Program Listing for File rns-multiparty.cpp
↰ Return to documentation for file (pke/lib/schemerns/rns-multiparty.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#include "schemerns/rns-multiparty.h"
#include "key/privatekey.h"
#include "key/evalkeyrelin.h"
#include "cryptocontext.h"
#include "schemerns/rns-pke.h"
#include <memory>
#include <vector>
#include <utility>
#include <string>
#include <cstring>
namespace lbcrypto {
Ciphertext<DCRTPoly> MultipartyRNS::MultipartyDecryptLead(ConstCiphertext<DCRTPoly> ciphertext,
const PrivateKey<DCRTPoly> privateKey) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(privateKey->GetCryptoParameters());
const std::vector<DCRTPoly>& cv = ciphertext->GetElements();
const auto ns = cryptoParams->GetNoiseScale();
auto s(privateKey->GetPrivateElement());
size_t sizeQ = s.GetParams()->GetParams().size();
size_t sizeQl = cv[0].GetParams()->GetParams().size();
size_t diffQl = sizeQ - sizeQl;
s.DropLastElements(diffQl);
DCRTPoly noise;
if (cryptoParams->GetMultipartyMode() == NOISE_FLOODING_MULTIPARTY) {
if (sizeQl < 3) {
OPENFHE_THROW("sizeQl " + std::to_string(sizeQl) +
" must be at least 3 in NOISE_FLOODING_MULTIPARTY mode.");
}
DugType dug;
auto params = cv[0].GetParams();
auto cyclOrder = params->GetCyclotomicOrder();
std::vector<NativeInteger> moduliFirst = {params->GetParams()[0]->GetModulus()};
std::vector<NativeInteger> rootsFirst = {params->GetParams()[0]->GetRootOfUnity()};
auto paramsFirst = std::make_shared<ILDCRTParams<BigInteger>>(cyclOrder, moduliFirst, rootsFirst);
std::vector<NativeInteger> moduliAllButFirst(sizeQl - 1);
std::vector<NativeInteger> rootsAllButFirst(sizeQl - 1);
for (size_t i = 1; i < sizeQl; i++) {
moduliAllButFirst[i - 1] = params->GetParams()[i]->GetModulus();
rootsAllButFirst[i - 1] = params->GetParams()[i]->GetRootOfUnity();
}
auto paramsAllButFirst =
std::make_shared<ILDCRTParams<BigInteger>>(cyclOrder, moduliAllButFirst, rootsAllButFirst);
DCRTPoly e(dug, paramsAllButFirst, Format::EVALUATION);
e.ExpandCRTBasisReverseOrder(params, paramsFirst, cryptoParams->GetMultipartyQHatInvModqAtIndex(sizeQl - 2),
cryptoParams->GetMultipartyQHatInvModqPreconAtIndex(sizeQl - 2),
cryptoParams->GetMultipartyQHatModq0AtIndex(sizeQl - 2),
cryptoParams->GetMultipartyAlphaQModq0AtIndex(sizeQl - 2),
cryptoParams->GetMultipartyModq0BarrettMu(), cryptoParams->GetMultipartyQInv(),
Format::EVALUATION);
noise = e;
}
else if (cryptoParams->GetDecryptionNoiseMode() == NOISE_FLOODING_DECRYPT &&
cryptoParams->GetExecutionMode() == EXEC_EVALUATION) {
auto dgg = cryptoParams->GetFloodingDiscreteGaussianGenerator();
DCRTPoly e(dgg, cv[0].GetParams(), Format::EVALUATION);
noise = std::move(e);
}
else {
DggType dgg(NoiseFlooding::MP_SD);
DCRTPoly e(dgg, cv[0].GetParams(), Format::EVALUATION);
noise = std::move(e);
}
// e is added to do noise flooding
DCRTPoly b = cv[0] + s * cv[1] + ns * noise;
auto result = ciphertext->CloneEmpty();
result->SetElement(std::move(b));
return result;
}
Ciphertext<DCRTPoly> MultipartyRNS::MultipartyDecryptMain(ConstCiphertext<DCRTPoly> ciphertext,
const PrivateKey<DCRTPoly> privateKey) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(privateKey->GetCryptoParameters());
const auto ns = cryptoParams->GetNoiseScale();
const std::vector<DCRTPoly>& cv = ciphertext->GetElements();
auto s(privateKey->GetPrivateElement());
size_t sizeQ = s.GetParams()->GetParams().size();
size_t sizeQl = cv[0].GetParams()->GetParams().size();
size_t diffQl = sizeQ - sizeQl;
s.DropLastElements(diffQl);
DCRTPoly noise;
if (cryptoParams->GetMultipartyMode() == NOISE_FLOODING_MULTIPARTY) {
if (sizeQl < 3) {
OPENFHE_THROW("sizeQl " + std::to_string(sizeQl) +
" must be at least 3 in NOISE_FLOODING_MULTIPARTY mode.");
}
DugType dug;
auto params = cv[0].GetParams();
ILDCRTParams<BigInteger> paramsCopy = *params;
paramsCopy.PopFirstParam();
auto paramsAllButFirst = std::make_shared<ILDCRTParams<BigInteger>>(paramsCopy);
DCRTPoly e(dug, paramsAllButFirst, Format::EVALUATION);
auto cyclOrder = params->GetCyclotomicOrder();
std::vector<NativeInteger> moduliFirst = {params->GetParams()[0]->GetModulus()};
std::vector<NativeInteger> rootsFirst = {params->GetParams()[0]->GetRootOfUnity()};
auto paramsFirst = std::make_shared<ILDCRTParams<BigInteger>>(cyclOrder, moduliFirst, rootsFirst);
e.ExpandCRTBasisReverseOrder(params, paramsFirst, cryptoParams->GetMultipartyQHatInvModqAtIndex(sizeQl - 2),
cryptoParams->GetMultipartyQHatInvModqPreconAtIndex(sizeQl - 2),
cryptoParams->GetMultipartyQHatModq0AtIndex(sizeQl - 2),
cryptoParams->GetMultipartyAlphaQModq0AtIndex(sizeQl - 2),
cryptoParams->GetMultipartyModq0BarrettMu(), cryptoParams->GetMultipartyQInv(),
Format::EVALUATION);
noise = e;
}
else if (cryptoParams->GetDecryptionNoiseMode() == NOISE_FLOODING_DECRYPT &&
cryptoParams->GetExecutionMode() == EXEC_EVALUATION) {
auto dgg = cryptoParams->GetFloodingDiscreteGaussianGenerator();
DCRTPoly e(dgg, cv[0].GetParams(), Format::EVALUATION);
noise = std::move(e);
}
else {
DggType dgg(NoiseFlooding::MP_SD);
DCRTPoly e(dgg, cv[0].GetParams(), Format::EVALUATION);
noise = std::move(e);
}
// noise is added to do noise flooding
DCRTPoly b = s * cv[1] + ns * noise;
auto result = ciphertext->CloneEmpty();
result->SetElement(std::move(b));
return result;
}
EvalKey<DCRTPoly> MultipartyRNS::MultiMultEvalKey(PrivateKey<DCRTPoly> privateKey, EvalKey<DCRTPoly> evalKey) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersRNS>(evalKey->GetCryptoContext()->GetCryptoParameters());
const auto ns = cryptoParams->GetNoiseScale();
const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
EvalKey<DCRTPoly> evalKeyResult = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(evalKey->GetCryptoContext());
const std::vector<DCRTPoly>& a0 = evalKey->GetAVector();
const std::vector<DCRTPoly>& b0 = evalKey->GetBVector();
const size_t size = a0.size();
std::vector<DCRTPoly> a;
a.reserve(size);
std::vector<DCRTPoly> b;
b.reserve(size);
if (cryptoParams->GetKeySwitchTechnique() == BV) {
const DCRTPoly& s = privateKey->GetPrivateElement();
const auto& elementParams = s.GetParams();
for (size_t i = 0; i < size; ++i) {
a.push_back(a0[i] * s + ns * DCRTPoly(dgg, elementParams, Format::EVALUATION));
b.push_back(b0[i] * s + ns * DCRTPoly(dgg, elementParams, Format::EVALUATION));
}
}
else {
const auto& paramsQ = cryptoParams->GetElementParams();
const auto& paramsQP = cryptoParams->GetParamsQP();
usint sizeQ = paramsQ->GetParams().size();
usint sizeQP = paramsQP->GetParams().size();
DCRTPoly s = privateKey->GetPrivateElement().Clone();
s.SetFormat(Format::COEFFICIENT);
DCRTPoly sExt(paramsQP, Format::COEFFICIENT, true);
for (usint i = 0; i < sizeQ; i++) {
sExt.SetElementAtIndex(i, s.GetElementAtIndex(i));
}
for (usint j = sizeQ; j < sizeQP; j++) {
NativeInteger pj = paramsQP->GetParams()[j]->GetModulus();
NativeInteger rooti = paramsQP->GetParams()[j]->GetRootOfUnity();
auto sNew0 = s.GetElementAtIndex(0);
sNew0.SwitchModulus(pj, rooti, 0, 0);
sExt.SetElementAtIndex(j, std::move(sNew0));
}
sExt.SetFormat(Format::EVALUATION);
for (usint i = 0; i < size; i++) {
a.push_back(a0[i] * sExt + ns * DCRTPoly(dgg, paramsQP, Format::EVALUATION));
b.push_back(b0[i] * sExt + ns * DCRTPoly(dgg, paramsQP, Format::EVALUATION));
}
}
evalKeyResult->SetAVector(std::move(a));
evalKeyResult->SetBVector(std::move(b));
return evalKeyResult;
}
// Used a subroutine for interactive bootstrapping.
// Takes a polynomial with 2 two towers (RNS limbs)
// For each coefficient, applies the following logic
// If |coefficient| > q/4, then add q/2 to it
// The guarantees that rounded c_0 + c_1 < q/2,
// it prevents an overflow during interactive bootstrapping
void PolynomialRound(DCRTPoly& dcrtpoly) {
const uint32_t NUM_TOWERS = dcrtpoly.GetNumOfElements();
if (2 != NUM_TOWERS) {
OPENFHE_THROW("The input polynomial has " + std::to_string(NUM_TOWERS) + " instead of 2 RNS limbs");
}
std::vector<NativeInteger> q(NUM_TOWERS);
std::vector<NativePoly> poly(NUM_TOWERS);
for (size_t i = 0; i < NUM_TOWERS; i++) {
poly[i] = dcrtpoly.GetElementAtIndex(i);
q[i] = poly[i].GetModulus();
}
std::vector<NativeInteger> qInv(NUM_TOWERS);
qInv[0] = q[1].ModInverse(q[0]);
qInv[1] = q[0].ModInverse(q[1]);
std::vector<NativeInteger> precon(NUM_TOWERS);
for (size_t i = 0; i < NUM_TOWERS; i++) {
precon[i] = qInv[i].PrepModMulConst(q[i]);
}
NativeInteger::DNativeInt Q =
NativeInteger::DNativeInt(q[0].ConvertToInt()) * NativeInteger::DNativeInt(q[1].ConvertToInt());
NativeInteger::DNativeInt Qhalf = Q / 2;
NativeInteger::DNativeInt Q1quart = Q / 4;
NativeInteger::DNativeInt Q3quart = 3 * Q / 4;
std::vector<NativeInteger> qHalf(NUM_TOWERS);
for (size_t i = 0; i < NUM_TOWERS; i++) {
qHalf[i] = Qhalf % q[i].ConvertToInt();
}
// to do the comparison |coefficient[k]| > q/4,
// we compute CRT composition (interpolation) using
// 128-bit integers
for (size_t k = 0; k < dcrtpoly.GetRingDimension(); k++) {
NativeInteger::DNativeInt x128 =
(poly[0][k].ModMulFastConst(qInv[0], q[0], precon[0])).ConvertToInt() * q[1].ConvertToInt();
x128 += (poly[1][k].ModMulFastConst(qInv[1], q[1], precon[1])).ConvertToInt() * q[0].ConvertToInt();
if (x128 > Q)
x128 %= Q;
if ((x128 > Q1quart) && (x128 <= Q3quart)) {
poly[0][k].ModAddFastEq(qHalf[0], q[0]);
poly[1][k].ModAddFastEq(qHalf[1], q[1]);
}
}
dcrtpoly.SetElementAtIndex(0, poly[0]);
dcrtpoly.SetElementAtIndex(1, poly[1]);
}
// Used as a subroutine in interactive bootstrapping.
// Extends a DCRTPoly with 2 RNS limbs (from q) to the full
// RNS basis (to Q). The exact basis extension RNS procedure from
// https://eprint.iacr.org/2018/117 is used.
void ExtendBasis(DCRTPoly& dcrtpoly, const std::shared_ptr<DCRTPoly::Params> paramsQP) {
if (dcrtpoly.GetNumOfElements() != 2) {
OPENFHE_THROW(" The input polynomial should have 2 RNS limbs");
}
const auto paramsQ = dcrtpoly.GetParams();
usint sizeQP = paramsQP->GetParams().size();
usint sizeQ = paramsQ->GetParams().size();
usint sizeP = sizeQP - sizeQ;
// Loads all moduli and roots of unity
std::vector<NativeInteger> moduliQ(sizeQ);
// std::vector<NativeInteger> rootsQ(sizeQ); // TODO (dsuponit): do we need rootsQ?
for (size_t i = 0; i < sizeQ; i++) {
moduliQ[i] = paramsQ->GetParams()[i]->GetModulus();
// rootsQ[i] = paramsQ->GetParams()[i]->GetRootOfUnity();
}
std::vector<NativeInteger> moduliP(sizeP);
std::vector<NativeInteger> rootsP(sizeP);
for (size_t i = 0; i < sizeP; i++) {
moduliP[i] = paramsQP->GetParams()[i + sizeQ]->GetModulus();
rootsP[i] = paramsQP->GetParams()[i + sizeQ]->GetRootOfUnity();
}
auto paramsP = std::make_shared<typename DCRTPoly::Params>(2 * paramsQ->GetRingDimension(), moduliP, rootsP);
// Does all RNS precomputations
std::vector<NativeInteger> QHatInvModq(sizeQ);
std::vector<NativeInteger> QHatInvModqPrecon(sizeQ);
std::vector<std::vector<NativeInteger>> QHatModp(sizeP);
NativeInteger::DNativeInt modulusQ = dcrtpoly.GetModulus().ConvertToInt<NativeInteger::DNativeInt>();
for (usint i = 0; i < sizeQ; i++) {
NativeInteger::DNativeInt qi(moduliQ[i].ConvertToInt());
NativeInteger QHati = modulusQ / qi;
QHatInvModq[i] = QHati.ModInverse(moduliQ[i]).Mod(moduliQ[i]);
QHatInvModqPrecon[i] = QHatInvModq[i].PrepModMulConst(moduliQ[i]);
for (usint j = 0; j < sizeP; j++) {
const NativeInteger& pj = moduliP[j];
QHatModp[j].push_back(QHati.Mod(pj));
}
}
std::vector<std::vector<NativeInteger>> alphaQModp(sizeQ + 1);
for (usint j = 0; j < sizeP; j++) {
NativeInteger::DNativeInt pj(moduliP[j].ConvertToInt());
NativeInteger QModpj = modulusQ % pj;
for (usint i = 0; i < sizeQ + 1; i++) {
alphaQModp[i].push_back(QModpj.ModMul(NativeInteger(i), moduliP[j]));
}
}
const BigInteger BarrettBase128Bit("340282366920938463463374607431768211456"); // 2^128
const BigInteger TwoPower64("18446744073709551616"); // 2^64
// Precomputations for Barrett modulo reduction
std::vector<NativeInteger::DNativeInt> modpBarrettMu(sizeP);
for (uint32_t j = 0; j < sizeP; j++) {
BigInteger mu = BarrettBase128Bit / BigInteger(moduliP[j]);
uint64_t val[2];
val[0] = (mu % TwoPower64).ConvertToInt();
val[1] = mu.RShift(64).ConvertToInt();
memcpy(&modpBarrettMu[j], val, sizeof(NativeInteger::DNativeInt));
}
std::vector<double> qInv(sizeQ);
for (size_t i = 0; i < sizeQ; i++) {
qInv[i] = 1. / static_cast<double>(moduliQ[i].ConvertToInt());
}
// Calls the exact RNS basis extension procedure
dcrtpoly.ExpandCRTBasis(paramsQP, paramsP, QHatInvModq, QHatInvModqPrecon, QHatModp, alphaQModp, modpBarrettMu,
qInv, Format::COEFFICIENT);
}
Ciphertext<DCRTPoly> MultipartyRNS::IntBootDecrypt(const PrivateKey<DCRTPoly> privateKey,
ConstCiphertext<DCRTPoly> ciphertext) const {
const size_t NUM_POLYNOMIALS = ciphertext->NumberCiphertextElements();
if (NUM_POLYNOMIALS != 1 && NUM_POLYNOMIALS != 2) {
std::string msg = "Ciphertext should contain either one or two polynomials. The input ciphertext has " +
std::to_string(NUM_POLYNOMIALS) + ".";
OPENFHE_THROW(msg);
}
std::vector<DCRTPoly> c = ciphertext->GetElements();
for (uint32_t i = 0; i < NUM_POLYNOMIALS; i++)
c[i].SetFormat(Format::EVALUATION);
size_t sizeQl = c[0].GetParams()->GetParams().size();
const DCRTPoly& s = privateKey->GetPrivateElement();
size_t sizeQ = s.GetParams()->GetParams().size();
size_t diffQl = sizeQ - sizeQl;
auto scopy(s);
scopy.DropLastElements(diffQl);
DCRTPoly cs{(NUM_POLYNOMIALS == 1) ? (c[0] * scopy) : (c[1] * scopy + c[0])};
cs.SetFormat(Format::COEFFICIENT);
PolynomialRound(cs);
Ciphertext<DCRTPoly> result = ciphertext->Clone();
result->SetElements({cs});
return result;
}
Ciphertext<DCRTPoly> MultipartyRNS::IntBootEncrypt(const PublicKey<DCRTPoly> publicKey,
ConstCiphertext<DCRTPoly> ctxt) const {
if (ctxt->GetElements().empty()) {
OPENFHE_THROW("No polynomials found in the input ciphertext");
}
using DggType = typename DCRTPoly::DggType;
using TugType = typename DCRTPoly::TugType;
using ParmType = typename DCRTPoly::Params;
const auto cryptoParams =
std::static_pointer_cast<CryptoParametersRLWE<DCRTPoly>>(publicKey->GetCryptoParameters());
DCRTPoly ptxt = ctxt->GetElements()[0];
ptxt.SetFormat(Format::COEFFICIENT);
// changes the modulus from small q (2 RNS limbs) to a large Q to support future computations
ExtendBasis(ptxt, cryptoParams->GetElementParams());
const std::shared_ptr<ParmType> ptxtParams = ptxt.GetParams();
const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
TugType tug;
// Supports both discrete Gaussian (GAUSSIAN) and ternary uniform distribution (UNIFORM_TERNARY) cases
DCRTPoly v = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ? DCRTPoly(dgg, ptxtParams, Format::EVALUATION) :
DCRTPoly(tug, ptxtParams, Format::EVALUATION);
DCRTPoly e0(dgg, ptxtParams, Format::COEFFICIENT);
DCRTPoly e1(dgg, ptxtParams, Format::EVALUATION);
// we add in the coefficient representation to avoid extra NTTs
ptxt += e0;
ptxt.SetFormat(Format::EVALUATION);
const std::vector<DCRTPoly>& pk = publicKey->GetPublicElements();
uint32_t sizeQl = ptxtParams->GetParams().size();
uint32_t sizeQ = pk[0].GetParams()->GetParams().size();
std::vector<DCRTPoly> cv;
if (sizeQl != sizeQ) {
// Clone public keys because we need to drop towers.
DCRTPoly b = pk[0].Clone();
DCRTPoly a = pk[1].Clone();
uint32_t diffQl = sizeQ - sizeQl;
b.DropLastElements(diffQl);
a.DropLastElements(diffQl);
// the error e0 was already added to ptxt
cv.push_back(b * v + ptxt);
cv.push_back(a * v + e1);
}
else {
// Use public keys as they are
const DCRTPoly& b = pk[0];
const DCRTPoly& a = pk[1];
// the error e0 was already added to ptxt
cv.push_back(b * v + ptxt);
cv.push_back(a * v + e1);
}
Ciphertext<DCRTPoly> ciphertext(std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
ciphertext->SetElements(std::move(cv));
// Ciphertext depth, level, and scaling factor should be equal to that of the plaintext.
// However, Encrypt does not take Plaintext as input (only DCRTPoly),
// so we don't have access to these here and we copy them from the input ciphertext.
ciphertext->SetEncodingType(ctxt->GetEncodingType());
ciphertext->SetScalingFactor(ctxt->GetScalingFactor());
ciphertext->SetNoiseScaleDeg(ctxt->GetNoiseScaleDeg());
ciphertext->SetLevel(0);
ciphertext->SetMetadataMap(ctxt->GetMetadataMap());
ciphertext->SetSlots(ctxt->GetSlots());
return ciphertext;
}
Ciphertext<DCRTPoly> MultipartyRNS::IntBootAdd(ConstCiphertext<DCRTPoly> ciphertext1,
ConstCiphertext<DCRTPoly> ciphertext2) const {
if (ciphertext1->GetElements().empty()) {
OPENFHE_THROW("No polynomials found in the input ciphertext1");
}
if (ciphertext2->GetElements().empty()) {
OPENFHE_THROW("No polynomials found in the input ciphertext2");
}
auto elements1 = ciphertext1->GetElements();
auto elements2 = ciphertext2->GetElements();
elements2[0].SetFormat(Format::COEFFICIENT);
const auto cryptoParams =
std::static_pointer_cast<CryptoParametersRLWE<DCRTPoly>>(ciphertext1->GetCryptoParameters());
ExtendBasis(elements2[0], cryptoParams->GetElementParams());
elements2[0].SetFormat(Format::EVALUATION);
elements1[0] += elements2[0];
Ciphertext<DCRTPoly> result = ciphertext1->Clone();
result->SetElements(elements1);
return result;
}
} // namespace lbcrypto