Program Listing for File ckksrns-cryptoparameters.cpp
↰ Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-cryptoparameters.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
CKKS implementation. See https://eprint.iacr.org/2020/1118 for details.
*/
#define PROFILE
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"
#include <vector>
namespace lbcrypto {
// Precomputation of CRT tables encryption, decryption, and homomorphic
// multiplication
void CryptoParametersCKKSRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, ScalingTechnique scalTech,
EncryptionTechnique encTech, MultiplicationTechnique multTech,
uint32_t numPartQ, uint32_t auxBits, uint32_t extraBits) {
CryptoParametersRNS::PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraBits);
size_t sizeQ = GetElementParams()->GetParams().size();
uint32_t compositeDegree = this->GetCompositeDegree();
std::vector<NativeInteger> moduliQ(sizeQ);
std::vector<NativeInteger> rootsQ(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
moduliQ[i] = GetElementParams()->GetParams()[i]->GetModulus();
rootsQ[i] = GetElementParams()->GetParams()[i]->GetRootOfUnity();
}
BigInteger modulusQ = GetElementParams()->GetModulus();
// Pre-compute values for rescaling
// modulusQ holds Q^(l) = \prod_{i=0}^{i=l}(q_i).
m_QlQlInvModqlDivqlModq.resize(sizeQ - 1);
m_QlQlInvModqlDivqlModqPrecon.resize(sizeQ - 1);
m_qlInvModq.resize(sizeQ - 1);
m_qlInvModqPrecon.resize(sizeQ - 1);
for (size_t k = 0; k < sizeQ - 1; k++) {
size_t l = sizeQ - (k + 1);
modulusQ = modulusQ / BigInteger(moduliQ[l]);
m_QlQlInvModqlDivqlModq[k].resize(l);
m_QlQlInvModqlDivqlModqPrecon[k].resize(l);
m_qlInvModq[k].resize(l);
m_qlInvModqPrecon[k].resize(l);
BigInteger QlInvModql = modulusQ.ModInverse(moduliQ[l]);
BigInteger result = (QlInvModql * modulusQ) / BigInteger(moduliQ[l]);
for (usint i = 0; i < l; i++) {
m_QlQlInvModqlDivqlModq[k][i] = result.Mod(moduliQ[i]).ConvertToInt();
m_QlQlInvModqlDivqlModqPrecon[k][i] = m_QlQlInvModqlDivqlModq[k][i].PrepModMulConst(moduliQ[i]);
m_qlInvModq[k][i] = moduliQ[l].ModInverse(moduliQ[i]);
m_qlInvModqPrecon[k][i] = m_qlInvModq[k][i].PrepModMulConst(moduliQ[i]);
}
}
// Pre-compute scaling factors for each level (used in FLEXIBLE* scaling techniques)
if (m_scalTechnique == FLEXIBLEAUTO || m_scalTechnique == FLEXIBLEAUTOEXT ||
m_scalTechnique == COMPOSITESCALINGAUTO || m_scalTechnique == COMPOSITESCALINGMANUAL) {
m_scalingFactorsReal.resize(sizeQ);
if ((sizeQ == 1) && (extraBits == 0) && (m_scalTechnique != COMPOSITESCALINGAUTO) &&
(m_scalTechnique != COMPOSITESCALINGMANUAL)) {
// mult depth = 0 and FLEXIBLEAUTO
// when multiplicative depth = 0, we use the scaling mod size instead of modulus size
// Plaintext modulus is used in EncodingParamsImpl to store the exponent p of the scaling factor
m_scalingFactorsReal[0] = std::pow(2, GetPlaintextModulus());
}
else if ((sizeQ == 2) && (extraBits > 0) && (m_scalTechnique != COMPOSITESCALINGAUTO) &&
(m_scalTechnique != COMPOSITESCALINGMANUAL)) {
// mult depth = 0 and FLEXIBLEAUTOEXT
// when multiplicative depth = 0, we use the scaling mod size instead of modulus size
// Plaintext modulus is used in EncodingParamsImpl to store the exponent p of the scaling factor
m_scalingFactorsReal[0] = moduliQ[sizeQ - 1].ConvertToDouble();
m_scalingFactorsReal[1] = std::pow(2, GetPlaintextModulus());
}
else {
m_scalingFactorsReal[0] = moduliQ[sizeQ - 1].ConvertToDouble();
if (m_scalTechnique == COMPOSITESCALINGAUTO || m_scalTechnique == COMPOSITESCALINGMANUAL) {
for (uint32_t j = 1; j < compositeDegree; j++) {
m_scalingFactorsReal[0] *= moduliQ[sizeQ - j - 1].ConvertToDouble();
}
}
else if (extraBits > 0)
m_scalingFactorsReal[1] = moduliQ[sizeQ - 2].ConvertToDouble();
const double lastPresetFactor = (extraBits == 0) ? m_scalingFactorsReal[0] : m_scalingFactorsReal[1];
// number of levels with pre-calculated factors
const size_t numPresetFactors = (extraBits == 0 || (m_scalTechnique == COMPOSITESCALINGAUTO ||
m_scalTechnique == COMPOSITESCALINGMANUAL)) ?
1 :
2;
for (size_t k = numPresetFactors; k < sizeQ; k++) {
if (m_scalTechnique == COMPOSITESCALINGAUTO || m_scalTechnique == COMPOSITESCALINGMANUAL) {
if (k % compositeDegree == 0) {
double prevSF = m_scalingFactorsReal[k - compositeDegree];
m_scalingFactorsReal[k] = prevSF * prevSF;
for (uint32_t j = 0; j < compositeDegree; j++) {
m_scalingFactorsReal[k] /= moduliQ[sizeQ - k + j].ConvertToDouble();
}
}
else {
m_scalingFactorsReal[k] = 1;
}
}
else {
double prevSF = m_scalingFactorsReal[k - 1];
m_scalingFactorsReal[k] = prevSF * prevSF / moduliQ[sizeQ - k].ConvertToDouble();
if (m_scalTechnique == FLEXIBLEAUTO || m_scalTechnique == FLEXIBLEAUTOEXT) {
double ratio = m_scalingFactorsReal[k] / lastPresetFactor;
if (ratio <= 0.5 || ratio >= 2.0) {
OPENFHE_THROW("FLEXIBLEAUTO scaling failed at level " + std::to_string(k) +
" with scaling factor ratio " + std::to_string(ratio) +
". Use FIXEDMANUAL or FIXEDAUTO instead.");
}
}
}
}
}
m_scalingFactorsRealBig.resize(sizeQ - 1);
if (m_scalingFactorsRealBig.size() > 0) {
if (extraBits == 0) {
m_scalingFactorsRealBig[0] = m_scalingFactorsReal[0] * m_scalingFactorsReal[0];
}
else {
m_scalingFactorsRealBig[0] = m_scalingFactorsReal[0] * m_scalingFactorsReal[1];
}
for (uint32_t k = 1; k < sizeQ - 1; k++) {
m_scalingFactorsRealBig[k] = m_scalingFactorsReal[k] * m_scalingFactorsReal[k];
}
}
// Moduli as real
m_dmoduliQ.resize(sizeQ);
for (uint32_t i = 0; i < sizeQ; ++i) {
m_dmoduliQ[i] = moduliQ[i].ConvertToDouble();
}
}
else {
const auto p = GetPlaintextModulus();
m_approxSF = std::pow(2, p);
}
if (m_ksTechnique == HYBRID) {
const auto BarrettBase128Bit(BigInteger(1).LShiftEq(128));
m_modqBarrettMu.resize(sizeQ);
for (uint32_t i = 0; i < sizeQ; i++) {
m_modqBarrettMu[i] = (BarrettBase128Bit / BigInteger(moduliQ[i])).ConvertToInt<DoubleNativeInt>();
}
}
}
uint64_t CryptoParametersCKKSRNS::FindAuxPrimeStep() const {
size_t n = GetElementParams()->GetRingDimension();
return static_cast<uint64_t>(2 * n);
}
void CryptoParametersCKKSRNS::ConfigureCompositeDegree(uint32_t scalingModSize) {
// Add logic to determine whether composite scaling is feasible or not
if (GetScalingTechnique() == COMPOSITESCALINGAUTO) {
uint32_t registerWordSize = GetRegisterWordSize();
if (registerWordSize <= 64) {
if (registerWordSize < scalingModSize) {
uint32_t compositeDegree =
static_cast<uint32_t>(std::ceil(static_cast<float>(scalingModSize) / registerWordSize));
// Assert minimum allowed moduli size on composite scaling mode
// @fdiasmor TODO: make it more robust for a range of multiplicative depth
if (static_cast<float>(scalingModSize) / compositeDegree < 19) {
std::string errMsg = "Moduli size (";
errMsg += std::to_string(static_cast<float>(scalingModSize) / compositeDegree);
errMsg +=
") is too short (< 19) for target multiplicative depth. Consider increasing the scaling factor or the register word size.";
OPENFHE_THROW(errMsg);
}
m_compositeDegree = compositeDegree;
} // else composite degree remains set to 1
}
else {
OPENFHE_THROW("COMPOSITESCALING scaling technique only supports register word size <= 64.");
}
}
}
} // namespace lbcrypto