Program Listing for File bfvrns-leveledshe.cpp
↰ Return to documentation for file (pke/lib/scheme/bfvrns/bfvrns-leveledshe.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2024, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
BFV implementation. See https://eprint.iacr.org/2021/204 for details.
*/
#define PROFILE
#include "scheme/bfvrns/bfvrns-leveledshe.h"
#include "scheme/bfvrns/bfvrns-cryptoparameters.h"
#include "schemebase/base-scheme.h"
#include "cryptocontext.h"
#include "ciphertext.h"
#include <algorithm>
#include <map>
#include <utility>
#include <memory>
#include <vector>
#include <string>
namespace lbcrypto {
void LeveledSHEBFVRNS::EvalAddInPlace(Ciphertext<DCRTPoly>& ciphertext, ConstPlaintext& plaintext) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoParameters());
auto pt = plaintext->GetElement<DCRTPoly>();
pt.SetFormat(COEFFICIENT);
// enables encoding of plaintexts using a smaller number of RNS limbs
auto sizeQ = cryptoParams->GetElementParams()->GetParams().size();
auto sizeP = pt.GetParams()->GetParams().size();
auto level = sizeQ - sizeP;
auto&& NegQModt = cryptoParams->GetNegQModt(level);
auto&& NegQModtPrecon = cryptoParams->GetNegQModtPrecon(level);
auto&& tInvModq = cryptoParams->GettInvModq();
auto&& t = cryptoParams->GetPlaintextModulus();
pt.TimesQovert(cryptoParams->GetElementParams(), tInvModq, t, NegQModt, NegQModtPrecon);
pt.SetFormat(EVALUATION);
ciphertext->GetElements()[0] += pt;
}
void LeveledSHEBFVRNS::EvalSubInPlace(Ciphertext<DCRTPoly>& ciphertext, ConstPlaintext& plaintext) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoParameters());
auto pt = plaintext->GetElement<DCRTPoly>();
pt.SetFormat(COEFFICIENT);
// enables encoding of plaintexts using a smaller number of RNS limbs
auto sizeQ = cryptoParams->GetElementParams()->GetParams().size();
auto sizeP = pt.GetParams()->GetParams().size();
auto level = sizeQ - sizeP;
auto&& NegQModt = cryptoParams->GetNegQModt(level);
auto&& NegQModtPrecon = cryptoParams->GetNegQModtPrecon(level);
auto&& tInvModq = cryptoParams->GettInvModq();
auto&& t = cryptoParams->GetPlaintextModulus();
pt.TimesQovert(cryptoParams->GetElementParams(), tInvModq, t, NegQModt, NegQModtPrecon);
pt.SetFormat(EVALUATION);
ciphertext->GetElements()[0] -= pt;
}
uint32_t FindLevelsToDrop(uint32_t multiplicativeDepth, std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParams,
uint32_t dcrtBits, bool keySwitch = false) {
const auto cryptoParamsBFVrns = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(cryptoParams);
double sigma = cryptoParamsBFVrns->GetDistributionParameter();
double alpha = cryptoParamsBFVrns->GetAssuranceMeasure();
double p = static_cast<double>(cryptoParamsBFVrns->GetPlaintextModulus());
uint32_t n = cryptoParamsBFVrns->GetElementParams()->GetRingDimension();
uint32_t relinWindow = cryptoParamsBFVrns->GetDigitSize();
KeySwitchTechnique scalTechnique = cryptoParamsBFVrns->GetKeySwitchTechnique();
EncryptionTechnique encTech = cryptoParamsBFVrns->GetEncryptionTechnique();
uint32_t k = cryptoParamsBFVrns->GetNumPerPartQ();
uint32_t numPartQ = cryptoParamsBFVrns->GetNumPartQ();
uint32_t thresholdParties = cryptoParamsBFVrns->GetThresholdNumOfParties();
// Bound of the Gaussian error polynomial
double Berr = sigma * std::sqrt(alpha);
// Bkey set to thresholdParties * 1 for ternary distribution
const double Bkey =
(cryptoParamsBFVrns->GetSecretKeyDist() == GAUSSIAN) ? std::sqrt(thresholdParties) * Berr : thresholdParties;
double w = std::pow(2, relinWindow == 0 ? dcrtBits : relinWindow);
// expansion factor delta for a multiplication of a Gaussian polynomial by a random polynomial
auto delta = [](uint32_t n) -> double {
return (2. * std::sqrt(n));
};
// expansion factor delta for modulus switching
auto deltaMS = [](uint32_t n) -> double {
return (4. * std::sqrt(n));
};
// norm of fresh ciphertext polynomial (for EXTENDED the noise is reduced to modulus switching noise)
auto Vnorm = [&](uint32_t n) -> double {
if (encTech == EXTENDED)
return (1. + deltaMS(n) * Bkey) / 2.;
else
return Berr * (1. + 2. * delta(n) * Bkey);
};
auto noiseKS = [&](uint32_t n, double logqPrev, double w) -> double {
if (scalTechnique == HYBRID)
#if defined(WITH_REDUCED_NOISE)
return k * (numPartQ * delta(n) * Berr + delta(n) * Bkey + 1.0) / 2.0;
#else
return k * (numPartQ * delta(n) * Berr + deltaMS(n) * Bkey + 1.0);
#endif
else {
double numDigitsPerTower = (relinWindow == 0) ? 1 : ((dcrtBits / relinWindow) + 1);
return delta(n) * numDigitsPerTower * (std::floor(logqPrev / dcrtBits) + 1) * w * Berr / 2.0;
}
};
// function used in the EvalMult constraint
auto C1 = [&](uint32_t n) -> double {
return delta(n) * deltaMS(n) * p * Bkey;
};
// function used in the EvalMult constraint
auto C2 = [&](uint32_t n, double logqPrev) -> double {
return delta(n) * deltaMS(n) * Bkey * Bkey / 2.0 + noiseKS(n, logqPrev, w);
};
// main correctness constraint
auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
if (multiplicativeDepth > 0) {
return std::log2(4 * p) + (multiplicativeDepth - 1) * std::log2(C1(n)) +
std::log2(C1(n) * Vnorm(n) + multiplicativeDepth * C2(n, logqPrev));
}
return std::log2(p * (4 * (Vnorm(n))));
};
// initial values
double logqPrev = 6. * std::log2(10);
double logq = logqBFV(n, logqPrev);
while (std::fabs(logq - logqPrev) > std::log2(1.001)) {
logqPrev = logq;
logq = logqBFV(n, logqPrev);
}
// get an estimate of the error q / (4t)
double loge = logq - 2 - std::log2(p);
double logExtra = keySwitch ? std::log2(noiseKS(n, logq, w)) : std::log2(deltaMS(n));
// adding the cushon to the error (see Appendix D of https://eprint.iacr.org/2021/204.pdf for details)
// adjusted empirical parameter to 16 from 4 for threshold scenarios to work correctly, this might need to
// be further refined
int32_t levels = std::floor((loge - 3 * multiplicativeDepth - 16 - logExtra) / dcrtBits);
size_t sizeQ = cryptoParamsBFVrns->GetElementParams()->GetParams().size();
if (levels < 0)
levels = 0;
else if (levels > static_cast<int32_t>(sizeQ) - 1)
levels = sizeQ - 1;
return levels;
};
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalMult(ConstCiphertext<DCRTPoly>& ciphertext1,
ConstCiphertext<DCRTPoly>& ciphertext2) const {
if (ciphertext1->GetCryptoParameters() != ciphertext2->GetCryptoParameters())
OPENFHE_THROW("crypto parameters are not the same");
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext1->GetCryptoContext()->GetCryptoParameters());
std::vector<DCRTPoly> cv1 = ciphertext1->GetElements();
std::vector<DCRTPoly> cv2 = ciphertext2->GetElements();
uint32_t cv1Size = cv1.size();
uint32_t cv2Size = cv2.size();
uint32_t cvMultSize = cv1Size + cv2Size - 1;
uint32_t sizeQ = cv1[0].GetNumOfElements();
const auto elementParams = cryptoParams->GetElementParams();
// Maximum number of RNS limbs in the crypto context
uint32_t sizeQM = elementParams->GetParams().size();
// l is index corresponding to leveled parameters in cryptoParameters precomputations in HPSPOVERQLEVELED
uint32_t l = 0;
std::vector<DCRTPoly> cvMult(cvMultSize);
if (cryptoParams->GetMultiplicationTechnique() == HPS) {
for (uint32_t i = 0; i < cv1Size; ++i) {
cv1[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(), cryptoParams->GetParamsRl(),
cryptoParams->GetQlHatInvModq(), cryptoParams->GetQlHatInvModqPrecon(),
cryptoParams->GetQlHatModr(), cryptoParams->GetalphaQlModr(),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetqInv(), Format::EVALUATION);
}
for (uint32_t i = 0; i < cv2Size; ++i) {
cv2[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(), cryptoParams->GetParamsRl(),
cryptoParams->GetQlHatInvModq(), cryptoParams->GetQlHatInvModqPrecon(),
cryptoParams->GetQlHatModr(), cryptoParams->GetalphaQlModr(),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetqInv(), Format::EVALUATION);
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQ) ||
((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ < sizeQM))) {
for (uint32_t i = 0; i < cv1Size; ++i) {
// Expand ciphertext1 from basis Q to PQ (from Q_l to P_l*Q_l if manual compress/lower-level-encode was called)
cv1[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(sizeQ - 1), cryptoParams->GetParamsRl(sizeQ - 1),
cryptoParams->GetQlHatInvModq(sizeQ - 1),
cryptoParams->GetQlHatInvModqPrecon(sizeQ - 1), cryptoParams->GetQlHatModr(sizeQ - 1),
cryptoParams->GetalphaQlModr(sizeQ - 1), cryptoParams->GetModrBarrettMu(),
cryptoParams->GetqInv(), Format::EVALUATION);
}
DCRTPoly::CRTBasisExtensionPrecomputations basisPQ(
cryptoParams->GetParamsQlRl(sizeQ - 1), cryptoParams->GetParamsRl(sizeQ - 1),
cryptoParams->GetParamsQl(sizeQ - 1), cryptoParams->GetmNegRlQlHatInvModq(sizeQ - 1),
cryptoParams->GetmNegRlQlHatInvModqPrecon(sizeQ - 1), cryptoParams->GetqInvModr(),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetRlHatInvModr(sizeQ - 1),
cryptoParams->GetRlHatInvModrPrecon(sizeQ - 1), cryptoParams->GetRlHatModq(sizeQ - 1),
cryptoParams->GetalphaRlModq(sizeQ - 1), cryptoParams->GetModqBarrettMu(), cryptoParams->GetrInv());
for (uint32_t i = 0; i < cv2Size; ++i) {
cv2[i].SetFormat(Format::COEFFICIENT);
// Switch ciphertext2 from basis Q to P to PQ (from Q_l to P_l to P_l*Q_l if manual compress/lower-level-encode was called).
cv2[i].FastExpandCRTBasisPloverQ(basisPQ);
cv2[i].SetFormat(Format::EVALUATION);
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ == sizeQM)) {
uint32_t c1depth = ciphertext1->GetNoiseScaleDeg();
uint32_t c2depth = ciphertext2->GetNoiseScaleDeg();
uint32_t levels = std::max(c1depth, c2depth) - 1;
double dcrtBits = cv1[0].GetElementAtIndex(0).GetModulus().GetMSB();
// how many levels to drop
uint32_t levelsDropped = FindLevelsToDrop(levels, cryptoParams, dcrtBits, false);
l = levelsDropped > 0 ? sizeQ - 1 - levelsDropped : sizeQ - 1;
for (uint32_t i = 0; i < cv1Size; ++i) {
cv1[i].SetFormat(Format::COEFFICIENT);
if (l < sizeQ - 1) {
// Drop from basis Q to Q_l.
cv1[i] =
cv1[i].ScaleAndRound(cryptoParams->GetParamsQl(l), cryptoParams->GetQlQHatInvModqDivqModq(l),
cryptoParams->GetQlQHatInvModqDivqFrac(l), cryptoParams->GetModqBarrettMu());
}
// Expand ciphertext1 from basis Q_l to P_l*Q_l.
cv1[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(l), cryptoParams->GetParamsRl(l),
cryptoParams->GetQlHatInvModq(l), cryptoParams->GetQlHatInvModqPrecon(l),
cryptoParams->GetQlHatModr(l), cryptoParams->GetalphaQlModr(l),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetqInv(), Format::EVALUATION);
}
DCRTPoly::CRTBasisExtensionPrecomputations basisPQ(
cryptoParams->GetParamsQlRl(l), cryptoParams->GetParamsRl(l), cryptoParams->GetParamsQl(l),
cryptoParams->GetmNegRlQHatInvModq(l), cryptoParams->GetmNegRlQHatInvModqPrecon(l),
cryptoParams->GetqInvModr(), cryptoParams->GetModrBarrettMu(), cryptoParams->GetRlHatInvModr(l),
cryptoParams->GetRlHatInvModrPrecon(l), cryptoParams->GetRlHatModq(l), cryptoParams->GetalphaRlModq(l),
cryptoParams->GetModqBarrettMu(), cryptoParams->GetrInv());
for (uint32_t i = 0; i < cv2Size; ++i) {
cv2[i].SetFormat(Format::COEFFICIENT);
// Switch ciphertext2 from basis Q to P_l to P_l*Q_l.
cv2[i].FastExpandCRTBasisPloverQ(basisPQ);
cv2[i].SetFormat(Format::EVALUATION);
}
}
else {
for (uint32_t i = 0; i < cv1Size; ++i) {
cv1[i].FastBaseConvqToBskMontgomery(
cryptoParams->GetParamsQBsk(), cryptoParams->GetModuliQ(), cryptoParams->GetModuliBsk(),
cryptoParams->GetModbskBarrettMu(), cryptoParams->GetmtildeQHatInvModq(),
cryptoParams->GetmtildeQHatInvModqPrecon(), cryptoParams->GetQHatModbsk(),
cryptoParams->GetQHatModmtilde(), cryptoParams->GetQModbsk(), cryptoParams->GetQModbskPrecon(),
cryptoParams->GetNegQInvModmtilde(), cryptoParams->GetmtildeInvModbsk(),
cryptoParams->GetmtildeInvModbskPrecon());
cv1[i].SetFormat(Format::EVALUATION);
}
for (uint32_t i = 0; i < cv2Size; ++i) {
cv2[i].FastBaseConvqToBskMontgomery(
cryptoParams->GetParamsQBsk(), cryptoParams->GetModuliQ(), cryptoParams->GetModuliBsk(),
cryptoParams->GetModbskBarrettMu(), cryptoParams->GetmtildeQHatInvModq(),
cryptoParams->GetmtildeQHatInvModqPrecon(), cryptoParams->GetQHatModbsk(),
cryptoParams->GetQHatModmtilde(), cryptoParams->GetQModbsk(), cryptoParams->GetQModbskPrecon(),
cryptoParams->GetNegQInvModmtilde(), cryptoParams->GetmtildeInvModbsk(),
cryptoParams->GetmtildeInvModbskPrecon());
cv2[i].SetFormat(Format::EVALUATION);
}
}
#ifdef USE_KARATSUBA
if (cv1Size == 2 && cv2Size == 2) {
// size of each ciphertxt = 2, use Karatsuba
cvMult[0] = cv1[0] * cv2[0]; // a
cvMult[2] = cv1[1] * cv2[1]; // b
cvMult[1] = cv1[0] + cv1[1];
cvMult[1] *= (cv2[0] + cv2[1]);
cvMult[1] -= cvMult[2];
cvMult[1] -= cvMult[0];
}
else { // if size of any of the ciphertexts > 2
std::vector<bool> isFirstAdd(cvMultSize, true);
for (uint32_t i = 0; i < cv1Size; i++) {
for (uint32_t j = 0; j < cv2Size; j++) {
if (isFirstAdd[i + j] == true) {
cvMult[i + j] = cv1[i] * cv2[j];
isFirstAdd[i + j] = false;
}
else {
cvMult[i + j] += cv1[i] * cv2[j];
}
}
}
}
#else
std::vector<bool> isFirstAdd(cvMultSize, true);
for (uint32_t i = 0; i < cv1Size; i++) {
for (uint32_t j = 0; j < cv2Size; j++) {
if (isFirstAdd[i + j] == true) {
cvMult[i + j] = cv1[i] * cv2[j];
isFirstAdd[i + j] = false;
}
else {
cvMult[i + j] += cv1[i] * cv2[j];
}
}
}
#endif
if (cryptoParams->GetMultiplicationTechnique() == HPS) {
for (uint32_t i = 0; i < cvMultSize; ++i) {
// converts to coefficient representation before rounding
cvMult[i].SetFormat(Format::COEFFICIENT);
// Performs the scaling by t/Q followed by rounding; the result is in the
// CRT basis P
cvMult[i] =
cvMult[i].ScaleAndRound(cryptoParams->GetParamsRl(), cryptoParams->GettRSHatInvModsDivsModr(),
cryptoParams->GettRSHatInvModsDivsFrac(), cryptoParams->GetModrBarrettMu());
// Converts from the CRT basis P to Q
cvMult[i] = cvMult[i].SwitchCRTBasis(cryptoParams->GetElementParams(), cryptoParams->GetRlHatInvModr(),
cryptoParams->GetRlHatInvModrPrecon(), cryptoParams->GetRlHatModq(),
cryptoParams->GetalphaRlModq(), cryptoParams->GetModqBarrettMu(),
cryptoParams->GetrInv());
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQ) ||
((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ < sizeQM))) {
l = sizeQ - 1;
for (uint32_t i = 0; i < cvMultSize; ++i) {
cvMult[i].SetFormat(COEFFICIENT);
// Performs the scaling by t/P followed by rounding; the result is in the
// CRT basis Q (Q_l if compress/lower-level encode was used)
cvMult[i] =
cvMult[i].ScaleAndRound(cryptoParams->GetParamsQl(l), cryptoParams->GettQlSlHatInvModsDivsModq(l),
cryptoParams->GettQlSlHatInvModsDivsFrac(l), cryptoParams->GetModqBarrettMu());
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ == sizeQM)) {
for (uint32_t i = 0; i < cvMultSize; ++i) {
cvMult[i].SetFormat(COEFFICIENT);
// Performs the scaling by t/P followed by rounding; the result is in the
// CRT basis Ql
cvMult[i] =
cvMult[i].ScaleAndRound(cryptoParams->GetParamsQl(l), cryptoParams->GettQlSlHatInvModsDivsModq(l),
cryptoParams->GettQlSlHatInvModsDivsFrac(l), cryptoParams->GetModqBarrettMu());
if (l < sizeQ - 1) {
// Expand back to basis Q.
cvMult[i].ExpandCRTBasisQlHat(cryptoParams->GetElementParams(), cryptoParams->GetQlHatModq(l),
cryptoParams->GetQlHatModqPrecon(l), sizeQ);
}
}
}
else {
const NativeInteger& t = cryptoParams->GetPlaintextModulus();
for (uint32_t i = 0; i < cvMultSize; ++i) {
// converts to Format::COEFFICIENT representation before rounding
cvMult[i].SetFormat(Format::COEFFICIENT);
// Performs the scaling by t/Q followed by rounding; the result is in the
// CRT basis {Bsk}
cvMult[i].FastRNSFloorq(
t, cryptoParams->GetModuliQ(), cryptoParams->GetModuliBsk(), cryptoParams->GetModbskBarrettMu(),
cryptoParams->GettQHatInvModq(), cryptoParams->GettQHatInvModqPrecon(), cryptoParams->GetQHatModbsk(),
cryptoParams->GetqInvModbsk(), cryptoParams->GettQInvModbsk(), cryptoParams->GettQInvModbskPrecon());
// Converts from the CRT basis {Bsk} to {Q}
cvMult[i].FastBaseConvSK(cryptoParams->GetElementParams(), cryptoParams->GetModqBarrettMu(),
cryptoParams->GetModuliBsk(), cryptoParams->GetModbskBarrettMu(),
cryptoParams->GetBHatInvModb(), cryptoParams->GetBHatInvModbPrecon(),
cryptoParams->GetBHatModmsk(), cryptoParams->GetBInvModmsk(),
cryptoParams->GetBInvModmskPrecon(), cryptoParams->GetBHatModq(),
cryptoParams->GetBModq(), cryptoParams->GetBModqPrecon());
}
}
auto ciphertextMult = ciphertext1->CloneEmpty();
ciphertextMult->SetElements(std::move(cvMult));
ciphertextMult->SetNoiseScaleDeg(std::max(ciphertext1->GetNoiseScaleDeg(), ciphertext2->GetNoiseScaleDeg()) + 1);
return ciphertextMult;
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalSquare(ConstCiphertext<DCRTPoly>& ciphertext) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoContext()->GetCryptoParameters());
std::vector<DCRTPoly> cv = ciphertext->GetElements();
size_t cvSize = cv.size();
size_t cvSqSize = 2 * cvSize - 1;
size_t sizeQ = cv[0].GetNumOfElements();
const auto elementParams = cryptoParams->GetElementParams();
// Maximum number of RNS limbs in the crypto context
size_t sizeQM = elementParams->GetParams().size();
// l is index corresponding to leveled parameters in cryptoParameters precomputations in HPSPOVERQLEVELED
size_t l = 0;
std::vector<DCRTPoly> cvPoverQ;
if (cryptoParams->GetMultiplicationTechnique() == HPS) {
for (size_t i = 0; i < cvSize; i++) {
cv[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(), cryptoParams->GetParamsRl(),
cryptoParams->GetQlHatInvModq(), cryptoParams->GetQlHatInvModqPrecon(),
cryptoParams->GetQlHatModr(), cryptoParams->GetalphaQlModr(),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetqInv(), Format::EVALUATION);
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQ) ||
((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ < sizeQM))) {
cvPoverQ = cv;
for (size_t i = 0; i < cvSize; i++) {
// Expand ciphertext1 from basis Q to PQ.
cv[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(sizeQ - 1), cryptoParams->GetParamsRl(sizeQ - 1),
cryptoParams->GetQlHatInvModq(sizeQ - 1),
cryptoParams->GetQlHatInvModqPrecon(sizeQ - 1), cryptoParams->GetQlHatModr(sizeQ - 1),
cryptoParams->GetalphaQlModr(sizeQ - 1), cryptoParams->GetModrBarrettMu(),
cryptoParams->GetqInv(), Format::EVALUATION);
}
DCRTPoly::CRTBasisExtensionPrecomputations basisPQ(
cryptoParams->GetParamsQlRl(sizeQ - 1), cryptoParams->GetParamsRl(sizeQ - 1),
cryptoParams->GetParamsQl(sizeQ - 1), cryptoParams->GetmNegRlQlHatInvModq(sizeQ - 1),
cryptoParams->GetmNegRlQlHatInvModqPrecon(sizeQ - 1), cryptoParams->GetqInvModr(),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetRlHatInvModr(sizeQ - 1),
cryptoParams->GetRlHatInvModrPrecon(sizeQ - 1), cryptoParams->GetRlHatModq(sizeQ - 1),
cryptoParams->GetalphaRlModq(sizeQ - 1), cryptoParams->GetModqBarrettMu(), cryptoParams->GetrInv());
for (size_t i = 0; i < cvSize; i++) {
cvPoverQ[i].SetFormat(Format::COEFFICIENT);
// Switch ciphertext2 from basis Q to P to PQ (from Q_l to P_l to P_l*Q_l if manual compress/lower-level-encode was called).
cvPoverQ[i].FastExpandCRTBasisPloverQ(basisPQ);
cvPoverQ[i].SetFormat(Format::EVALUATION);
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ == sizeQM)) {
size_t cdepth = ciphertext->GetNoiseScaleDeg();
size_t levels = cdepth - 1;
double dcrtBits = cv[0].GetElementAtIndex(0).GetModulus().GetMSB();
// how many levels to drop
uint32_t levelsDropped = FindLevelsToDrop(levels, cryptoParams, dcrtBits, false);
l = levelsDropped > 0 ? sizeQ - 1 - levelsDropped : sizeQ - 1;
for (size_t i = 0; i < cvSize; i++) {
cv[i].SetFormat(Format::COEFFICIENT);
}
cvPoverQ = cv;
for (size_t i = 0; i < cvSize; i++) {
if (l < sizeQ - 1) {
// Drop from basis Q to Q_l.
cv[i] =
cv[i].ScaleAndRound(cryptoParams->GetParamsQl(l), cryptoParams->GetQlQHatInvModqDivqModq(l),
cryptoParams->GetQlQHatInvModqDivqFrac(l), cryptoParams->GetModqBarrettMu());
}
// Expand ciphertext1 from basis Q_l to PQ_l.
cv[i].ExpandCRTBasis(cryptoParams->GetParamsQlRl(l), cryptoParams->GetParamsRl(l),
cryptoParams->GetQlHatInvModq(l), cryptoParams->GetQlHatInvModqPrecon(l),
cryptoParams->GetQlHatModr(l), cryptoParams->GetalphaQlModr(l),
cryptoParams->GetModrBarrettMu(), cryptoParams->GetqInv(), Format::EVALUATION);
}
DCRTPoly::CRTBasisExtensionPrecomputations basisPQ(
cryptoParams->GetParamsQlRl(l), cryptoParams->GetParamsRl(l), cryptoParams->GetParamsQl(l),
cryptoParams->GetmNegRlQHatInvModq(l), cryptoParams->GetmNegRlQHatInvModqPrecon(l),
cryptoParams->GetqInvModr(), cryptoParams->GetModrBarrettMu(), cryptoParams->GetRlHatInvModr(l),
cryptoParams->GetRlHatInvModrPrecon(l), cryptoParams->GetRlHatModq(l), cryptoParams->GetalphaRlModq(l),
cryptoParams->GetModqBarrettMu(), cryptoParams->GetrInv());
for (size_t i = 0; i < cvSize; i++) {
cvPoverQ[i].FastExpandCRTBasisPloverQ(basisPQ);
cvPoverQ[i].SetFormat(Format::EVALUATION);
}
}
else {
for (size_t i = 0; i < cvSize; i++) {
cv[i].FastBaseConvqToBskMontgomery(
cryptoParams->GetParamsQBsk(), cryptoParams->GetModuliQ(), cryptoParams->GetModuliBsk(),
cryptoParams->GetModbskBarrettMu(), cryptoParams->GetmtildeQHatInvModq(),
cryptoParams->GetmtildeQHatInvModqPrecon(), cryptoParams->GetQHatModbsk(),
cryptoParams->GetQHatModmtilde(), cryptoParams->GetQModbsk(), cryptoParams->GetQModbskPrecon(),
cryptoParams->GetNegQInvModmtilde(), cryptoParams->GetmtildeInvModbsk(),
cryptoParams->GetmtildeInvModbskPrecon());
cv[i].SetFormat(Format::EVALUATION);
}
}
std::vector<DCRTPoly> cvSquare(cvSqSize);
#ifdef USE_KARATSUBA
if (cvSize == 2) {
if (cryptoParams->GetMultiplicationTechnique() == HPS || cryptoParams->GetMultiplicationTechnique() == BEHZ) {
// size of each ciphertxt = 2, use Karatsuba
cvSquare[0] = cv[0] * cv[0]; // a
cvSquare[2] = cv[1] * cv[1]; // b
cvSquare[1] = cv1[0] * cv1[1];
cvSquare[1] += cvSquare[1];
}
else {
// size of each ciphertxt = 2, use Karatsuba
cvSquare[0] = cv[0] * cvPoverQ[0]; // a
cvSquare[2] = cv[1] * cvPoverQ[1]; // b
cvSquare[1] = cv[0] + cv[1];
cvSquare[1] *= (cvPoverQ[0] + cvPoverQ[1]);
cvSquare[1] -= cvSquare[2];
cvSquare[1] -= cvSquare[0];
}
}
else {
std::vector<bool> isFirstAdd(cvSqSize, true);
DCRTPoly cvtemp;
if (cryptoParams->GetMultiplicationTechnique() == HPS || cryptoParams->GetMultiplicationTechnique() == BEHZ) {
for (size_t i = 0; i < cvSize; i++) {
for (size_t j = i; j < cvSize; j++) {
if (isFirstAdd[i + j] == true) {
if (j == i) {
cvSquare[i + j] = cv[i] * cv[j];
}
else {
cvtemp = cv[i] * cv[j];
cvSquare[i + j] = cvtemp;
cvSquare[i + j] += cvtemp;
}
isFirstAdd[i + j] = false;
}
else {
if (j == i) {
cvSquare[i + j] += cv[i] * cv[j];
}
else {
cvtemp = cv[i] * cv[j];
cvSquare[i + j] += cvtemp;
cvSquare[i + j] += cvtemp;
}
}
}
}
}
else {
for (size_t i = 0; i < cvSize; i++) {
for (size_t j = 0; j < cvSize; j++) {
if (isFirstAdd[i + j] == true) {
cvSquare[i + j] = cv[i] * cvPoverQ[j];
isFirstAdd[i + j] = false;
}
else {
cvSquare[i + j] += cv[i] * cvPoverQ[j];
}
}
}
}
}
#else
std::vector<bool> isFirstAdd(cvSqSize, true);
DCRTPoly cvtemp;
if (cryptoParams->GetMultiplicationTechnique() == HPS || cryptoParams->GetMultiplicationTechnique() == BEHZ) {
for (size_t i = 0; i < cvSize; i++) {
for (size_t j = i; j < cvSize; j++) {
if (isFirstAdd[i + j] == true) {
if (j == i) {
cvSquare[i + j] = cv[i] * cv[j];
}
else {
cvtemp = cv[i] * cv[j];
cvSquare[i + j] = cvtemp;
cvSquare[i + j] += cvtemp;
}
isFirstAdd[i + j] = false;
}
else {
if (j == i) {
cvSquare[i + j] += cv[i] * cv[j];
}
else {
cvtemp = cv[i] * cv[j];
cvSquare[i + j] += cvtemp;
cvSquare[i + j] += cvtemp;
}
}
}
}
}
else {
for (size_t i = 0; i < cvSize; i++) {
for (size_t j = 0; j < cvSize; j++) {
if (isFirstAdd[i + j] == true) {
cvSquare[i + j] = cv[i] * cvPoverQ[j];
isFirstAdd[i + j] = false;
}
else {
cvSquare[i + j] += cv[i] * cvPoverQ[j];
}
}
}
}
#endif
if (cryptoParams->GetMultiplicationTechnique() == HPS) {
for (size_t i = 0; i < cvSqSize; i++) {
// converts to coefficient representation before rounding
cvSquare[i].SetFormat(Format::COEFFICIENT);
// Performs the scaling by t/Q followed by rounding; the result is in the
// CRT basis P
cvSquare[i] =
cvSquare[i].ScaleAndRound(cryptoParams->GetParamsRl(), cryptoParams->GettRSHatInvModsDivsModr(),
cryptoParams->GettRSHatInvModsDivsFrac(), cryptoParams->GetModrBarrettMu());
// Converts from the CRT basis P to Q
cvSquare[i] = cvSquare[i].SwitchCRTBasis(cryptoParams->GetElementParams(), cryptoParams->GetRlHatInvModr(),
cryptoParams->GetRlHatInvModrPrecon(),
cryptoParams->GetRlHatModq(), cryptoParams->GetalphaRlModq(),
cryptoParams->GetModqBarrettMu(), cryptoParams->GetrInv());
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQ) ||
((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ < sizeQM))) {
l = sizeQ - 1;
for (size_t i = 0; i < cvSqSize; i++) {
cvSquare[i].SetFormat(COEFFICIENT);
// Performs the scaling by t/P followed by rounding; the result is in the
// CRT basis Q (Q_l if compress/lower-level encode was used)
cvSquare[i] = cvSquare[i].ScaleAndRound(
cryptoParams->GetParamsQl(l), cryptoParams->GettQlSlHatInvModsDivsModq(l),
cryptoParams->GettQlSlHatInvModsDivsFrac(l), cryptoParams->GetModqBarrettMu());
}
}
else if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ == sizeQM)) {
for (size_t i = 0; i < cvSqSize; i++) {
cvSquare[i].SetFormat(COEFFICIENT);
// Performs the scaling by t/P followed by rounding; the result is in the
// CRT basis Q
cvSquare[i] = cvSquare[i].ScaleAndRound(
cryptoParams->GetParamsQl(l), cryptoParams->GettQlSlHatInvModsDivsModq(l),
cryptoParams->GettQlSlHatInvModsDivsFrac(l), cryptoParams->GetModqBarrettMu());
if (l < sizeQ - 1) {
// Expand back to basis Q.
cvSquare[i].ExpandCRTBasisQlHat(cryptoParams->GetElementParams(), cryptoParams->GetQlHatModq(l),
cryptoParams->GetQlHatModqPrecon(l), sizeQ);
}
}
}
else {
const NativeInteger& t = cryptoParams->GetPlaintextModulus();
for (size_t i = 0; i < cvSqSize; i++) {
// converts to Format::COEFFICIENT representation before rounding
cvSquare[i].SetFormat(Format::COEFFICIENT);
// Performs the scaling by t/Q followed by rounding; the result is in the
// CRT basis {Bsk}
cvSquare[i].FastRNSFloorq(
t, cryptoParams->GetModuliQ(), cryptoParams->GetModuliBsk(), cryptoParams->GetModbskBarrettMu(),
cryptoParams->GettQHatInvModq(), cryptoParams->GettQHatInvModqPrecon(), cryptoParams->GetQHatModbsk(),
cryptoParams->GetqInvModbsk(), cryptoParams->GettQInvModbsk(), cryptoParams->GettQInvModbskPrecon());
// Converts from the CRT basis {Bsk} to {Q}
cvSquare[i].FastBaseConvSK(cryptoParams->GetElementParams(), cryptoParams->GetModqBarrettMu(),
cryptoParams->GetModuliBsk(), cryptoParams->GetModbskBarrettMu(),
cryptoParams->GetBHatInvModb(), cryptoParams->GetBHatInvModbPrecon(),
cryptoParams->GetBHatModmsk(), cryptoParams->GetBInvModmsk(),
cryptoParams->GetBInvModmskPrecon(), cryptoParams->GetBHatModq(),
cryptoParams->GetBModq(), cryptoParams->GetBModqPrecon());
}
}
auto ciphertextSq = ciphertext->CloneEmpty();
ciphertextSq->SetElements(std::move(cvSquare));
ciphertextSq->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() + 1);
return ciphertextSq;
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalMult(ConstCiphertext<DCRTPoly>& ciphertext1,
ConstCiphertext<DCRTPoly>& ciphertext2,
const EvalKey<DCRTPoly> evalKey) const {
auto ciphertext = EvalMult(ciphertext1, ciphertext2);
RelinearizeCore(ciphertext, evalKey);
return ciphertext;
}
void LeveledSHEBFVRNS::EvalMultInPlace(Ciphertext<DCRTPoly>& ciphertext1, ConstCiphertext<DCRTPoly>& ciphertext2,
const EvalKey<DCRTPoly> evalKey) const {
ciphertext1 = EvalMult(ciphertext1, ciphertext2);
RelinearizeCore(ciphertext1, evalKey);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalSquare(ConstCiphertext<DCRTPoly>& ciphertext,
const EvalKey<DCRTPoly> evalKey) const {
auto csquare = EvalSquare(ciphertext);
RelinearizeCore(csquare, evalKey);
return csquare;
}
void LeveledSHEBFVRNS::EvalSquareInPlace(Ciphertext<DCRTPoly>& ciphertext, const EvalKey<DCRTPoly> evalKey) const {
ciphertext = EvalSquare(ciphertext);
RelinearizeCore(ciphertext, evalKey);
}
void LeveledSHEBFVRNS::EvalMultCoreInPlace(Ciphertext<DCRTPoly>& ciphertext, NativeInteger scalar) const {
for (auto& cvi : ciphertext->GetElements())
cvi *= scalar;
ciphertext->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() + 1);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalAutomorphism(ConstCiphertext<DCRTPoly>& ciphertext, uint32_t i,
const std::map<uint32_t, EvalKey<DCRTPoly>>& evalKeyMap,
CALLER_INFO_ARGS_CPP) const {
uint32_t N = ciphertext->GetElements()[0].GetRingDimension();
std::vector<uint32_t> vec(N);
PrecomputeAutoMap(N, i, &vec);
auto result = ciphertext->Clone();
RelinearizeCore(result, evalKeyMap.at(i));
auto& rcv = result->GetElements();
rcv[0] = rcv[0].AutomorphismTransform(i, vec);
rcv[1] = rcv[1].AutomorphismTransform(i, vec);
return result;
}
std::shared_ptr<std::vector<DCRTPoly>> LeveledSHEBFVRNS::EvalFastRotationPrecompute(
ConstCiphertext<DCRTPoly>& ciphertext) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoParameters());
auto algo = ciphertext->GetCryptoContext()->GetScheme();
size_t sizeQ = ciphertext->GetElements()[0].GetNumOfElements();
const auto elementParams = cryptoParams->GetElementParams();
// Maximum number of RNS limbs in the crypto context
size_t sizeQM = elementParams->GetParams().size();
// in the HPSPOVERQLEVELED mode (without manually calling compress-like operations),
// an extra step of modulus reduction is needed
// otherwise, run the shared implemented of EvalKeySwitchPrecomputeCore for all RNS schemes
if (!((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ == sizeQM))) {
return algo->EvalKeySwitchPrecomputeCore(ciphertext->GetElements()[1], ciphertext->GetCryptoParameters());
}
else {
DCRTPoly c1 = ciphertext->GetElements()[1];
size_t levels = ciphertext->GetNoiseScaleDeg() - 1;
double dcrtBits = c1.GetElementAtIndex(0).GetModulus().GetMSB();
// how many levels to drop
uint32_t levelsDropped = FindLevelsToDrop(levels, cryptoParams, dcrtBits, true);
// l is index corresponding to leveled parameters in cryptoParameters precomputations in HPSPOVERQLEVELED
uint32_t l = levelsDropped > 0 ? sizeQ - 1 - levelsDropped : sizeQ - 1;
c1.SetFormat(COEFFICIENT);
c1 = c1.ScaleAndRound(cryptoParams->GetParamsQl(l), cryptoParams->GetQlQHatInvModqDivqModq(l),
cryptoParams->GetQlQHatInvModqDivqFrac(l), cryptoParams->GetModqBarrettMu());
c1.SetFormat(EVALUATION);
return algo->EvalKeySwitchPrecomputeCore(c1, ciphertext->GetCryptoParameters());
}
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalFastRotation(ConstCiphertext<DCRTPoly>& ciphertext, const uint32_t index,
const uint32_t m,
const std::shared_ptr<std::vector<DCRTPoly>> digits) const {
if (index == 0) {
return ciphertext->Clone();
}
const auto cc = ciphertext->GetCryptoContext();
uint32_t autoIndex = FindAutomorphismIndex(index, m);
auto evalKeyMap = cc->GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());
// verify if the key autoIndex exists in the evalKeyMap
auto evalKeyIterator = evalKeyMap.find(autoIndex);
if (evalKeyIterator == evalKeyMap.end()) {
OPENFHE_THROW("EvalKey for index [" + std::to_string(autoIndex) + "] is not found.");
}
auto evalKey = evalKeyIterator->second;
auto algo = cc->GetScheme();
const std::vector<DCRTPoly>& cv = ciphertext->GetElements();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoParameters());
// We remove all auxiliary moduli P_i in the case of hybrid key switching
// ATTN: elemParams should not be a shared_ptr because it would modify digits.
// TODO (dsuponit): wrap the lines below in a function to return elemParams as an object
auto elemParams = *((*digits)[0].GetParams());
if (cryptoParams->GetKeySwitchTechnique() == HYBRID) {
size_t sizeP = cryptoParams->GetParamsP()->GetParams().size();
for (size_t i = 0; i < sizeP; ++i) {
elemParams.PopLastParam();
}
}
std::shared_ptr<std::vector<DCRTPoly>> ba =
algo->EvalFastKeySwitchCore(digits, evalKey, std::make_shared<DCRTPoly::Params>(elemParams));
size_t sizeQ = ciphertext->GetElements()[0].GetNumOfElements();
const auto elementParams = cryptoParams->GetElementParams();
// Maximum number of RNS limbs in the crypto context
size_t sizeQM = elementParams->GetParams().size();
// In the HPSPOVERQLEVELED mode, we need to increase the modulus back to Q
if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQ == sizeQM)) {
size_t sizeQ = cv[0].GetNumOfElements();
// l is index corresponding to leveled parameters in cryptoParameters precomputations in HPSPOVERQLEVELED, after the level dropping
uint32_t l = elemParams.GetParams().size() - 1;
(*ba)[0].ExpandCRTBasisQlHat(cryptoParams->GetElementParams(), cryptoParams->GetQlHatModq(l),
cryptoParams->GetQlHatModqPrecon(l), sizeQ);
(*ba)[1].ExpandCRTBasisQlHat(cryptoParams->GetElementParams(), cryptoParams->GetQlHatModq(l),
cryptoParams->GetQlHatModqPrecon(l), sizeQ);
}
uint32_t N = cryptoParams->GetElementParams()->GetRingDimension();
std::vector<uint32_t> vec(N);
PrecomputeAutoMap(N, autoIndex, &vec);
(*ba)[0] += cv[0];
(*ba)[0] = (*ba)[0].AutomorphismTransform(autoIndex, vec);
(*ba)[1] = (*ba)[1].AutomorphismTransform(autoIndex, vec);
auto result = ciphertext->CloneEmpty();
result->SetElements({std::move((*ba)[0]), std::move((*ba)[1])});
return result;
}
uint32_t LeveledSHEBFVRNS::FindAutomorphismIndex(uint32_t index, uint32_t m) const {
return FindAutomorphismIndex2n(index, m);
}
void LeveledSHEBFVRNS::RelinearizeCore(Ciphertext<DCRTPoly>& ciphertext, const EvalKey<DCRTPoly> evalKey) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoParameters());
// l is index corresponding to leveled parameters in cryptoParameters precomputations in HPSPOVERQLEVELED
uint32_t l = 0;
std::vector<DCRTPoly>& cv = ciphertext->GetElements();
bool isKeySwitch = (cv.size() == 2);
auto algo = ciphertext->GetCryptoContext()->GetScheme();
size_t sel = 1 + !isKeySwitch;
size_t sizeQ = cv[0].GetNumOfElements();
const auto elementParams = cryptoParams->GetElementParams();
// Maximum number of RNS limbs in the crypto context
size_t sizeQM = elementParams->GetParams().size();
if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQM == sizeQ)) {
size_t levels = ciphertext->GetNoiseScaleDeg() - 1;
double dcrtBits = cv[0].GetElementAtIndex(0).GetModulus().GetMSB();
// how many levels to drop
l = sizeQ - 1 - FindLevelsToDrop(levels, cryptoParams, dcrtBits, isKeySwitch);
cv[sel].SetFormat(COEFFICIENT);
cv[sel] = cv[sel].ScaleAndRound(cryptoParams->GetParamsQl(l), cryptoParams->GetQlQHatInvModqDivqModq(l),
cryptoParams->GetQlQHatInvModqDivqFrac(l), cryptoParams->GetModqBarrettMu());
}
cv[sel].SetFormat(Format::EVALUATION);
auto ab = algo->KeySwitchCore(cv[sel], evalKey);
if ((cryptoParams->GetMultiplicationTechnique() == HPSPOVERQLEVELED) && (sizeQM == sizeQ)) {
size_t sizeQ = cv[0].GetNumOfElements();
(*ab)[0].ExpandCRTBasisQlHat(cryptoParams->GetElementParams(), cryptoParams->GetQlHatModq(l),
cryptoParams->GetQlHatModqPrecon(l), sizeQ);
(*ab)[1].ExpandCRTBasisQlHat(cryptoParams->GetElementParams(), cryptoParams->GetQlHatModq(l),
cryptoParams->GetQlHatModqPrecon(l), sizeQ);
}
cv[0].SetFormat(Format::EVALUATION);
cv[0] += (*ab)[0];
if (isKeySwitch) {
cv[1] = std::move((*ab)[1]);
}
else {
cv[1].SetFormat(Format::EVALUATION);
cv[1] += (*ab)[1];
}
cv.resize(2);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::Compress(ConstCiphertext<DCRTPoly>& ciphertext, size_t towersLeft,
size_t noiseScaleDeg) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBFVRNS>(ciphertext->GetCryptoParameters());
if ((cryptoParams->GetMultiplicationTechnique() == BEHZ) || (cryptoParams->GetMultiplicationTechnique() == HPS)) {
OPENFHE_THROW(
"BFV Compress is not currently supported for BEHZ or HPS. Use one of the HPSPOVERQ* methods instead.");
}
if ((cryptoParams->GetEncryptionTechnique() == EXTENDED)) {
OPENFHE_THROW(
"BFV Compress is not currently supported for the EXTENDED encryption method. Use the STANDARD encryption method instead.");
}
Ciphertext<DCRTPoly> result = std::make_shared<CiphertextImpl<DCRTPoly>>(*ciphertext);
std::vector<DCRTPoly>& cv = result->GetElements();
size_t sizeQ = cryptoParams->GetElementParams()->GetParams().size();
size_t sizeQl = cv[0].GetNumOfElements();
size_t diffQl = sizeQ - sizeQl;
size_t levels = sizeQl - towersLeft;
for (size_t l = 0; l < levels; ++l) {
for (size_t i = 0; i < cv.size(); ++i) {
cv[i].DropLastElementAndScale(cryptoParams->GetQlQlInvModqlDivqlModq(diffQl + l),
cryptoParams->GetqlInvModq(diffQl + l));
}
}
return result;
}
// We do not need to support LeveledSHEBFVRNS::EvalMultMutable(InPlace) as no automated adjustment of ciphertexts is
// typically done in BFV.
static const std::string EVAL_MUTABLE_ERROR{
"The mutable features are not supported in the BFV scheme. Please use a non-mutable version of this function"};
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalMultMutable(Ciphertext<DCRTPoly>& ciphertext1,
Ciphertext<DCRTPoly>& ciphertext2) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalMultMutable(Ciphertext<DCRTPoly>& ciphertext1,
Ciphertext<DCRTPoly>& ciphertext2,
const EvalKey<DCRTPoly> evalKey) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalMultMutable(Ciphertext<DCRTPoly>& ciphertext, Plaintext& plaintext) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
void LeveledSHEBFVRNS::EvalMultMutableInPlace(Ciphertext<DCRTPoly>& ciphertext1, Ciphertext<DCRTPoly>& ciphertext2,
const EvalKey<DCRTPoly> evalKey) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
void LeveledSHEBFVRNS::EvalMultMutableInPlace(Ciphertext<DCRTPoly>& ciphertext, Plaintext& plaintext) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalAddMutable(Ciphertext<DCRTPoly>& ciphertext1,
Ciphertext<DCRTPoly>& ciphertext2) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalAddMutable(Ciphertext<DCRTPoly>& ciphertext, Plaintext& plaintext) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
void LeveledSHEBFVRNS::EvalAddMutableInPlace(Ciphertext<DCRTPoly>& ciphertext1,
Ciphertext<DCRTPoly>& ciphertext2) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
void LeveledSHEBFVRNS::EvalAddMutableInPlace(Ciphertext<DCRTPoly>& ciphertext, Plaintext& plaintext) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalSubMutable(Ciphertext<DCRTPoly>& ciphertext1,
Ciphertext<DCRTPoly>& ciphertext2) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
Ciphertext<DCRTPoly> LeveledSHEBFVRNS::EvalSubMutable(Ciphertext<DCRTPoly>& ciphertext, Plaintext& plaintext) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
void LeveledSHEBFVRNS::EvalSubMutableInPlace(Ciphertext<DCRTPoly>& ciphertext1,
Ciphertext<DCRTPoly>& ciphertext2) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
void LeveledSHEBFVRNS::EvalSubMutableInPlace(Ciphertext<DCRTPoly>& ciphertext, Plaintext& plaintext) const {
OPENFHE_THROW(EVAL_MUTABLE_ERROR);
}
} // namespace lbcrypto