Program Listing for File rns-cryptoparameters.cpp
↰ Return to documentation for file (pke/lib/schemerns/rns-cryptoparameters.cpp
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#define PROFILE
#include "math/dftransform.h"
#include "cryptocontext.h"
#include "schemerns/rns-cryptoparameters.h"
namespace lbcrypto {
void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, ScalingTechnique scalTech,
EncryptionTechnique encTech, MultiplicationTechnique multTech,
uint32_t numPartQ, uint32_t auxBits, uint32_t extraBits) {
// Set the key switching technique.
m_ksTechnique = ksTech;
// Set the scaling technique.
m_scalTechnique = scalTech;
// Set the key encryption technique.
m_encTechnique = encTech;
// Set the multiplication technique.
m_multTechnique = multTech;
// Set number of digits in HYBRID
m_numPartQ = numPartQ;
// Set auxiliary primes bit size in HYBRID
m_auxBits = auxBits;
// Set number of extraBits for lower error
m_extraBits = extraBits;
size_t sizeQ = GetElementParams()->GetParams().size();
size_t n = GetElementParams()->GetRingDimension();
// Construct moduliQ and rootsQ from crypto parameters
std::vector<NativeInteger> moduliQ(sizeQ);
std::vector<NativeInteger> rootsQ(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
moduliQ[i] = GetElementParams()->GetParams()[i]->GetModulus();
rootsQ[i] = GetElementParams()->GetParams()[i]->GetRootOfUnity();
}
// Pre-compute CRT::FFT values for Q
DiscreteFourierTransform::Initialize(n * 2, n / 2);
ChineseRemainderTransformFTT<NativeVector>().PreCompute(rootsQ, 2 * n, moduliQ);
if (m_ksTechnique == HYBRID) {
// Compute ceil(sizeQ/m_numPartQ), the # of towers per digit
uint32_t a = ceil(static_cast<double>(sizeQ) / numPartQ);
if ((int32_t)(sizeQ - a * (numPartQ - 1)) <= 0) {
auto str =
"CryptoParametersRNS::PrecomputeCRTTables - HYBRID key "
"switching parameters: Can't appropriately distribute " +
std::to_string(sizeQ) + " towers into " + std::to_string(numPartQ) +
" digits. Please select different number of digits.";
OPENFHE_THROW(str);
}
m_numPerPartQ = a;
// Compute the composite digits PartQ = Q_j
std::vector<BigInteger> moduliPartQ;
moduliPartQ.resize(m_numPartQ);
for (usint j = 0; j < m_numPartQ; j++) {
moduliPartQ[j] = BigInteger(1);
for (usint i = a * j; i < (j + 1) * a; i++) {
if (i < moduliQ.size())
moduliPartQ[j] *= moduliQ[i];
}
}
// Compute PartQHat_i = Q/Q_j
std::vector<BigInteger> PartQHat;
PartQHat.resize(m_numPartQ);
for (size_t i = 0; i < m_numPartQ; i++) {
PartQHat[i] = BigInteger(1);
for (size_t j = 0; j < m_numPartQ; j++) {
if (j != i)
PartQHat[i] *= moduliPartQ[j];
}
}
// Compute partitions of Q into numPartQ digits
m_paramsPartQ.resize(m_numPartQ);
for (uint32_t j = 0; j < m_numPartQ; j++) {
auto startTower = j * a;
auto endTower = ((j + 1) * a - 1 < sizeQ) ? (j + 1) * a - 1 : sizeQ - 1;
std::vector<std::shared_ptr<ILNativeParams>> params =
GetElementParams()->GetParamPartition(startTower, endTower);
std::vector<NativeInteger> moduli(params.size());
std::vector<NativeInteger> roots(params.size());
for (uint32_t i = 0; i < params.size(); i++) {
moduli[i] = params[i]->GetModulus();
roots[i] = params[i]->GetRootOfUnity();
}
m_paramsPartQ[j] =
std::make_shared<ILDCRTParams<BigInteger>>(params[0]->GetCyclotomicOrder(), moduli, roots);
}
uint32_t sizeP;
// Find number and size of individual special primes.
uint32_t maxBits = moduliPartQ[0].GetLengthForBase(2);
for (usint j = 1; j < m_numPartQ; j++) {
uint32_t bits = moduliPartQ[j].GetLengthForBase(2);
if (bits > maxBits)
maxBits = bits;
}
// Select number of primes in auxiliary CRT basis
sizeP = ceil(static_cast<double>(maxBits) / auxBits);
uint64_t primeStep = FindAuxPrimeStep();
// Choose special primes in auxiliary basis and compute their roots
// moduliP holds special primes p1, p2, ..., pk
// m_modulusP holds the product of special primes P = p1*p2*...pk
std::vector<NativeInteger> moduliP(sizeP);
std::vector<NativeInteger> rootsP(sizeP);
// firstP contains a prime whose size is PModSize.
NativeInteger firstP = FirstPrime<NativeInteger>(auxBits, primeStep);
NativeInteger pPrev = firstP;
BigInteger modulusP(1);
for (usint i = 0; i < sizeP; i++) {
// The following loop makes sure that moduli in
// P and Q are different
bool foundInQ = false;
do {
moduliP[i] = PreviousPrime<NativeInteger>(pPrev, primeStep);
foundInQ = false;
for (usint j = 0; j < sizeQ; j++)
if (moduliP[i] == moduliQ[j])
foundInQ = true;
pPrev = moduliP[i];
} while (foundInQ);
rootsP[i] = RootOfUnity<NativeInteger>(2 * n, moduliP[i]);
modulusP *= moduliP[i];
pPrev = moduliP[i];
}
// Store the created moduli and roots in m_paramsP
m_paramsP = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliP, rootsP);
// Create the moduli and roots for the extended CRT basis QP
std::vector<NativeInteger> moduliQP(sizeQ + sizeP);
std::vector<NativeInteger> rootsQP(sizeQ + sizeP);
for (size_t i = 0; i < sizeQ; i++) {
moduliQP[i] = moduliQ[i];
rootsQP[i] = rootsQ[i];
}
for (size_t i = 0; i < sizeP; i++) {
moduliQP[sizeQ + i] = moduliP[i];
rootsQP[sizeQ + i] = rootsP[i];
}
m_paramsQP = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQP, rootsQP);
// Pre-compute CRT::FFT values for P
ChineseRemainderTransformFTT<NativeVector>().PreCompute(rootsP, 2 * n, moduliP);
// Pre-compute values [P]_{q_i}
m_PModq.resize(sizeQ);
for (usint i = 0; i < sizeQ; i++) {
m_PModq[i] = modulusP.Mod(moduliQ[i]).ConvertToInt();
}
// Pre-compute values [P^{-1}]_{q_i}
m_PInvModq.resize(sizeQ);
m_PInvModqPrecon.resize(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
BigInteger PInvModqi = modulusP.ModInverse(moduliQ[i]);
m_PInvModq[i] = PInvModqi.ConvertToInt();
m_PInvModqPrecon[i] = m_PInvModq[i].PrepModMulConst(moduliQ[i]);
}
// Pre-compute values [P/p_j]_{q_i}
// Pre-compute values [(P/p_j)^{-1}]_{p_j}
m_PHatInvModp.resize(sizeP);
m_PHatInvModpPrecon.resize(sizeP);
m_PHatModq.resize(sizeP);
for (size_t j = 0; j < sizeP; j++) {
BigInteger PHatj = modulusP / BigInteger(moduliP[j]);
BigInteger PHatInvModpj = PHatj.ModInverse(moduliP[j]);
m_PHatInvModp[j] = PHatInvModpj.ConvertToInt();
m_PHatInvModpPrecon[j] = m_PHatInvModp[j].PrepModMulConst(moduliP[j]);
m_PHatModq[j].resize(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
BigInteger PHatModqji = PHatj.Mod(moduliQ[i]);
m_PHatModq[j][i] = PHatModqji.ConvertToInt();
}
}
BigInteger modulusQ = GetElementParams()->GetModulus();
// Pre-compute values [Q/q_i]_{p_j}
// Pre-compute values [(Q/q_i)^{-1}]_{q_i}
m_QlHatInvModq.resize(sizeQ);
m_QlHatInvModqPrecon.resize(sizeQ);
// l will run from 0 to size-2, but modulusQ values
// run from Q^(l-1) to Q^(0)
for (size_t l = 0; l < sizeQ; l++) {
if (l > 0)
modulusQ = modulusQ / BigInteger(moduliQ[sizeQ - l]);
m_QlHatInvModq[sizeQ - l - 1].resize(sizeQ - l);
m_QlHatInvModqPrecon[sizeQ - l - 1].resize(sizeQ - l);
for (size_t i = 0; i < sizeQ - l; i++) {
BigInteger QHati = modulusQ / BigInteger(moduliQ[i]);
BigInteger QHatInvModqi = QHati.ModInverse(moduliQ[i]);
m_QlHatInvModq[sizeQ - l - 1][i] = QHatInvModqi.ConvertToInt();
m_QlHatInvModqPrecon[sizeQ - l - 1][i] = m_QlHatInvModq[sizeQ - l - 1][i].PrepModMulConst(moduliQ[i]);
}
}
// Pre-compute compementary partitions for ModUp
uint32_t alpha = ceil(static_cast<double>(sizeQ) / m_numPartQ);
m_paramsComplPartQ.resize(sizeQ);
m_modComplPartqBarrettMu.resize(sizeQ);
for (int32_t l = sizeQ - 1; l >= 0; l--) {
uint32_t beta = ceil(static_cast<double>(l + 1) / alpha);
m_paramsComplPartQ[l].resize(beta);
m_modComplPartqBarrettMu[l].resize(beta);
for (uint32_t j = 0; j < beta; j++) {
const std::shared_ptr<ILDCRTParams<BigInteger>> digitPartition = GetParamsPartQ(j);
auto cyclOrder = digitPartition->GetCyclotomicOrder();
uint32_t sizePartQj = digitPartition->GetParams().size();
if (j == beta - 1)
sizePartQj = (l + 1) - j * alpha;
uint32_t sizeComplPartQj = (l + 1) - sizePartQj + sizeP;
std::vector<NativeInteger> moduli(sizeComplPartQj);
std::vector<NativeInteger> roots(sizeComplPartQj);
for (uint32_t k = 0; k < sizeComplPartQj; k++) {
if (k < (l + 1) - sizePartQj) {
uint32_t currDigit = k / alpha;
if (currDigit >= j)
currDigit++;
moduli[k] = GetParamsPartQ(currDigit)->GetParams()[k % alpha]->GetModulus();
roots[k] = GetParamsPartQ(currDigit)->GetParams()[k % alpha]->GetRootOfUnity();
}
else {
moduli[k] = moduliP[k - ((l + 1) - sizePartQj)];
roots[k] = rootsP[k - ((l + 1) - sizePartQj)];
}
}
m_paramsComplPartQ[l][j] = std::make_shared<ParmType>(cyclOrder, moduli, roots);
const auto BarrettBase128Bit(BigInteger(1).LShiftEq(128));
m_modComplPartqBarrettMu[l][j].resize(moduli.size());
for (uint32_t i = 0; i < moduli.size(); i++) {
m_modComplPartqBarrettMu[l][j][i] =
(BarrettBase128Bit / BigInteger(moduli[i])).ConvertToInt<DoubleNativeInt>();
}
}
}
// Pre-compute values [Q^(l)_j/q_i)^{-1}]_{q_i}
m_PartQlHatInvModq.resize(m_numPartQ);
m_PartQlHatInvModqPrecon.resize(m_numPartQ);
for (uint32_t k = 0; k < m_numPartQ; k++) {
auto params = m_paramsPartQ[k]->GetParams();
uint32_t sizePartQk = params.size();
m_PartQlHatInvModq[k].resize(sizePartQk);
m_PartQlHatInvModqPrecon[k].resize(sizePartQk);
auto modulusPartQ = m_paramsPartQ[k]->GetModulus();
for (size_t l = 0; l < sizePartQk; l++) {
if (l > 0)
modulusPartQ = modulusPartQ / BigInteger(params[sizePartQk - l]->GetModulus());
m_PartQlHatInvModq[k][sizePartQk - l - 1].resize(sizePartQk - l);
m_PartQlHatInvModqPrecon[k][sizePartQk - l - 1].resize(sizePartQk - l);
for (size_t i = 0; i < sizePartQk - l; i++) {
BigInteger QHat = modulusPartQ / BigInteger(params[i]->GetModulus());
BigInteger QHatInvModqi = QHat.ModInverse(params[i]->GetModulus());
m_PartQlHatInvModq[k][sizePartQk - l - 1][i] = QHatInvModqi.ConvertToInt();
m_PartQlHatInvModqPrecon[k][sizePartQk - l - 1][i] =
m_PartQlHatInvModq[k][sizePartQk - l - 1][i].PrepModMulConst(params[i]->GetModulus());
}
}
}
// Pre-compute QHat mod complementary partition qi's
m_PartQlHatModp.resize(sizeQ);
for (uint32_t l = 0; l < sizeQ; l++) {
uint32_t alpha = ceil(static_cast<double>(sizeQ) / m_numPartQ);
uint32_t beta = ceil(static_cast<double>(l + 1) / alpha);
m_PartQlHatModp[l].resize(beta);
for (uint32_t k = 0; k < beta; k++) {
auto paramsPartQ = GetParamsPartQ(k)->GetParams();
auto partQ = GetParamsPartQ(k)->GetModulus();
uint32_t digitSize = paramsPartQ.size();
if (k == beta - 1) {
digitSize = l + 1 - k * alpha;
for (uint32_t idx = digitSize; idx < paramsPartQ.size(); idx++) {
partQ = partQ / BigInteger(paramsPartQ[idx]->GetModulus());
}
}
m_PartQlHatModp[l][k].resize(digitSize);
for (uint32_t i = 0; i < digitSize; i++) {
BigInteger partQHat = partQ / BigInteger(paramsPartQ[i]->GetModulus());
auto complBasis = GetParamsComplPartQ(l, k);
m_PartQlHatModp[l][k][i].resize(complBasis->GetParams().size());
for (size_t j = 0; j < complBasis->GetParams().size(); j++) {
BigInteger QHatModpj = partQHat.Mod(complBasis->GetParams()[j]->GetModulus());
m_PartQlHatModp[l][k][i][j] = QHatModpj.ConvertToInt();
}
}
}
}
}
// BFVrns and BGVrns : Multiparty Decryption : ExpandCRTBasis
if (GetMultipartyMode() == NOISE_FLOODING_MULTIPARTY) {
// Pre-compute values [*(Q/q_i/q_0)^{-1}]_{q_i}
BigInteger modulusQ = BigInteger(GetElementParams()->GetModulus()) / BigInteger(moduliQ[0]);
m_multipartyQHatInvModq.resize(sizeQ - 1);
m_multipartyQHatInvModqPrecon.resize(sizeQ - 1);
m_multipartyQHatModq0.resize(sizeQ - 1);
// l will run from 0 to size-2, but modulusQ values
// run from Q^(l-1) to Q^(0)
for (size_t l = 0, m = sizeQ - l - 2; l < sizeQ - 1; ++l, --m) {
if (l > 0)
modulusQ = modulusQ / BigInteger(moduliQ[sizeQ - l]);
m_multipartyQHatInvModq[m].resize(m + 1);
m_multipartyQHatInvModqPrecon[m].resize(m + 1);
m_multipartyQHatModq0[m].resize(1);
m_multipartyQHatModq0[m][0].resize(m + 1);
for (size_t i = 1; i < m + 2; i++) {
BigInteger QHati = modulusQ / BigInteger(moduliQ[i]);
BigInteger QHatInvModqi = QHati.ModInverse(moduliQ[i]);
m_multipartyQHatInvModq[m][i - 1] = QHatInvModqi.ConvertToInt();
m_multipartyQHatInvModqPrecon[m][i - 1] = m_multipartyQHatInvModq[m][i - 1].PrepModMulConst(moduliQ[i]);
m_multipartyQHatModq0[m][0][i - 1] = QHati.Mod(moduliQ[0]);
}
}
modulusQ = BigInteger(GetElementParams()->GetModulus()) / BigInteger(moduliQ[0]);
m_multipartyAlphaQModq0.resize(sizeQ - 1);
for (usint l = sizeQ - 1; l > 0; l--) {
if (l < sizeQ - 1)
modulusQ = modulusQ / BigInteger(moduliQ[l + 1]);
m_multipartyAlphaQModq0[l - 1].resize(l + 1);
NativeInteger QlModq0 = modulusQ.Mod(moduliQ[0]).ConvertToInt();
for (usint j = 0; j < l + 1; ++j) {
m_multipartyAlphaQModq0[l - 1][j] = {QlModq0.ModMul(NativeInteger(j), moduliQ[0])};
}
}
const auto BarrettBase128Bit(BigInteger(1).LShiftEq(128));
m_multipartyModq0BarrettMu.resize(1);
m_multipartyModq0BarrettMu[0] = (BarrettBase128Bit / BigInteger(moduliQ[0])).ConvertToInt<DoubleNativeInt>();
// Stores \frac{1/q_i}
m_multipartyQInv.resize(sizeQ - 1);
for (size_t i = 1; i < sizeQ; i++) {
m_multipartyQInv[i - 1] = 1. / static_cast<double>(moduliQ[i].ConvertToInt());
}
}
}
uint64_t CryptoParametersRNS::FindAuxPrimeStep() const {
return GetElementParams()->GetRingDimension();
}
} // namespace lbcrypto