Program Listing for File bfvrns-cryptoparameters.cpp
↰ Return to documentation for file (pke/lib/scheme/bfvrns/bfvrns-cryptoparameters.cpp
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
BFV implementation. See https://eprint.iacr.org/2021/204 for details.
*/
#define PROFILE
#include "cryptocontext.h"
#include "scheme/bfvrns/bfvrns-cryptoparameters.h"
namespace lbcrypto {
// Precomputation of CRT tables for encryption, decryption, and homomorphic multiplication
void CryptoParametersBFVRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, ScalingTechnique scalTech,
EncryptionTechnique encTech, MultiplicationTechnique multTech,
uint32_t numPartQ, uint32_t auxBits, uint32_t extraBits) {
CryptoParametersRNS::PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraBits);
NativeInteger t = GetPlaintextModulus();
uint32_t n = GetElementParams()->GetRingDimension();
BigInteger modulusQ = GetElementParams()->GetModulus();
const auto& paramsQ = GetElementParams()->GetParams();
size_t sizeQ = paramsQ.size();
m_modqBarrettMu.resize(0);
m_modqBarrettMu.reserve(sizeQ);
m_tInvModq.resize(0);
m_tInvModq.reserve(sizeQ);
std::vector<NativeInteger> moduliQ, rootsQ;
moduliQ.reserve(sizeQ);
rootsQ.reserve(sizeQ);
const auto BarrettBase128Bit(BigInteger(1).LShiftEq(128));
for (const auto& p : paramsQ) {
m_tInvModq.emplace_back(t.ModInverse(p->GetModulus()));
m_modqBarrettMu.emplace_back((BarrettBase128Bit / BigInteger(p->GetModulus())).ConvertToInt<DoubleNativeInt>());
moduliQ.emplace_back(p->GetModulus());
rootsQ.emplace_back(p->GetRootOfUnity());
}
// BFVrns : Encrypt
NativeInteger modulusr = PreviousPrime<NativeInteger>(moduliQ[sizeQ - 1], 2 * n);
NativeInteger rootr = RootOfUnity<NativeInteger>(2 * n, modulusr);
m_negQModt = modulusQ.Mod(BigInteger(GetPlaintextModulus())).ConvertToInt();
m_negQModt = t.Sub(m_negQModt);
m_negQModtPrecon = m_negQModt.PrepModMulConst(t);
// BFVrns : Encrypt : With extra
if (encTech == EXTENDED) {
std::vector<NativeInteger> moduliQr(sizeQ + 1);
std::vector<NativeInteger> rootsQr(sizeQ + 1);
m_rInvModq.resize(sizeQ);
m_tInvModqr.resize(sizeQ + 1);
for (uint32_t i = 0; i < sizeQ; i++) {
moduliQr[i] = moduliQ[i];
rootsQr[i] = rootsQ[i];
m_tInvModqr[i] = m_tInvModq[i];
m_rInvModq[i] = modulusr.ModInverse(moduliQ[i]);
}
moduliQr[sizeQ] = modulusr;
rootsQr[sizeQ] = rootr;
m_paramsQr = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQr, rootsQr);
m_tInvModqr[sizeQ] = t.ModInverse(modulusr);
BigInteger modulusQr = modulusQ.Mul(modulusr);
m_negQrModt = modulusQr.Mod(BigInteger(t)).ConvertToInt();
m_negQrModt = t.Sub(m_negQrModt);
m_negQrModtPrecon = m_negQrModt.PrepModMulConst(t);
}
// HPS Precomputation
if (multTech != BEHZ) {
size_t sizeR = (multTech == HPS) ? sizeQ + 1 : sizeQ;
std::vector<NativeInteger> moduliR(sizeR);
std::vector<NativeInteger> rootsR(sizeR);
m_modrBarrettMu.resize(sizeR);
moduliR[0] = modulusr;
rootsR[0] = rootr;
m_modrBarrettMu[0] = (BarrettBase128Bit / BigInteger(moduliR[0])).ConvertToInt<DoubleNativeInt>();
for (size_t j = 1; j < sizeR; j++) {
moduliR[j] = PreviousPrime<NativeInteger>(moduliR[j - 1], 2 * n);
rootsR[j] = RootOfUnity<NativeInteger>(2 * n, moduliR[j]);
m_modrBarrettMu[j] = (BarrettBase128Bit / BigInteger(moduliR[j])).ConvertToInt<DoubleNativeInt>();
}
ChineseRemainderTransformFTT<NativeVector>().PreCompute(rootsR, 2 * n, moduliR);
// BFVrns : Mult : ExpandCRTBasis
// Pre-compute values [Ql/q_i]_{r_j}
// Pre-compute values [(Ql/q_i)^{-1}]_{q_i}
BigInteger tmpModulusQ = modulusQ;
if (multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ) {
m_QlHatInvModq.resize(sizeQ);
m_QlHatInvModqPrecon.resize(sizeQ);
m_QlHatModr.resize(sizeQ);
for (size_t l = 0; l < sizeQ; l++) {
if (l > 0)
tmpModulusQ = tmpModulusQ / BigInteger(moduliQ[sizeQ - l]);
m_QlHatInvModq[sizeQ - l - 1].resize(sizeQ - l);
m_QlHatInvModqPrecon[sizeQ - l - 1].resize(sizeQ - l);
m_QlHatModr[sizeQ - l - 1].resize(sizeR);
for (size_t j = 0; j < sizeR; j++) {
m_QlHatModr[sizeQ - l - 1][j].resize(sizeQ - l);
}
for (size_t i = 0; i < sizeQ - l; i++) {
m_QlHatModr[sizeQ - l - 1][i].resize(sizeR);
BigInteger QHati = tmpModulusQ / BigInteger(moduliQ[i]);
BigInteger QHatInvModqi = QHati.ModInverse(moduliQ[i]);
m_QlHatInvModq[sizeQ - l - 1][i] = QHatInvModqi.ConvertToInt();
m_QlHatInvModqPrecon[sizeQ - l - 1][i] =
m_QlHatInvModq[sizeQ - l - 1][i].PrepModMulConst(moduliQ[i]);
for (size_t j = 0; j < sizeR; j++) {
BigInteger QlHatModrij = QHati.Mod(moduliR[j]);
m_QlHatModr[sizeQ - l - 1][j][i] = QlHatModrij.ConvertToInt();
}
}
}
}
else {
m_QlHatInvModq.resize(1);
m_QlHatInvModqPrecon.resize(1);
m_QlHatInvModq[0].resize(sizeQ);
m_QlHatInvModqPrecon[0].resize(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
BigInteger QHati = modulusQ / BigInteger(moduliQ[i]);
BigInteger QHatInvModqi = QHati.ModInverse(moduliQ[i]);
m_QlHatInvModq[0][i] = QHatInvModqi.ConvertToInt();
m_QlHatInvModqPrecon[0][i] = m_QlHatInvModq[0][i].PrepModMulConst(moduliQ[i]);
}
m_QlHatModr.resize(1);
m_QlHatModr[0].resize(sizeR);
for (usint j = 0; j < sizeR; j++) {
m_QlHatModr[0][j].resize(sizeQ);
for (usint i = 0; i < sizeQ; i++) {
BigInteger QHati = modulusQ / BigInteger(moduliQ[i]);
m_QlHatModr[0][j][i] = QHati.Mod(moduliR[j]).ConvertToInt();
}
}
}
// BFVrns : Mult : ExpandCRTBasis
if (multTech == HPS) {
m_paramsQl.resize(1);
m_paramsRl.resize(1);
m_paramsQlRl.resize(1);
m_paramsQl[0] = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQ, rootsQ);
m_paramsRl[0] = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliR, rootsR);
std::vector<NativeInteger> moduliQR(sizeQ + sizeR);
std::vector<NativeInteger> rootsQR(sizeQ + sizeR);
for (size_t i = 0; i < sizeQ; i++) {
moduliQR[i] = moduliQ[i];
rootsQR[i] = rootsQ[i];
}
for (size_t j = 0; j < sizeR; j++) {
moduliQR[sizeQ + j] = moduliR[j];
rootsQR[sizeQ + j] = rootsR[j];
}
m_paramsQlRl[0] = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQR, rootsQR);
}
else if (multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ) {
m_paramsQl.resize(sizeQ);
m_paramsRl.resize(sizeQ);
m_paramsQlRl.resize(sizeQ);
std::vector<NativeInteger> moduliQl;
moduliQl.reserve(sizeQ);
std::vector<NativeInteger> rootsQl;
rootsQl.reserve(sizeQ);
std::vector<NativeInteger> moduliRl;
moduliRl.reserve(sizeQ);
std::vector<NativeInteger> rootsRl;
rootsRl.reserve(sizeQ);
std::vector<NativeInteger> moduliQlRl;
moduliQlRl.reserve(2 * sizeQ);
std::vector<NativeInteger> rootsQlRl;
rootsQlRl.reserve(2 * sizeQ);
for (usint l = 0; l < sizeQ; ++l) {
moduliQl.push_back(moduliQ[l]);
rootsQl.push_back(rootsQ[l]);
m_paramsQl[l] = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQl, rootsQl);
moduliRl.push_back(moduliR[l]);
rootsRl.push_back(rootsR[l]);
m_paramsRl[l] = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliRl, rootsRl);
moduliQlRl.insert(moduliQlRl.begin() + l, moduliQ[l]);
rootsQlRl.insert(rootsQlRl.begin() + l, rootsQ[l]);
moduliQlRl.push_back(moduliR[l]);
rootsQlRl.push_back(rootsR[l]);
m_paramsQlRl[l] = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQlRl, rootsQlRl);
}
}
m_modrBarrettMu.resize(sizeR);
for (uint32_t j = 0; j < moduliR.size(); j++) {
m_modrBarrettMu[j] = (BarrettBase128Bit / BigInteger(moduliR[j])).ConvertToInt<DoubleNativeInt>();
}
m_qInv.resize(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
m_qInv[i] = 1. / static_cast<double>(moduliQ[i].ConvertToInt());
}
// BFVrns : Mult : ScaleAndRound
const BigInteger modulusR = multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ ?
m_paramsRl[sizeQ - 1]->GetModulus() :
m_paramsRl[0]->GetModulus();
const BigInteger modulusQR = multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ ?
m_paramsQlRl[sizeQ - 1]->GetModulus() :
m_paramsQlRl[0]->GetModulus();
const BigInteger modulust(GetPlaintextModulus());
m_tRSHatInvModsDivsFrac.resize(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
m_tRSHatInvModsDivsFrac[i] =
static_cast<double>(
((modulusQR.DividedBy(qi)).ModInverse(qi) * modulusR * modulust).Mod(qi).ConvertToInt()) /
static_cast<double>(qi.ConvertToInt());
}
m_tRSHatInvModsDivsModr.resize(sizeR);
for (usint j = 0; j < sizeR; j++) {
m_tRSHatInvModsDivsModr[j].reserve(sizeQ + 1);
BigInteger rj(moduliR[j].ConvertToInt());
for (usint i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
BigInteger tRSHatInvMods = modulust * modulusR * ((modulusQR.DividedBy(qi)).ModInverse(qi));
BigInteger tRSHatInvModsDivs = tRSHatInvMods / qi;
m_tRSHatInvModsDivsModr[j].push_back(tRSHatInvModsDivs.Mod(rj).ConvertToInt());
}
BigInteger tRSHatInvMods = modulust * modulusR * ((modulusQR.DividedBy(rj)).ModInverse(rj));
BigInteger tRSHatInvModsDivs = tRSHatInvMods / rj;
m_tRSHatInvModsDivsModr[j].push_back(tRSHatInvModsDivs.Mod(rj).ConvertToInt());
}
// BFVrns : Mult : SwitchCRTBasis
std::vector<BigInteger> Ql(sizeQ + 1);
std::vector<BigInteger> Rl(sizeQ + 1);
std::vector<BigInteger> QlRl(sizeQ + 1);
std::vector<BigInteger> QlHat(sizeQ + 1);
std::vector<BigInteger> RlHat(sizeQ + 1);
if (multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ) {
Ql[0] = 1;
Rl[0] = 1;
QlRl[0] = 1;
QlHat[0] = modulusQ;
RlHat[0] = modulusR;
for (usint l = 0; l < sizeQ; ++l) {
BigInteger ql(moduliQ[l].ConvertToInt());
BigInteger rl(moduliR[l].ConvertToInt());
Ql[l + 1] = Ql[l] * ql;
Rl[l + 1] = Rl[l] * rl;
QlRl[l + 1] = QlRl[l] * ql;
QlRl[l + 1] = QlRl[l + 1] * rl;
QlHat[l + 1] = QlHat[l] / ql;
RlHat[l + 1] = RlHat[l] / rl;
}
}
// BFVrns : Mult : ExpandCRTBasis
if (multTech == HPS) {
m_alphaQlModr.resize(1);
m_alphaQlModr[0].resize(sizeQ + 1, std::vector<NativeInteger>(sizeR));
for (usint j = 0; j < sizeR; j++) {
NativeInteger QModrj = modulusQ.Mod(moduliR[j]).ConvertToInt();
for (usint i = 0; i < sizeQ + 1; i++) {
m_alphaQlModr[0][i][j] = QModrj.ModMul(NativeInteger(i), moduliR[j]);
}
}
}
else if (multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ) {
m_alphaQlModr.resize(sizeQ);
for (usint l = sizeQ; l > 0; l--) {
m_alphaQlModr[l - 1].resize(l + 1, std::vector<NativeInteger>(sizeR));
for (usint i = 0; i < sizeR; i++) {
NativeInteger QlModri = Ql[l].Mod(moduliR[i]).ConvertToInt();
for (usint j = 0; j < l + 1; ++j) {
m_alphaQlModr[l - 1][j][i] = QlModri.ModMul(NativeInteger(j), moduliR[i]);
}
}
}
}
// Pre-compute values [Rl/r_j]_{q_i}
// Pre-compute values [(Rl/r_j)^{-1}]_{r_j}
if (multTech == HPS) {
m_RlHatInvModr.resize(1);
m_RlHatInvModrPrecon.resize(1);
m_RlHatInvModr[0].resize(sizeR);
m_RlHatInvModrPrecon[0].resize(sizeR);
for (size_t j = 0; j < sizeR; j++) {
BigInteger RHatj = modulusR / BigInteger(moduliR[j]);
m_RlHatInvModr[0][j] = RHatj.ModInverse(moduliR[j]).ConvertToInt();
m_RlHatInvModrPrecon[0][j] = m_RlHatInvModr[0][j].PrepModMulConst(moduliR[j]);
}
m_RlHatModq.resize(1);
m_RlHatModq[0].resize(sizeQ);
for (usint i = 0; i < sizeQ; i++) {
m_RlHatModq[0][i].resize(sizeR);
for (usint j = 0; j < sizeR; j++) {
BigInteger RHatj = modulusR / BigInteger(moduliR[j]);
m_RlHatModq[0][i][j] = RHatj.Mod(moduliQ[i]).ConvertToInt();
}
}
}
else if (multTech == HPSPOVERQ || multTech == HPSPOVERQLEVELED) {
m_RlHatInvModr.resize(sizeR);
m_RlHatInvModrPrecon.resize(sizeR);
m_RlHatModq.resize(sizeR);
for (usint l = sizeR; l > 0; l--) {
m_RlHatInvModr[l - 1].resize(l);
m_RlHatInvModrPrecon[l - 1].resize(l);
m_RlHatModq[l - 1].resize(l, std::vector<NativeInteger>(l));
for (size_t j = 0; j < l; j++) {
BigInteger RlHatj = Rl[l] / BigInteger(moduliR[j]);
BigInteger RlHatInvModrj = RlHatj.ModInverse(moduliR[j]);
m_RlHatInvModr[l - 1][j] = RlHatInvModrj.ConvertToInt();
m_RlHatInvModrPrecon[l - 1][j] = m_RlHatInvModr[l - 1][j].PrepModMulConst(moduliR[j]);
for (size_t i = 0; i < l; i++) {
BigInteger RlHatModqji = RlHatj.Mod(moduliQ[i]);
m_RlHatModq[l - 1][i][j] = RlHatModqji.ConvertToInt();
}
}
}
}
// compute [\alpha*Rl]_{q_i} for 0 <= alpha <= sizeRl
// used for homomorphic multiplication
if (multTech == HPS) {
m_alphaRlModq.resize(1);
m_alphaRlModq[0].resize(sizeR + 1, std::vector<NativeInteger>(sizeQ));
for (usint i = 0; i < sizeQ; i++) {
NativeInteger RModqi = modulusR.Mod(moduliQ[i]).ConvertToInt();
for (usint j = 0; j < sizeR + 1; ++j) {
m_alphaRlModq[0][j][i] = RModqi.ModMul(NativeInteger(j), moduliQ[i]);
}
}
}
else if (multTech == HPSPOVERQLEVELED || multTech == HPSPOVERQ) {
m_alphaRlModq.resize(sizeR);
for (usint l = sizeR; l > 0; l--) {
m_alphaRlModq[l - 1].resize(l + 1, std::vector<NativeInteger>(sizeQ));
for (usint i = 0; i < sizeQ; i++) {
NativeInteger RlModqi = Rl[l].Mod(moduliQ[i]).ConvertToInt();
for (usint j = 0; j < l + 1; ++j) {
m_alphaRlModq[l - 1][j][i] = RlModqi.ModMul(NativeInteger(j), moduliQ[i]);
}
}
}
}
m_rInv.resize(sizeR);
for (size_t j = 0; j < sizeR; j++) {
m_rInv[j] = 1. / static_cast<double>(moduliR[j].ConvertToInt());
}
// BFVrns : Decrypt : ScaleAndRound
usint qMSB = moduliQ[0].GetMSB();
usint sizeQMSB = GetMSB64(sizeQ);
m_tQHatInvModqDivqModt.resize(sizeQ);
m_tQHatInvModqDivqModtPrecon.resize(sizeQ);
m_tQHatInvModqDivqFrac.resize(sizeQ);
if (qMSB + sizeQMSB < 52) {
for (size_t i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
BigInteger tQHatInvModqi =
((modulusQ.DividedBy(qi)).ModInverse(qi) * BigInteger(GetPlaintextModulus()));
BigInteger tQHatInvModqDivqi = tQHatInvModqi.DividedBy(qi);
m_tQHatInvModqDivqModt[i] = tQHatInvModqDivqi.Mod(GetPlaintextModulus()).ConvertToInt();
m_tQHatInvModqDivqModtPrecon[i] = m_tQHatInvModqDivqModt[i].PrepModMulConst(GetPlaintextModulus());
int64_t numerator = tQHatInvModqi.Mod(qi).ConvertToInt();
int64_t denominator = moduliQ[i].ConvertToInt();
m_tQHatInvModqDivqFrac[i] = static_cast<double>(numerator) / static_cast<double>(denominator);
}
}
else {
m_tQHatInvModqBDivqModt.resize(sizeQ);
m_tQHatInvModqBDivqModtPrecon.resize(sizeQ);
m_tQHatInvModqBDivqFrac.resize(sizeQ);
usint qMSBHf = qMSB >> 1;
for (size_t i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
BigInteger tQHatInvModqi =
((modulusQ.DividedBy(qi)).ModInverse(qi) * BigInteger(GetPlaintextModulus()));
BigInteger tQHatInvModqDivqi = tQHatInvModqi.DividedBy(qi);
m_tQHatInvModqDivqModt[i] = tQHatInvModqDivqi.Mod(GetPlaintextModulus()).ConvertToInt();
m_tQHatInvModqDivqModtPrecon[i] = m_tQHatInvModqDivqModt[i].PrepModMulConst(GetPlaintextModulus());
int64_t numerator = tQHatInvModqi.Mod(qi).ConvertToInt();
int64_t denominator = moduliQ[i].ConvertToInt();
m_tQHatInvModqDivqFrac[i] = static_cast<double>(numerator) / static_cast<double>(denominator);
tQHatInvModqi.LShiftEq(qMSBHf);
tQHatInvModqDivqi = tQHatInvModqi.DividedBy(qi);
m_tQHatInvModqBDivqModt[i] = tQHatInvModqDivqi.Mod(GetPlaintextModulus()).ConvertToInt();
m_tQHatInvModqBDivqModtPrecon[i] = m_tQHatInvModqBDivqModt[i].PrepModMulConst(GetPlaintextModulus());
numerator = tQHatInvModqi.Mod(qi).ConvertToInt();
m_tQHatInvModqBDivqFrac[i] = static_cast<double>(numerator) / static_cast<double>(denominator);
}
}
// BFVrns : Mult : FastExpandCRTBasisPloverQ
if (multTech == HPSPOVERQ || multTech == HPSPOVERQLEVELED) {
m_negRlQHatInvModq.resize(sizeR);
m_negRlQHatInvModqPrecon.resize(sizeR);
for (usint l = sizeR; l > 0; l--) {
m_negRlQHatInvModq[l - 1].resize(sizeQ);
m_negRlQHatInvModqPrecon[l - 1].resize(sizeQ);
for (usint i = 0; i < sizeQ; i++) {
BigInteger QHati = modulusQ / BigInteger(moduliQ[i]);
BigInteger QHatInvModqi = QHati.ModInverse(moduliQ[i]);
m_negRlQHatInvModq[l - 1][i] = Rl[l].ModMul(QHatInvModqi, moduliQ[i]).ConvertToInt();
m_negRlQHatInvModq[l - 1][i] = moduliQ[i].Sub(m_negRlQHatInvModq[l - 1][i]);
m_negRlQHatInvModqPrecon[l - 1][i] = m_negRlQHatInvModq[l - 1][i].PrepModMulConst(moduliQ[i]);
}
}
}
m_qInvModr.resize(sizeQ);
for (usint i = 0; i < sizeQ; i++) {
m_qInvModr[i].resize(sizeR);
for (usint j = 0; j < sizeR; j++) {
m_qInvModr[i][j] = moduliQ[i].ModInverse(moduliR[j]);
}
}
modulusQ = GetElementParams()->GetModulus();
// BFVrns : Mult : ScaleAndRoundP
if (multTech == HPS || multTech == HPSPOVERQ) {
m_tQlSlHatInvModsDivsFrac.resize(1);
m_tQlSlHatInvModsDivsFrac[0].resize(sizeR);
for (size_t j = 0; j < sizeR; j++) {
BigInteger rj(moduliR[j].ConvertToInt());
m_tQlSlHatInvModsDivsFrac[0][j] =
static_cast<double>(
((modulusQR.DividedBy(rj)).ModInverse(rj) * modulusQ * modulust).Mod(rj).ConvertToInt()) /
static_cast<double>(rj.ConvertToInt());
}
m_tQlSlHatInvModsDivsModq.resize(1);
m_tQlSlHatInvModsDivsModq[0].resize(sizeQ, std::vector<NativeInteger>(sizeR + 1));
for (usint i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
for (usint j = 0; j < sizeR; j++) {
BigInteger rj(moduliR[j].ConvertToInt());
BigInteger tQlSlHatInvMods = modulust * modulusQ * ((modulusQR.DividedBy(rj)).ModInverse(rj));
BigInteger tQlSlHatInvModsDivs = tQlSlHatInvMods / rj;
m_tQlSlHatInvModsDivsModq[0][i][j] = tQlSlHatInvModsDivs.Mod(qi).ConvertToInt();
}
BigInteger tQlSlHatInvMods = modulust * modulusQ * ((modulusQR.DividedBy(qi)).ModInverse(qi));
BigInteger tQlSlHatInvModsDivs = tQlSlHatInvMods / qi;
m_tQlSlHatInvModsDivsModq[0][i][sizeR] = tQlSlHatInvModsDivs.Mod(qi).ConvertToInt();
}
}
else if (multTech == HPSPOVERQLEVELED) {
m_tQlSlHatInvModsDivsFrac.resize(sizeQ);
m_tQlSlHatInvModsDivsModq.resize(sizeQ);
for (usint l = sizeQ; l > 0; l--) {
m_tQlSlHatInvModsDivsFrac[l - 1].resize(l);
for (size_t j = 0; j < l; j++) {
BigInteger rj(moduliR[j].ConvertToInt());
m_tQlSlHatInvModsDivsFrac[l - 1][j] =
static_cast<double>(
((QlRl[l].DividedBy(rj)).ModInverse(rj) * Ql[l] * modulust).Mod(rj).ConvertToInt()) /
static_cast<double>(rj.ConvertToInt());
}
m_tQlSlHatInvModsDivsModq[l - 1].resize(l, std::vector<NativeInteger>(l + 1));
for (usint i = 0; i < l; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
for (usint j = 0; j < l; j++) {
BigInteger rj(moduliR[j].ConvertToInt());
BigInteger tQlSlHatInvMods = modulust * Ql[l] * ((QlRl[l].DividedBy(rj)).ModInverse(rj));
BigInteger tQlSlHatInvModsDivs = tQlSlHatInvMods / rj;
m_tQlSlHatInvModsDivsModq[l - 1][i][j] = tQlSlHatInvModsDivs.Mod(qi).ConvertToInt();
}
BigInteger tQlSlHatInvMods = modulust * Ql[l] * ((QlRl[l].DividedBy(qi)).ModInverse(qi));
BigInteger tQlSlHatInvModsDivs = tQlSlHatInvMods / qi;
m_tQlSlHatInvModsDivsModq[l - 1][i][l] = tQlSlHatInvModsDivs.Mod(qi).ConvertToInt();
}
}
}
// BFVrns : Mult : ScaleAndRoundQl
m_QlQHatInvModqDivqModq.resize(sizeQ);
m_QlQHatInvModqDivqFrac.resize(sizeQ);
for (usint l = sizeQ; l > 0; l--) {
m_QlQHatInvModqDivqFrac[l - 1].resize(sizeQ - l);
for (size_t j = 0; j < sizeQ - l; j++) {
BigInteger qj(moduliQ[j + l].ConvertToInt());
m_QlQHatInvModqDivqFrac[l - 1][j] =
static_cast<double>(((modulusQ.DividedBy(qj)).ModInverse(qj) * Ql[l]).Mod(qj).ConvertToInt()) /
static_cast<double>(qj.ConvertToInt());
}
m_QlQHatInvModqDivqModq[l - 1].resize(l);
for (usint i = 0; i < l; i++) {
m_QlQHatInvModqDivqModq[l - 1][i].resize(sizeQ - l + 1);
BigInteger qi(moduliQ[i].ConvertToInt());
for (usint j = 0; j < sizeQ - l; j++) {
BigInteger qj(moduliQ[l + j].ConvertToInt());
BigInteger QlQHatInvModq = Ql[l] * ((modulusQ.DividedBy(qj)).ModInverse(qj));
BigInteger QlQHatInvModqDivq = QlQHatInvModq / qj;
m_QlQHatInvModqDivqModq[l - 1][i][j] = QlQHatInvModqDivq.Mod(qi).ConvertToInt();
}
BigInteger QlQHatInvModq = Ql[l] * ((modulusQ.DividedBy(qi)).ModInverse(qi));
BigInteger QlQHatInvModqDivq = QlQHatInvModq / qi;
m_QlQHatInvModqDivqModq[l - 1][i][sizeQ - l] = QlQHatInvModqDivq.Mod(qi).ConvertToInt();
}
}
// BFVrns : Mult : ExpandCRTBasisQlHat
m_QlHatModq.resize(sizeQ);
m_QlHatModqPrecon.resize(sizeQ);
for (usint l = sizeQ; l > 0; l--) {
m_QlHatModq[l - 1].resize(l);
m_QlHatModqPrecon[l - 1].resize(l);
for (usint i = 0; i < l; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
m_QlHatModq[l - 1][i] = QlHat[l].Mod(qi).ConvertToInt();
m_QlHatModqPrecon[l - 1][i] = m_QlHatModq[l - 1][i].PrepModMulConst(qi);
}
}
// DropLastElementAndScale
// Pre-compute omega values for rescaling in RNS
// modulusQ holds Q^(l) = \prod_{i=0}^{i=l}(q_i).
m_QlQlInvModqlDivqlModq.resize(sizeQ - 1);
m_QlQlInvModqlDivqlModqPrecon.resize(sizeQ - 1);
m_qlInvModq.resize(sizeQ - 1);
m_qlInvModqPrecon.resize(sizeQ - 1);
for (size_t k = 0; k < sizeQ - 1; k++) {
size_t l = sizeQ - (k + 1);
modulusQ = modulusQ / BigInteger(moduliQ[l]);
m_QlQlInvModqlDivqlModq[k].resize(l);
m_QlQlInvModqlDivqlModqPrecon[k].resize(l);
m_qlInvModq[k].resize(l);
m_qlInvModqPrecon[k].resize(l);
BigInteger QlInvModql = modulusQ.ModInverse(moduliQ[l]);
BigInteger result = (QlInvModql * modulusQ) / BigInteger(moduliQ[l]);
for (usint i = 0; i < l; i++) {
m_QlQlInvModqlDivqlModq[k][i] = result.Mod(moduliQ[i]).ConvertToInt();
m_QlQlInvModqlDivqlModqPrecon[k][i] = m_QlQlInvModqlDivqlModq[k][i].PrepModMulConst(moduliQ[i]);
m_qlInvModq[k][i] = moduliQ[l].ModInverse(moduliQ[i]);
m_qlInvModqPrecon[k][i] = m_qlInvModq[k][i].PrepModMulConst(moduliQ[i]);
}
}
}
// BEHZ Precomputation
if (multTech == BEHZ) {
m_moduliQ = moduliQ;
m_numq = sizeQ;
std::vector<std::shared_ptr<ILNativeParams>> params;
params.reserve(2 * sizeQ + 1);
for (usint i = 0; i < m_numq; ++i)
params.emplace_back(std::make_shared<ILNativeParams>(2 * n, moduliQ[i]));
m_moduliB.push_back(PreviousPrime<NativeInteger>(moduliQ.back(), 2 * n));
m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * n, m_moduliB.back()));
params.emplace_back(std::make_shared<ILNativeParams>(2 * n, m_moduliB.back(), m_rootsBsk.back()));
BigInteger B(m_moduliB.back());
for (usint i = 1; i < m_numq; ++i) { // we already added one prime
m_moduliB.push_back(PreviousPrime<NativeInteger>(m_moduliB.back(), 2 * n));
m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * n, m_moduliB.back()));
params.emplace_back(std::make_shared<ILNativeParams>(2 * n, m_moduliB.back(), m_rootsBsk.back()));
B = B * BigInteger(m_moduliB.back());
}
m_numb = m_numq;
m_msk = PreviousPrime<NativeInteger>(m_moduliB[m_numq - 1], 2 * n);
usint s = m_msk.GetMSB();
BigInteger Q(GetElementParams()->GetModulus());
BigInteger maxConvolutionValue(BigInteger(2 * n) * BigInteger(GetPlaintextModulus()) * Q);
// check msk is large enough
while (B * BigInteger(m_msk) < maxConvolutionValue) {
// TODO: revisit this logic. Maybe change to m_msk = LastPrime<NativeInteger>(++s, 2 * n);
auto firstInteger{FirstPrime<NativeInteger>(++s, 2 * n)};
m_msk = NextPrime<NativeInteger>(firstInteger, 2 * n);
}
m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * n, m_msk));
m_moduliBsk = m_moduliB;
m_moduliBsk.push_back(m_msk);
params.emplace_back(std::make_shared<ILNativeParams>(2 * n, m_moduliBsk.back(), m_rootsBsk.back()));
m_paramsQBsk = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, params);
ChineseRemainderTransformFTT<NativeVector>().PreCompute(m_rootsBsk, 2 * n, m_moduliBsk);
// populate Barrett constant for m_BskModuli
m_modbskBarrettMu.resize(m_moduliBsk.size());
for (uint32_t i = 0; i < m_modbskBarrettMu.size(); i++) {
m_modbskBarrettMu[i] = (BarrettBase128Bit / BigInteger(m_moduliBsk[i])).ConvertToInt<DoubleNativeInt>();
}
// Populate [t*(Q/q_i)^-1]_{q_i}
m_tQHatInvModq.resize(m_numq);
m_tQHatInvModqPrecon.resize(m_numq);
for (uint32_t i = 0; i < m_tQHatInvModq.size(); i++) {
BigInteger tQHatInvModqi;
tQHatInvModqi = Q.DividedBy(moduliQ[i]);
tQHatInvModqi = tQHatInvModqi.Mod(moduliQ[i]);
tQHatInvModqi = tQHatInvModqi.ModInverse(moduliQ[i]);
tQHatInvModqi = tQHatInvModqi.ModMul(t.ConvertToInt(), moduliQ[i]);
m_tQHatInvModq[i] = tQHatInvModqi.ConvertToInt();
m_tQHatInvModqPrecon[i] = m_tQHatInvModq[i].PrepModMulConst(moduliQ[i]);
}
// Populate [Q/q_i]_{bsk_j, mtilde}
m_QHatModbsk.resize(m_numq);
m_QHatModmtilde.resize(m_numq);
for (uint32_t i = 0; i < m_QHatModbsk.size(); i++) {
m_QHatModbsk[i].resize(m_numb + 1);
BigInteger QHati = Q.DividedBy(moduliQ[i]);
for (uint32_t j = 0; j < m_QHatModbsk[i].size(); j++) {
BigInteger QHatiModbskj = QHati.Mod(m_moduliBsk[j]);
m_QHatModbsk[i][j] = QHatiModbskj.ConvertToInt();
}
m_QHatModmtilde[i] = QHati.Mod(m_mtilde).ConvertToInt();
}
// Populate [1/q_i]_{bsk_j}
m_qInvModbsk.resize(m_numq);
for (uint32_t i = 0; i < m_qInvModbsk.size(); i++) {
m_qInvModbsk[i].resize(m_numb + 1);
for (uint32_t j = 0; j < m_qInvModbsk[i].size(); j++)
m_qInvModbsk[i][j] = moduliQ[i].ModInverse(m_moduliBsk[j]);
}
// Populate [mtilde*(Q/q_i)^{-1}]_{q_i}
m_mtildeQHatInvModq.resize(m_numq);
m_mtildeQHatInvModqPrecon.resize(m_numq);
BigInteger bmtilde(m_mtilde);
for (uint32_t i = 0; i < m_mtildeQHatInvModq.size(); i++) {
BigInteger mtildeQHatInvModqi = Q.DividedBy(moduliQ[i]);
mtildeQHatInvModqi = mtildeQHatInvModqi.Mod(moduliQ[i]);
mtildeQHatInvModqi = mtildeQHatInvModqi.ModInverse(moduliQ[i]);
mtildeQHatInvModqi = mtildeQHatInvModqi * bmtilde;
mtildeQHatInvModqi = mtildeQHatInvModqi.Mod(moduliQ[i]);
m_mtildeQHatInvModq[i] = mtildeQHatInvModqi.ConvertToInt();
m_mtildeQHatInvModqPrecon[i] = m_mtildeQHatInvModq[i].PrepModMulConst(moduliQ[i]);
}
// Populate [-Q^{-1}]_{mtilde}
BigInteger negQInvModmtilde = (BigInteger(m_mtilde - 1) * Q.ModInverse(m_mtilde));
negQInvModmtilde = negQInvModmtilde.Mod(m_mtilde);
m_negQInvModmtilde = negQInvModmtilde.ConvertToInt();
// Populate [Q]_{bski_j}
m_QModbsk.resize(m_numq + 1);
m_QModbskPrecon.resize(m_numq + 1);
for (uint32_t j = 0; j < m_QModbsk.size(); j++) {
BigInteger QModbskij = Q.Mod(m_moduliBsk[j]);
m_QModbsk[j] = QModbskij.ConvertToInt();
m_QModbskPrecon[j] = m_QModbsk[j].PrepModMulConst(m_moduliBsk[j]);
}
// Populate [mtilde^{-1}]_{bsk_j}
m_mtildeInvModbsk.resize(m_numb + 1);
m_mtildeInvModbskPrecon.resize(m_numb + 1);
for (uint32_t j = 0; j < m_mtildeInvModbsk.size(); j++) {
BigInteger mtildeInvModbskij = m_mtilde % m_moduliBsk[j];
mtildeInvModbskij = mtildeInvModbskij.ModInverse(m_moduliBsk[j]);
m_mtildeInvModbsk[j] = mtildeInvModbskij.ConvertToInt();
m_mtildeInvModbskPrecon[j] = m_mtildeInvModbsk[j].PrepModMulConst(m_moduliBsk[j]);
}
// Populate {t/Q}_{bsk_j}
m_tQInvModbsk.resize(m_numb + 1);
m_tQInvModbskPrecon.resize(m_numb + 1);
for (uint32_t i = 0; i < m_tQInvModbsk.size(); i++) {
BigInteger tDivqModBski = Q.ModInverse(m_moduliBsk[i]);
tDivqModBski.ModMulEq(t.ConvertToInt(), m_moduliBsk[i]);
m_tQInvModbsk[i] = tDivqModBski.ConvertToInt();
m_tQInvModbskPrecon[i] = m_tQInvModbsk[i].PrepModMulConst(m_moduliBsk[i]);
}
// Populate [(B/b_j)^{-1}]_{b_j}
m_BHatInvModb.resize(m_numb);
m_BHatInvModbPrecon.resize(m_numb);
for (uint32_t i = 0; i < m_BHatInvModb.size(); i++) {
BigInteger BDivBi;
BDivBi = B.DividedBy(m_moduliB[i]);
BDivBi = BDivBi.Mod(m_moduliB[i]);
BDivBi = BDivBi.ModInverse(m_moduliB[i]);
m_BHatInvModb[i] = BDivBi.ConvertToInt();
m_BHatInvModbPrecon[i] = m_BHatInvModb[i].PrepModMulConst(m_moduliB[i]);
}
// Populate [B/b_j]_{q_i}
m_BHatModq.resize(m_numb);
for (uint32_t i = 0; i < m_BHatModq.size(); i++) {
m_BHatModq[i].resize(m_numq);
BigInteger BDivBi = B.DividedBy(m_moduliB[i]);
for (uint32_t j = 0; j < m_BHatModq[i].size(); j++) {
BigInteger BDivBiModqj = BDivBi.Mod(moduliQ[j]);
m_BHatModq[i][j] = BDivBiModqj.ConvertToInt();
}
}
// Populate [B/b_j]_{msk}
m_BHatModmsk.resize(m_numb);
for (uint32_t i = 0; i < m_BHatModmsk.size(); i++) {
BigInteger BDivBi = B.DividedBy(m_moduliB[i]);
m_BHatModmsk[i] = (BDivBi.Mod(m_msk)).ConvertToInt();
}
// Populate [B^{-1}]_{msk}
m_BInvModmsk = (B.ModInverse(m_msk)).ConvertToInt();
m_BInvModmskPrecon = m_BInvModmsk.PrepModMulConst(m_msk);
// Populate [B]_{q_i}
m_BModq.resize(m_numq);
m_BModqPrecon.resize(m_numq);
for (uint32_t i = 0; i < m_BModq.size(); i++) {
m_BModq[i] = (B.Mod(moduliQ[i])).ConvertToInt();
m_BModqPrecon[i] = m_BModq[i].PrepModMulConst(moduliQ[i]);
}
// Populate Decrns lookup tables
NativeInteger tgamma = NativeInteger(t.ConvertToInt() * m_gamma); // t*gamma
m_tgamma = tgamma;
// Populate [-1/q_i]_{t*gamma} (t*gamma < 2^58)
m_negInvqModtgamma.resize(m_numq);
m_negInvqModtgammaPrecon.resize(m_numq);
for (uint32_t i = 0; i < m_negInvqModtgamma.size(); i++) {
BigInteger imod(moduliQ[i]);
BigInteger negInvqi = BigInteger((tgamma - 1)) * imod.ModInverse(tgamma);
BigInteger negInvqiModtgamma = negInvqi.Mod(tgamma);
m_negInvqModtgamma[i] = negInvqiModtgamma.ConvertToInt();
m_negInvqModtgammaPrecon[i] = m_negInvqModtgamma[i].PrepModMulConst(tgamma);
}
// populate [t*gamma*(Q/q_i)^(-1)]_{q_i}
m_tgammaQHatInvModq.resize(m_numq);
m_tgammaQHatInvModqPrecon.resize(m_numq);
BigInteger bmgamma(m_gamma);
for (uint32_t i = 0; i < m_tgammaQHatInvModq.size(); i++) {
BigInteger qDivqi = Q.DividedBy(moduliQ[i]);
BigInteger imod(moduliQ[i]);
qDivqi = qDivqi.ModInverse(moduliQ[i]);
BigInteger gammaqDivqi = (qDivqi * bmgamma) % imod;
BigInteger tgammaqDivqi = (gammaqDivqi * BigInteger(t)) % imod;
m_tgammaQHatInvModq[i] = tgammaqDivqi.ConvertToInt();
m_tgammaQHatInvModqPrecon[i] = m_tgammaQHatInvModq[i].PrepModMulConst(moduliQ[i]);
}
}
}
uint64_t CryptoParametersBFVRNS::FindAuxPrimeStep() const {
return 2 * GetElementParams()->GetRingDimension();
}
} // namespace lbcrypto