Program Listing for File ckksrns-multiparty.cpp
↰ Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-multiparty.cpp
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
CKKS implementation. See https://eprint.iacr.org/2020/1118 for details.
*/
#define PROFILE
#include "scheme/ckksrns/ckksrns-multiparty.h"
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"
#include "ciphertext.h"
#include "cryptocontext.h"
#include <memory>
namespace lbcrypto {
// {Q} = {q_1,...,q_l}, original RNS basis
// {P} = {p_1,...,p_k}, extended RNS basis
struct RNSExtensionTables {
std::shared_ptr<ILDCRTParams<BigInteger>> paramsQP; // the whole RNS basis
std::shared_ptr<ILDCRTParams<BigInteger>> paramsP; // only the new RNS basis
std::vector<NativeInteger> QHatInvModq; // done
std::vector<NativeInteger> QHatInvModqPrecon; // done
std::vector<std::vector<NativeInteger>> QHatModp; // done
std::vector<std::vector<NativeInteger>> alphaQModp; // done
std::vector<DoubleNativeInt> modpBarrettMu; // done
std::vector<double> qInv; // done
Format resultFormat;
};
DecryptResult MultipartyCKKSRNS::MultipartyDecryptFusion(const std::vector<Ciphertext<DCRTPoly>>& ciphertextVec,
Poly* plaintext) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertextVec[0]->GetCryptoParameters());
const std::vector<DCRTPoly>& cv0 = ciphertextVec[0]->GetElements();
DCRTPoly b = cv0[0];
for (size_t i = 1; i < ciphertextVec.size(); i++) {
const std::vector<DCRTPoly>& cvi = ciphertextVec[i]->GetElements();
b += cvi[0];
}
b.SetFormat(Format::COEFFICIENT);
*plaintext = b.CRTInterpolate();
// size_t sizeQl = b.GetParams()->GetParams().size();
// if (sizeQl > 1) {
// *plaintext = b.CRTInterpolate();
// } else if (sizeQl == 1) {
// *plaintext = Poly(b.GetElementAtIndex(0), Format::COEFFICIENT);
// } else {
// OPENFHE_THROW(
// "Decryption failure: No towers left; consider increasing the depth.");
// }
return DecryptResult(plaintext->GetLength());
}
DecryptResult MultipartyCKKSRNS::MultipartyDecryptFusion(const std::vector<Ciphertext<DCRTPoly>>& ciphertextVec,
NativePoly* plaintext) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertextVec[0]->GetCryptoParameters());
const std::vector<DCRTPoly>& cv0 = ciphertextVec[0]->GetElements();
DCRTPoly b = cv0[0];
for (size_t i = 1; i < ciphertextVec.size(); i++) {
const std::vector<DCRTPoly>& cvi = ciphertextVec[i]->GetElements();
b += cvi[0];
}
b.SetFormat(Format::COEFFICIENT);
// const size_t sizeQl = b.GetParams()->GetParams().size();
// if (sizeQl == 1)
// *plaintext = b.GetElementAtIndex(0);
// else
// OPENFHE_THROW(
// "Decryption failure: No towers left; consider increasing the depth.");
*plaintext = b.GetElementAtIndex(0);
return DecryptResult(plaintext->GetLength());
}
Ciphertext<DCRTPoly> MultipartyCKKSRNS::IntMPBootAdjustScale(ConstCiphertext<DCRTPoly> ciphertext) const {
if (ciphertext->NumberCiphertextElements() == 0) {
std::string msg = "IntMPBootAdjustScale: no polynomials in the input ciphertext.";
OPENFHE_THROW(msg);
}
auto cc = ciphertext->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc->GetCryptoParameters());
auto compressionLevel = cryptoParams->GetMPIntBootCiphertextCompressionLevel();
// Compress ctxt and reduce it to numPrimesToKeep towers
// 1 is for the message itself (assuming 1 tower (60-bit) for msg)
size_t scalingFactorBits = cc->GetEncodingParams()->GetPlaintextModulus();
size_t firstModulusSize =
std::ceil(std::log2(ciphertext->GetElements()[0].GetAllElements()[0].GetParams()->GetModulus().ConvertToInt()));
size_t numTowersToKeep = (scalingFactorBits / firstModulusSize + 1) + compressionLevel;
if (ciphertext->GetElements()[0].GetNumOfElements() < numTowersToKeep) {
std::string msg = std::string(__func__) + ": not enough towers in the input polynomial.";
OPENFHE_THROW(msg);
}
if (cryptoParams->GetScalingTechnique() == ScalingTechnique::FLEXIBLEAUTO ||
cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
auto ciphertextAdjusted = cc->Compress(ciphertext, numTowersToKeep + 1);
uint32_t lvl = cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO ? 0 : 1;
double targetSF = cryptoParams->GetScalingFactorReal(lvl);
double sourceSF = ciphertextAdjusted->GetScalingFactor();
uint32_t numTowers = ciphertextAdjusted->GetElements()[0].GetNumOfElements();
double modToDrop = cryptoParams->GetElementParams()->GetParams()[numTowers - 1]->GetModulus().ConvertToDouble();
double adjustmentFactor = (targetSF / sourceSF) * (modToDrop / sourceSF);
ciphertextAdjusted = cc->EvalMult(ciphertextAdjusted, adjustmentFactor);
cc->GetScheme()->ModReduceInternalInPlace(ciphertextAdjusted, 1);
ciphertextAdjusted->SetScalingFactor(targetSF);
return ciphertextAdjusted;
}
else {
return cc->Compress(ciphertext, numTowersToKeep);
}
}
Ciphertext<DCRTPoly> MultipartyCKKSRNS::IntMPBootRandomElementGen(std::shared_ptr<CryptoParametersCKKSRNS> params,
const PublicKey<DCRTPoly> publicKey) const {
auto ildcrtparams = params->GetElementParams();
typename DCRTPoly::DugType dug;
DCRTPoly crp(dug, ildcrtparams);
crp.SetFormat(Format::EVALUATION);
Ciphertext<DCRTPoly> outCtxt(std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
outCtxt->SetElements({std::move(crp)});
return outCtxt;
}
// Subroutines for Interactive Multi-Party Bootstrapping
// Calculating RNS parameters
void PrecomputeRNSExtensionTables(CryptoContext<DCRTPoly>& cc, usint from, usint to, RNSExtensionTables& rnsExtTables) {
std::vector<NativeInteger> moduliQ;
moduliQ.reserve(from);
std::vector<NativeInteger> rootsQ;
rootsQ.reserve(from);
std::vector<NativeInteger> moduliP;
moduliP.reserve(to - from);
std::vector<NativeInteger> rootsP;
rootsP.reserve(to - from);
for (size_t i = 0; i < from; i++) {
moduliQ.push_back(cc->GetCryptoParameters()->GetElementParams()->GetParams()[i]->GetModulus());
rootsQ.push_back(cc->GetCryptoParameters()->GetElementParams()->GetParams()[i]->GetRootOfUnity());
}
for (size_t i = from; i < to; i++) {
moduliP.push_back(cc->GetCryptoParameters()->GetElementParams()->GetParams()[i]->GetModulus());
rootsP.push_back(cc->GetCryptoParameters()->GetElementParams()->GetParams()[i]->GetRootOfUnity());
}
size_t sizeQ = moduliQ.size();
size_t sizeP = moduliP.size();
BigInteger modulusQ(1);
for (auto& it : moduliQ)
modulusQ *= it;
std::vector<NativeInteger> moduliQP(sizeQ + sizeP);
std::vector<NativeInteger> rootsQP(sizeQ + sizeP);
// populate moduli for CRT basis Q
for (size_t i = 0; i < sizeQ; i++) {
moduliQP[i] = moduliQ[i];
rootsQP[i] = rootsQ[i];
}
// populate moduli for CRT basis P
for (size_t j = 0; j < sizeP; j++) {
moduliQP[sizeQ + j] = moduliP[j];
rootsQP[sizeQ + j] = rootsP[j];
}
usint ringDim = cc->GetCryptoParameters()->GetElementParams()->GetRingDimension();
rnsExtTables.paramsP = std::make_shared<ILDCRTParams<BigInteger>>(2 * ringDim, moduliP, rootsP);
rnsExtTables.paramsQP = std::make_shared<ILDCRTParams<BigInteger>>(2 * ringDim, moduliQP, rootsQP);
rnsExtTables.QHatInvModq.resize(sizeQ);
rnsExtTables.QHatInvModqPrecon.resize(sizeQ);
for (usint i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
BigInteger QHati = modulusQ / qi;
rnsExtTables.QHatInvModq[i] = QHati.ModInverse(qi).Mod(qi).ConvertToInt();
rnsExtTables.QHatInvModqPrecon[i] = rnsExtTables.QHatInvModq[i].PrepModMulConst(qi.ConvertToInt());
}
// compute the [Q/q_i]_{p_j}
// used for homomorphic multiplication
rnsExtTables.QHatModp.resize(sizeP, std::vector<NativeInteger>(sizeQ));
for (usint j = 0; j < sizeP; j++) {
BigInteger pj(moduliP[j].ConvertToInt());
for (usint i = 0; i < sizeQ; i++) {
BigInteger qi(moduliQ[i].ConvertToInt());
BigInteger QHati = modulusQ / qi;
rnsExtTables.QHatModp[j][i] = QHati.Mod(pj).ConvertToInt();
}
}
// compute the [\alpha*Q]p_j for 0 <= alpha <= sizeQ
// used for homomorphic multiplication
rnsExtTables.alphaQModp.resize(sizeQ + 1, std::vector<NativeInteger>(sizeP));
for (usint j = 0; j < sizeP; j++) {
BigInteger pj(moduliP[j].ConvertToInt());
NativeInteger QModpj = modulusQ.Mod(pj).ConvertToInt();
for (usint i = 0; i < sizeQ + 1; i++) {
rnsExtTables.alphaQModp[i][j] = QModpj.ModMul(NativeInteger(i), moduliP[j]);
}
}
const auto BarrettBase128Bit(BigInteger(1).LShiftEq(128));
rnsExtTables.modpBarrettMu.resize(sizeP);
for (uint32_t j = 0; j < moduliP.size(); j++) {
rnsExtTables.modpBarrettMu[j] = (BarrettBase128Bit / BigInteger(moduliP[j])).ConvertToInt<DoubleNativeInt>();
}
rnsExtTables.qInv.resize(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
rnsExtTables.qInv[i] = 1. / static_cast<double>(moduliQ[i].ConvertToInt());
}
}
// Utility function to compute noisy multiplication ( sk * poly + noise )
// noise will not be added if IsZeroNoise is set to true (as in computing h_0,i)
DCRTPoly ComputeNoisyMult(CryptoContext<DCRTPoly>& cc, const DCRTPoly& sk, const DCRTPoly& poly, bool IsZeroNoise) {
if (sk.GetNumOfElements() != poly.GetNumOfElements()) {
std::string errMsg = "ERROR: Number of towers in input polys does not match!";
OPENFHE_THROW(errMsg);
}
DCRTPoly res = sk * poly;
if (false == IsZeroNoise) {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(cc->GetCryptoParameters());
const DCRTPoly::DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
auto paramsq = poly.GetParams();
DCRTPoly e(dgg, paramsq, Format::EVALUATION);
res = res + e;
}
return res;
}
// Generate random mask
DCRTPoly GenerateMi(const DCRTPoly& c1, uint32_t maskBoundNumTowers) {
auto c1Copy = c1;
// drop twoers until we reach maskBoundNumTowers
c1Copy.DropLastElements(c1Copy.GetAllElements().size() - maskBoundNumTowers);
auto& ildcrtparams = c1Copy.GetParams();
typename DCRTPoly::DugType dug;
DCRTPoly Mi(dug, ildcrtparams, Format::EVALUATION);
return Mi;
}
// Compute h_{0,i}
DCRTPoly GenerateMaskedDecryptionShare(CryptoContext<DCRTPoly>& cc, const PrivateKey<DCRTPoly> privateKey,
const DCRTPoly& c1, DCRTPoly& Mi, uint32_t compressionLevel) {
DCRTPoly sk = privateKey->GetPrivateElement();
// reduce sk's numeTowers to c1's numTowers
sk.DropLastElements(sk.GetAllElements().size() - c1.GetAllElements().size());
DCRTPoly maskedDecryptionShare = ComputeNoisyMult(cc, sk, c1, true);
DCRTPoly MiCopy = Mi;
// Init RNS parameters - we generate these params online as of now - should be cheap
// Extending Mi parameters:
RNSExtensionTables MiForDecryptionShareRNSExtTables; // extending Mi from R_t to R_q
PrecomputeRNSExtensionTables(cc, compressionLevel, c1.GetAllElements().size(), MiForDecryptionShareRNSExtTables);
MiCopy.ExpandCRTBasis(MiForDecryptionShareRNSExtTables.paramsQP, MiForDecryptionShareRNSExtTables.paramsP,
MiForDecryptionShareRNSExtTables.QHatInvModq,
MiForDecryptionShareRNSExtTables.QHatInvModqPrecon, MiForDecryptionShareRNSExtTables.QHatModp,
MiForDecryptionShareRNSExtTables.alphaQModp, MiForDecryptionShareRNSExtTables.modpBarrettMu,
MiForDecryptionShareRNSExtTables.qInv, EVALUATION);
maskedDecryptionShare = maskedDecryptionShare - MiCopy;
return maskedDecryptionShare;
}
// Compute h_{1,i}
DCRTPoly GenerateReEncryptionShare(CryptoContext<DCRTPoly>& cc, const PrivateKey<DCRTPoly> privateKey,
ConstCiphertext<DCRTPoly> a, DCRTPoly& Mi, uint32_t compressionLevel) {
DCRTPoly sk = privateKey->GetPrivateElement();
auto negsk = sk.Negate();
DCRTPoly reEncryptionShare = ComputeNoisyMult(cc, negsk, a->GetElements()[0], false);
DCRTPoly MiCopy = Mi;
// Init RNS parameters - we generate these params online as of now - should be cheap
// Extending Mi parameters:
RNSExtensionTables MiForReEncryptionShareRNSExtTables; // extending Mi from R_t to R_Q
PrecomputeRNSExtensionTables(cc, compressionLevel, a->GetElements()[0].GetAllElements().size(),
MiForReEncryptionShareRNSExtTables);
MiCopy.ExpandCRTBasis(
MiForReEncryptionShareRNSExtTables.paramsQP, MiForReEncryptionShareRNSExtTables.paramsP,
MiForReEncryptionShareRNSExtTables.QHatInvModq, MiForReEncryptionShareRNSExtTables.QHatInvModqPrecon,
MiForReEncryptionShareRNSExtTables.QHatModp, MiForReEncryptionShareRNSExtTables.alphaQModp,
MiForReEncryptionShareRNSExtTables.modpBarrettMu, MiForReEncryptionShareRNSExtTables.qInv, EVALUATION);
reEncryptionShare = reEncryptionShare + MiCopy;
return reEncryptionShare;
}
std::vector<Ciphertext<DCRTPoly>> MultipartyCKKSRNS::IntMPBootDecrypt(const PrivateKey<DCRTPoly> privateKey,
ConstCiphertext<DCRTPoly> ciphertext,
ConstCiphertext<DCRTPoly> a) const {
// Generate maskedDecryptionShares: secretShare M_i and publicShare: s_i*c_1+e_{0,i} to compute h_{0,i}
// Generate secretShare M_i \in R_{q*2^{\lambda}} where lambda is the security level
// Calculate publicShare s_i*c_1 + e_{0,i} in R_{q*2^{\lambda}}
// Calculate h_{0,i} = publicShare - secretShare
auto cc = ciphertext->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc->GetCryptoParameters());
auto compressionLevel = cryptoParams->GetMPIntBootCiphertextCompressionLevel();
auto& c1 = ciphertext->GetElements()[0]; // input ctxt must only include one element which is c1
DCRTPoly Mi = GenerateMi(c1, compressionLevel); // Mi is in NTT domain
// Encryption to Share protocol to compute: h_{0,i}
DCRTPoly mdsp = GenerateMaskedDecryptionShare(cc, privateKey, c1, Mi, compressionLevel);
Ciphertext<DCRTPoly> maskedDecryptionShare(std::make_shared<CiphertextImpl<DCRTPoly>>(privateKey));
maskedDecryptionShare->SetElements({std::move(mdsp)});
// Generate reEncryptionShares: secretShare M_i (no need to recompute, use M_i from above)
// and publicShare: -s_i*a + e_{1,i} in R_{Q}
// Get screteShare M_i
// Calculate publicShare: -s_i*a + e_{1,i}
// // Calculate h_{1,i} = publicShare + secretShare
// Shares to Encryption protocol to compute h_{1,i}
DCRTPoly rsp = GenerateReEncryptionShare(cc, privateKey, a, Mi, compressionLevel);
Ciphertext<DCRTPoly> reEncryptionShare(std::make_shared<CiphertextImpl<DCRTPoly>>(privateKey));
reEncryptionShare->SetElements({std::move(rsp)});
std::vector<Ciphertext<DCRTPoly>> result = {maskedDecryptionShare, reEncryptionShare};
return result;
}
std::vector<Ciphertext<DCRTPoly>> MultipartyCKKSRNS::IntMPBootAdd(
std::vector<std::vector<Ciphertext<DCRTPoly>>>& sharesPairVec) const {
if (sharesPairVec.size() == 0) {
std::string msg = "IntMPBootAdd: no polynomials in input share(s).";
OPENFHE_THROW(msg);
}
std::vector<Ciphertext<DCRTPoly>> result = sharesPairVec[0];
for (size_t i = 1; i < sharesPairVec.size(); i++) {
// h_0 = h_0,0 + h_0,i
result[0]->GetElements()[0] = result[0]->GetElements()[0] + sharesPairVec[i][0]->GetElements()[0];
// h_1 = h_1,0 + h_1,i
result[1]->GetElements()[0] = result[1]->GetElements()[0] + sharesPairVec[i][1]->GetElements()[0];
}
return result;
}
Ciphertext<DCRTPoly> MultipartyCKKSRNS::IntMPBootEncrypt(const PublicKey<DCRTPoly> publicKey,
const std::vector<Ciphertext<DCRTPoly>>& sharesPair,
ConstCiphertext<DCRTPoly> a,
ConstCiphertext<DCRTPoly> ciphertext) const {
if (ciphertext->NumberCiphertextElements() == 0) {
std::string msg = "IntMPBootEncrypt: no polynomials in the input ciphertext.";
OPENFHE_THROW(msg);
}
auto cc = ciphertext->GetCryptoContext();
DCRTPoly c0Prime = ciphertext->GetElements()[0] + sharesPair[0]->GetElements()[0];
// Init RNS parameters - we generate these params online as of now - should be cheap
// Extending Mi parameters:
RNSExtensionTables C0ForReEncryptRNSExtTables; // extending c0 from R_q to R_Q
PrecomputeRNSExtensionTables(cc, c0Prime.GetAllElements().size(), a->GetElements()[0].GetAllElements().size(),
C0ForReEncryptRNSExtTables);
c0Prime.ExpandCRTBasis(C0ForReEncryptRNSExtTables.paramsQP, C0ForReEncryptRNSExtTables.paramsP,
C0ForReEncryptRNSExtTables.QHatInvModq, C0ForReEncryptRNSExtTables.QHatInvModqPrecon,
C0ForReEncryptRNSExtTables.QHatModp, C0ForReEncryptRNSExtTables.alphaQModp,
C0ForReEncryptRNSExtTables.modpBarrettMu, C0ForReEncryptRNSExtTables.qInv, EVALUATION);
c0Prime = c0Prime + sharesPair[1]->GetElements()[0];
Ciphertext<DCRTPoly> outCtxt(std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
outCtxt->SetElements({std::move(c0Prime), std::move(a->GetElements()[0])});
// Ciphertext depth, level, and scaling factor should be
// equal to that of the plaintext. However, Encrypt does
// not take Plaintext as input (only DCRTPoly), so we
// don't have access to these here and we copy them
// from the input ciphertext.
outCtxt->SetEncodingType(ciphertext->GetEncodingType());
outCtxt->SetScalingFactor(ciphertext->GetScalingFactor());
outCtxt->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg());
outCtxt->SetLevel(0);
outCtxt->SetMetadataMap(ciphertext->GetMetadataMap());
outCtxt->SetSlots(ciphertext->GetSlots());
return outCtxt;
}
} // namespace lbcrypto