Program Listing for File ckksrns-parametergeneration.cpp
↰ Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
CKKS implementation. See https://eprint.iacr.org/2020/1118 for details.
*/
#define PROFILE
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"
#include "scheme/ckksrns/ckksrns-parametergeneration.h"
#include <vector>
#include <memory>
#include <string>
#include <unordered_set>
#include <iostream>
namespace lbcrypto {
#if NATIVEINT == 128
constexpr size_t AUXMODSIZE = 119;
#elif NATIVEINT == 32
constexpr size_t AUXMODSIZE = 28;
#else
constexpr size_t AUXMODSIZE = 60;
#endif
bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNSInternal(std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParams,
uint32_t cyclOrder, uint32_t numPrimes,
uint32_t scalingModSize, uint32_t firstModSize,
uint32_t numPartQ,
CompressionLevel mPIntBootCiphertextCompressionLevel) const {
// the "const" modifier for cryptoParamsCKKSRNS and encodingParams below doesn't mean that the objects those 2 pointers
// point to are const (not changeable). it means that the pointers themselves are const only.
const auto cryptoParamsCKKSRNS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cryptoParams);
const EncodingParams encodingParams = cryptoParamsCKKSRNS->GetEncodingParams();
KeySwitchTechnique ksTech = cryptoParamsCKKSRNS->GetKeySwitchTechnique();
ScalingTechnique scalTech = cryptoParamsCKKSRNS->GetScalingTechnique();
EncryptionTechnique encTech = cryptoParamsCKKSRNS->GetEncryptionTechnique();
MultiplicationTechnique multTech = cryptoParamsCKKSRNS->GetMultiplicationTechnique();
ProxyReEncryptionMode PREMode = cryptoParamsCKKSRNS->GetPREMode();
// Determine appropriate composite degree automatically if scaling technique set to COMPOSITESCALINGAUTO
cryptoParamsCKKSRNS->ConfigureCompositeDegree(firstModSize);
uint32_t compositeDegree = cryptoParamsCKKSRNS->GetCompositeDegree();
uint32_t registerWordSize = cryptoParamsCKKSRNS->GetRegisterWordSize();
if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) {
// TODO (Duhyeong): We need more exception cases in terms of
// prime size (= scalingModSize / compositeDegree), registerSize, and numPrimes
// e.g.1, assertion: prime size < registerSize (we may need at least 1-2 bit gap)
// e.g.2, prime size > ??? if numPrimes > ???
if (compositeDegree > 2 && scalingModSize < 60) {
std::string errorMsg = "Prime moduli size is too small. It must generally be greater than 19,";
errorMsg += " especially for larger multiplicative depth.";
errorMsg += " Please increase the scaling factor (scalingModSize) or the register word size.";
errorMsg += " Also, you can use COMPOSITESCALINGMANUAL at your own risk.";
OPENFHE_THROW(errorMsg);
}
else if (compositeDegree == 1 && registerWordSize < 64) {
OPENFHE_THROW(
"This COMPOSITESCALING* version does not support composite degree == 1 with register size < 64.");
}
else if (compositeDegree < 1) {
OPENFHE_THROW("Composite degree must be greater than or equal to 1.");
}
if (registerWordSize < 20 && scalTech == COMPOSITESCALINGAUTO) {
OPENFHE_THROW(
"Register word size must be greater than or equal to 20 for COMPOSITESCALINGAUTO. Otherwise, try it with COMPOSITESCALINGMANUAL.");
}
}
if ((PREMode != INDCPA) && (PREMode != NOT_SET)) {
std::stringstream s;
s << "This PRE mode " << PREMode << " is not supported for CKKSRNS";
OPENFHE_THROW(s.str());
}
// TODO: Allow the user to specify this?
uint32_t extraModSize = (scalTech == FLEXIBLEAUTOEXT) ? DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE : 0;
SecurityLevel stdLevel = cryptoParamsCKKSRNS->GetStdLevel();
// TODO Duhyeong: Let's check if auxBits = registerWordSize makes an error in the P prime generation.
uint32_t auxBits =
((scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) && registerWordSize <= AUXMODSIZE) ?
(registerWordSize - 1) :
AUXMODSIZE;
uint32_t n = cyclOrder / 2;
// GAUSSIAN security constraint
DistributionType distType = (cryptoParamsCKKSRNS->GetSecretKeyDist() == GAUSSIAN) ? HEStd_error : HEStd_ternary;
if (stdLevel != HEStd_NotSet) {
uint32_t qBound = firstModSize + (numPrimes - 1) * scalingModSize + extraModSize;
// we add an extra bit to account for the alternating logic of selecting the RNS moduli in CKKS
// ignore the case when there is only one max size modulus
if (qBound != auxBits)
++qBound;
// Estimate ciphertext modulus Q*P bound (in case of HYBRID P*Q)
if (ksTech == HYBRID)
qBound += std::get<0>(CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, scalingModSize,
extraModSize, numPrimes, auxBits, scalTech, true));
uint32_t he_std_n = StdLatticeParm::FindRingDim(distType, stdLevel, qBound);
if (n == 0) {
// Choose ring dimension based on security standards
n = he_std_n;
cyclOrder = 2 * n;
}
else {
// Check whether particular selection is standards-compliant
if (he_std_n > n) {
OPENFHE_THROW("The specified ring dimension (" + std::to_string(n) +
") does not comply with HE standards recommendation (" + std::to_string(he_std_n) + ").");
}
}
}
else if (n == 0) {
OPENFHE_THROW("Please specify the ring dimension or desired security level.");
}
if (encodingParams->GetBatchSize() > n / 2)
OPENFHE_THROW("The batch size cannot be larger than ring dimension / 2.");
if (encodingParams->GetBatchSize() & (encodingParams->GetBatchSize() - 1))
OPENFHE_THROW("The batch size can only be set to zero (for full packing) or a power of two.");
uint32_t dcrtBits = scalingModSize;
// In COMPOSITESCALING mode, each modulus consists of compositeDegree number of primes
numPrimes *= compositeDegree;
uint32_t vecSize = (extraModSize == 0) ? numPrimes : numPrimes + 1;
std::vector<NativeInteger> moduliQ(vecSize);
std::vector<NativeInteger> rootsQ(vecSize);
if ((scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) && (compositeDegree > 1)) {
CompositePrimeModuliGen(moduliQ, rootsQ, compositeDegree, numPrimes, firstModSize, dcrtBits, cyclOrder,
registerWordSize);
}
else
SinglePrimeModuliGen(moduliQ, rootsQ, scalTech, numPrimes, firstModSize, dcrtBits, cyclOrder, extraModSize);
auto paramsDCRT = std::make_shared<ILDCRTParams<BigInteger>>(cyclOrder, moduliQ, rootsQ);
cryptoParamsCKKSRNS->SetElementParams(paramsDCRT);
// if no batch size was specified, we set batchSize = n/2 by default (for full packing)
if (encodingParams->GetBatchSize() == 0) {
uint32_t batchSize = n / 2;
EncodingParams encodingParamsNew(
std::make_shared<EncodingParamsImpl>(encodingParams->GetPlaintextModulus(), batchSize));
cryptoParamsCKKSRNS->SetEncodingParams(encodingParamsNew);
}
cryptoParamsCKKSRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraModSize);
// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = (ksTech == HYBRID) ? cryptoParamsCKKSRNS->GetParamsQP()->GetModulus().GetMSB() :
cryptoParamsCKKSRNS->GetElementParams()->GetModulus().GetMSB();
uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);
if (n < nActual) {
std::string errMsg("The ring dimension [");
errMsg += std::to_string(n) + "] does not meet security requirements. ";
OPENFHE_THROW(errMsg);
}
}
return true;
}
void ParameterGenerationCKKSRNS::CompositePrimeModuliGen(std::vector<NativeInteger>& moduliQ,
std::vector<NativeInteger>& rootsQ, uint32_t compositeDegree,
uint32_t numPrimes, uint32_t firstModSize, uint32_t dcrtBits,
uint32_t cyclOrder, uint32_t registerWordSize) const {
if (firstModSize <= dcrtBits) {
OPENFHE_THROW("firstModSize must be > scalingModSize.");
}
std::unordered_set<uint64_t> moduliQRecord;
for (uint32_t d = 1, remBits = dcrtBits; d <= compositeDegree; ++d) {
uint32_t qBitSize = std::ceil(static_cast<double>(remBits) / (compositeDegree - d + 1));
NativeInteger q = FirstPrime<NativeInteger>(qBitSize, cyclOrder);
q = PreviousPrime<NativeInteger>(q, cyclOrder);
while (std::log2(q.ConvertToDouble()) > registerWordSize || std::log2(q.ConvertToDouble()) > qBitSize ||
moduliQRecord.find(q.ConvertToInt()) != moduliQRecord.end()) {
q = PreviousPrime<NativeInteger>(q, cyclOrder);
}
moduliQ[numPrimes - d] = q;
rootsQ[numPrimes - d] = RootOfUnity(cyclOrder, moduliQ[numPrimes - d]);
moduliQRecord.emplace(q.ConvertToInt());
remBits -= std::ceil(std::log2(q.ConvertToDouble()));
}
const std::string compositeScalingErrMsg =
"COMPOSITE SCALING prime sampling error. Consider increasing the scaling factor or the register word size.";
if (numPrimes > 1) {
std::vector<NativeInteger> qPrev(std::ceil(static_cast<double>(compositeDegree) / 2));
std::vector<NativeInteger> qNext(compositeDegree - static_cast<uint32_t>(qPrev.size()));
// Prep to compute initial scaling factor
double sf = moduliQ[numPrimes - 1].ConvertToDouble();
for (uint32_t d = 2; d <= compositeDegree; ++d) {
sf *= moduliQ[numPrimes - d].ConvertToDouble();
}
bool flag = true;
for (uint32_t i = numPrimes - compositeDegree; i >= 2 * compositeDegree; i -= compositeDegree) {
// Compute initial scaling factor
sf = std::pow(sf, 2);
for (uint32_t d = 0; d < compositeDegree; ++d) {
sf /= moduliQ[i + d].ConvertToDouble();
}
auto sf_sqrt = std::pow(sf, 1.0 / compositeDegree);
NativeInteger sfInt = std::llround(sf_sqrt);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
double primeProduct = 1.0;
std::unordered_set<uint64_t> qCurrentRecord; // current prime tracker
for (size_t step = 0; step < qPrev.size(); ++step) {
qPrev[step] = sfInt - sfRem + NativeInteger(1) - NativeInteger(cyclOrder);
do {
try {
qPrev[step] = lbcrypto::PreviousPrime(qPrev[step], cyclOrder);
}
catch (const OpenFHEException& ex) {
OPENFHE_THROW(compositeScalingErrMsg);
}
} while (std::log2(qPrev[step].ConvertToDouble()) > registerWordSize ||
moduliQRecord.find(qPrev[step].ConvertToInt()) != moduliQRecord.end() ||
qCurrentRecord.find(qPrev[step].ConvertToInt()) != qCurrentRecord.end());
qCurrentRecord.emplace(qPrev[step].ConvertToInt());
primeProduct *= qPrev[step].ConvertToDouble();
}
bool fitsRegister = true;
for (size_t step = 0; step < qNext.size(); ++step) {
qNext[step] = sfInt - sfRem + NativeInteger(1) + NativeInteger(cyclOrder);
do {
try {
if (fitsRegister == true) {
qNext[step] = lbcrypto::NextPrime(qNext[step], cyclOrder);
}
else {
qNext[step] = lbcrypto::PreviousPrime(qNext[step], cyclOrder);
}
}
catch (const OpenFHEException& ex) {
OPENFHE_THROW(compositeScalingErrMsg);
}
if (std::log2(qNext[step].ConvertToDouble()) > registerWordSize) {
fitsRegister = false;
}
} while (std::log2(qNext[step].ConvertToDouble()) > registerWordSize ||
moduliQRecord.find(qNext[step].ConvertToInt()) != moduliQRecord.end() ||
qCurrentRecord.find(qNext[step].ConvertToInt()) != qCurrentRecord.end());
qCurrentRecord.emplace(qNext[step].ConvertToInt());
primeProduct *= qNext[step].ConvertToDouble();
}
if (flag == false) {
NativeInteger qPrevNext = NativeInteger(qNext[qNext.size() - 1].ConvertToInt());
while (primeProduct > sf) {
do {
qCurrentRecord.erase(qPrevNext.ConvertToInt()); // constant time
try {
qPrevNext = lbcrypto::PreviousPrime(qPrevNext, cyclOrder);
}
catch (const OpenFHEException& ex) {
OPENFHE_THROW(compositeScalingErrMsg);
}
} while (std::log2(qPrevNext.ConvertToDouble()) > registerWordSize ||
moduliQRecord.find(qPrevNext.ConvertToInt()) != moduliQRecord.end() ||
qCurrentRecord.find(qPrevNext.ConvertToInt()) != qCurrentRecord.end());
qCurrentRecord.emplace(qPrevNext.ConvertToInt());
primeProduct /= qNext[qNext.size() - 1].ConvertToDouble();
qNext[qNext.size() - 1] = qPrevNext;
primeProduct *= qPrevNext.ConvertToDouble();
}
uint32_t m = qPrev.size();
for (uint32_t d = 1; d <= m; ++d) {
moduliQ[i - d] = qPrev[d - 1];
}
for (uint32_t d = m + 1; d <= compositeDegree; ++d) {
moduliQ[i - d] = qNext[d - (m + 1)];
}
for (uint32_t d = 1; d <= compositeDegree; ++d) {
rootsQ[i - d] = RootOfUnity(cyclOrder, moduliQ[i - d]);
moduliQRecord.emplace(moduliQ[i - d].ConvertToInt());
}
flag = true;
}
else {
NativeInteger qNextPrev = NativeInteger(qPrev[qPrev.size() - 1].ConvertToInt());
fitsRegister = true;
while (primeProduct < sf) {
do {
qCurrentRecord.erase(qNextPrev.ConvertToInt()); // constant time
try {
if (fitsRegister) {
qNextPrev = lbcrypto::NextPrime(qNextPrev, cyclOrder);
}
else {
qNextPrev = lbcrypto::PreviousPrime(qNextPrev, cyclOrder);
}
}
catch (const OpenFHEException& ex) {
OPENFHE_THROW(compositeScalingErrMsg);
}
if (std::log2(qNextPrev.ConvertToDouble()) > registerWordSize) {
fitsRegister = false;
}
} while (std::log2(qNextPrev.ConvertToDouble()) > registerWordSize ||
moduliQRecord.find(qNextPrev.ConvertToInt()) != moduliQRecord.end() ||
qCurrentRecord.find(qNextPrev.ConvertToInt()) != qCurrentRecord.end());
qCurrentRecord.emplace(qNextPrev.ConvertToInt());
primeProduct /= qPrev[qPrev.size() - 1].ConvertToDouble();
qPrev[qPrev.size() - 1] = qNextPrev;
primeProduct *= qNextPrev.ConvertToDouble();
}
uint32_t m = qPrev.size();
for (uint32_t d = 1; d <= m; ++d) {
moduliQ[i - d] = qPrev[d - 1];
}
for (uint32_t d = m + 1; d <= compositeDegree; ++d) {
moduliQ[i - d] = qNext[d - (m + 1)];
}
for (uint32_t d = 1; d <= compositeDegree; ++d) {
rootsQ[i - d] = RootOfUnity(cyclOrder, moduliQ[i - d]);
moduliQRecord.emplace(moduliQ[i - d].ConvertToInt());
}
flag = false;
}
} // for loop
} // if numPrimes > 1
for (uint32_t d = 1, remBits = firstModSize; d <= compositeDegree; ++d) {
uint32_t qBitSize = std::ceil(static_cast<double>(remBits) / (compositeDegree - d + 1));
try {
// Find next prime
NativeInteger nextInteger = FirstPrime<NativeInteger>(qBitSize, cyclOrder);
nextInteger = PreviousPrime<NativeInteger>(nextInteger, cyclOrder);
while (std::log2(nextInteger.ConvertToDouble()) > qBitSize ||
std::log2(nextInteger.ConvertToDouble()) > registerWordSize ||
moduliQRecord.find(nextInteger.ConvertToInt()) != moduliQRecord.end())
nextInteger = PreviousPrime<NativeInteger>(nextInteger, cyclOrder);
// Store prime
moduliQ[d - 1] = nextInteger;
rootsQ[d - 1] = RootOfUnity(cyclOrder, moduliQ[d - 1]);
// Keep track of existing primes
moduliQRecord.emplace(moduliQ[d - 1].ConvertToInt());
remBits -= qBitSize;
}
catch (const OpenFHEException& ex) {
OPENFHE_THROW(compositeScalingErrMsg);
}
}
return;
}
void ParameterGenerationCKKSRNS::SinglePrimeModuliGen(std::vector<NativeInteger>& moduliQ,
std::vector<NativeInteger>& rootsQ, ScalingTechnique scalTech,
uint32_t numPrimes, uint32_t firstModSize, uint32_t dcrtBits,
uint32_t cyclOrder, uint32_t extraModSize) const {
NativeInteger q = FirstPrime<NativeInteger>(dcrtBits, cyclOrder);
moduliQ[numPrimes - 1] = q;
rootsQ[numPrimes - 1] = RootOfUnity(cyclOrder, moduliQ[numPrimes - 1]);
NativeInteger maxPrime{q};
NativeInteger minPrime{q};
if (numPrimes > 1) {
if (scalTech != FLEXIBLEAUTO && scalTech != FLEXIBLEAUTOEXT) {
NativeInteger qPrev = q;
NativeInteger qNext = q;
for (size_t i = numPrimes - 2, cnt = 0; i >= 1; --i, ++cnt) {
if ((cnt % 2) == 0) {
qPrev = PreviousPrime(qPrev, cyclOrder);
moduliQ[i] = qPrev;
}
else {
qNext = NextPrime(qNext, cyclOrder);
moduliQ[i] = qNext;
}
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];
rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
}
}
else { // FLEXIBLEAUTO
/* Scaling factors in FLEXIBLEAUTO are a bit fragile,
* in the sense that once one scaling factor gets far enough from the
* original scaling factor, subsequent level scaling factors quickly
* diverge to either 0 or infinity. To mitigate this problem to a certain
* extend, we have a special prime selection process in place. The goal is
* to maintain the scaling factor of all levels as close to the original
* scale factor of level 0 as possible.
*/
double sf = moduliQ[numPrimes - 1].ConvertToDouble();
for (size_t i = numPrimes - 2, cnt = 0; i >= 1; --i, ++cnt) {
sf = std::pow(sf, 2) / moduliQ[i + 1].ConvertToDouble();
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
bool hasSameMod = true;
if ((cnt % 2) == 0) {
NativeInteger qPrev = sfInt - NativeInteger(cyclOrder) - sfRem + NativeInteger(1);
while (hasSameMod) {
hasSameMod = false;
qPrev = PreviousPrime(qPrev, cyclOrder);
for (size_t j = i + 1; j < numPrimes; j++) {
if (qPrev == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qPrev;
}
else {
NativeInteger qNext = sfInt + NativeInteger(cyclOrder) - sfRem + NativeInteger(1);
while (hasSameMod) {
hasSameMod = false;
qNext = NextPrime(qNext, cyclOrder);
for (size_t j = i + 1; j < numPrimes; j++) {
if (qNext == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qNext;
}
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];
rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
}
}
}
if (firstModSize == dcrtBits) { // this requires dcrtBits < 60
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}
else {
moduliQ[0] = LastPrime<NativeInteger>(firstModSize, cyclOrder);
// find if the value of moduliQ[0] is already in the vector starting with moduliQ[1] and
// if there is, then get another prime for moduliQ[0]
const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]);
if (pos != moduliQ.end()) {
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}
}
if (moduliQ[0] > maxPrime)
maxPrime = moduliQ[0];
rootsQ[0] = RootOfUnity(cyclOrder, moduliQ[0]);
if (scalTech == FLEXIBLEAUTOEXT) {
// moduliQ[numPrimes] must still be 0, so it has to be populated now
// no need for extra checking as extraModSize is automatically chosen by the library
auto tempMod = FirstPrime<NativeInteger>(extraModSize - 1, cyclOrder);
// check if tempMod has a duplicate in the vector (exclude moduliQ[numPrimes] from this operation):
const auto endPos = moduliQ.end() - 1;
auto pos = std::find(moduliQ.begin(), endPos, tempMod);
// if there is a duplicate, then we call NextPrime()
moduliQ[numPrimes] = (pos != endPos) ? NextPrime<NativeInteger>(maxPrime, cyclOrder) : tempMod;
rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]);
}
}
} // namespace lbcrypto