Program Listing for File base-multiparty.cpp
↰ Return to documentation for file (pke/lib/schemebase/base-multiparty.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#include "cryptocontext.h"
#include "key/evalkey.h"
#include "key/evalkeyrelin.h"
#include "key/privatekey.h"
#include "key/publickey.h"
#include "schemebase/base-multiparty.h"
#include "schemebase/base-pke.h"
#include "schemebase/base-scheme.h"
#include "schemebase/rlwe-cryptoparameters.h"
#include <iostream>
#include <map>
#include <memory>
#include <utility>
#include <vector>
namespace lbcrypto {
// makeSparse is not used by this scheme
template <class Element>
KeyPair<Element> MultipartyBase<Element>::MultipartyKeyGen(CryptoContext<Element> cc,
const std::vector<PrivateKey<Element>>& privateKeyVec,
bool makeSparse) {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(cc->GetCryptoParameters());
const auto elementParams = cryptoParams->GetElementParams();
// Private Key Generation
Element s(elementParams, Format::EVALUATION, true);
for (auto& pk : privateKeyVec)
s += pk->GetPrivateElement();
// Public Key Generation
DugType dug;
Element a(dug, elementParams, Format::EVALUATION);
Element e(cryptoParams->GetDiscreteGaussianGenerator(), elementParams, Format::EVALUATION);
NativeInteger ns = cryptoParams->GetNoiseScale();
// b = ns * e - a * s
Element b(std::move((e *= ns) -= (a * s)));
KeyPair<Element> keyPair(std::make_shared<PublicKeyImpl<Element>>(cc),
std::make_shared<PrivateKeyImpl<Element>>(cc));
keyPair.secretKey->SetPrivateElement(std::move(s));
keyPair.publicKey->SetPublicElements({std::move(b), std::move(a)});
return keyPair;
}
template <class Element>
KeyPair<Element> MultipartyBase<Element>::MultipartyKeyGen(CryptoContext<Element> cc,
const PublicKey<Element> publicKey, bool makeSparse,
bool fresh) {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(cc->GetCryptoParameters());
const auto elementParams = cryptoParams->GetElementParams();
const auto paramsPK = cryptoParams->GetParamsPK();
if (!paramsPK)
OPENFHE_THROW("PrecomputeCRTTables() must be called before using precomputed params.");
const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
TugType tug;
Element s;
switch (cryptoParams->GetSecretKeyDist()) {
case GAUSSIAN:
s = Element(dgg, paramsPK, Format::EVALUATION);
break;
case UNIFORM_TERNARY:
s = Element(tug, paramsPK, Format::EVALUATION);
break;
case SPARSE_TERNARY:
case SPARSE_ENCAPSULATED:
s = Element(tug, paramsPK, Format::EVALUATION, 192);
break;
default:
OPENFHE_THROW("Unknown SecretKeyDist.");
}
const auto& pk = publicKey->GetPublicElements();
Element a(pk[1]);
Element e(dgg, paramsPK, Format::EVALUATION);
NativeInteger ns = cryptoParams->GetNoiseScale();
// b = ns * e - a * s
// When PRE is not used, a joint key is computed
Element b(std::move((e *= ns) -= (a * s)));
if (!fresh)
b += pk[0];
auto sizeQ = elementParams->GetParams().size();
auto sizePK = paramsPK->GetParams().size();
if (sizePK > sizeQ)
s.DropLastElements(sizePK - sizeQ);
KeyPair<Element> keyPair(std::make_shared<PublicKeyImpl<Element>>(cc),
std::make_shared<PrivateKeyImpl<Element>>(cc));
keyPair.secretKey->SetPrivateElement(std::move(s));
keyPair.publicKey->SetPublicElements({std::move(b), std::move(a)});
return keyPair;
}
template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiKeySwitchGen(const PrivateKey<Element> oldPrivateKey,
const PrivateKey<Element> newPrivateKey,
const EvalKey<Element> evalKey) const {
return oldPrivateKey->GetCryptoContext()->GetScheme()->KeySwitchGen(oldPrivateKey, newPrivateKey, evalKey);
}
template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiEvalAutomorphismKeyGen(
const PrivateKey<Element> privateKey, const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap,
const std::vector<uint32_t>& indexList) const {
const Element& s = privateKey->GetPrivateElement();
uint32_t N = s.GetRingDimension();
if (indexList.size() > N - 1)
OPENFHE_THROW("size exceeds the ring dimension");
const auto cc = privateKey->GetCryptoContext();
auto result = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();
// #pragma omp parallel for if (indexList.size() >= 4)
for (uint32_t i = 0; i < indexList.size(); i++) {
PrivateKey<Element> privateKeyPermuted = std::make_shared<PrivateKeyImpl<Element>>(cc);
uint32_t index = NativeInteger(indexList[i]).ModInverse(2 * N).ConvertToInt();
std::vector<uint32_t> vec(N);
PrecomputeAutoMap(N, index, &vec);
Element sPermuted = s.AutomorphismTransform(index, vec);
privateKeyPermuted->SetPrivateElement(std::move(sPermuted));
// verify if the key indexList[i] exists in the evalKeyMap
auto evalKeyIterator = evalKeyMap->find(indexList[i]);
if (evalKeyIterator == evalKeyMap->end()) {
OPENFHE_THROW("EvalKey for index [" + std::to_string(indexList[i]) + "] is not found.");
}
(*result)[indexList[i]] = MultiKeySwitchGen(privateKey, privateKeyPermuted, evalKeyIterator->second);
}
return result;
}
template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiEvalAtIndexKeyGen(
const PrivateKey<Element> privateKey, const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap,
const std::vector<int32_t>& indexList) const {
const auto cc = privateKey->GetCryptoContext();
uint32_t M = privateKey->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder();
std::vector<uint32_t> autoIndices(indexList.size());
for (size_t i = 0; i < indexList.size(); i++) {
autoIndices[i] = (isCKKS(cc->getSchemeId())) ? FindAutomorphismIndex2nComplex(indexList[i], M) :
FindAutomorphismIndex2n(indexList[i], M);
}
return MultiEvalAutomorphismKeyGen(privateKey, evalKeyMap, autoIndices);
}
template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiEvalSumKeyGen(
const PrivateKey<Element> privateKey,
const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap) const {
const auto cryptoParams = privateKey->GetCryptoParameters();
uint32_t batchSize = cryptoParams->GetEncodingParams()->GetBatchSize();
uint32_t M = cryptoParams->GetElementParams()->GetCyclotomicOrder();
std::vector<uint32_t> indices;
if (batchSize > 1) {
int isize = std::ceil(std::log2(batchSize)) - 1;
indices.reserve(isize + 1);
uint32_t g = 5;
for (int i = 0; i < isize; i++) {
indices.push_back(g);
g = (g * g) % M;
}
if (2 * batchSize < M)
indices.push_back(g);
else
indices.push_back(M - 1);
}
return MultiEvalAutomorphismKeyGen(privateKey, evalKeyMap, indices);
}
template <class Element>
Ciphertext<Element> MultipartyBase<Element>::MultipartyDecryptLead(ConstCiphertext<Element> ciphertext,
const PrivateKey<Element> privateKey) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(privateKey->GetCryptoParameters());
const std::shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
const auto ns = cryptoParams->GetNoiseScale();
const std::vector<Element>& cv = ciphertext->GetElements();
const Element& s = privateKey->GetPrivateElement();
DggType dgg(NoiseFlooding::MP_SD);
Element e(dgg, elementParams, Format::EVALUATION);
Element b = cv[0] + s * cv[1] + ns * e;
// b.SwitchFormat();
auto result = ciphertext->CloneEmpty();
result->SetElement(std::move(b));
return result;
}
template <class Element>
Ciphertext<Element> MultipartyBase<Element>::MultipartyDecryptMain(ConstCiphertext<Element> ciphertext,
const PrivateKey<Element> privateKey) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(privateKey->GetCryptoParameters());
const std::shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
const auto es = cryptoParams->GetNoiseScale();
const std::vector<Element>& cv = ciphertext->GetElements();
const Element& s = privateKey->GetPrivateElement();
DggType dgg(NoiseFlooding::MP_SD);
Element e(dgg, elementParams, Format::EVALUATION);
// e is added to do noise flooding
Element b = s * cv[1] + es * e;
auto result = ciphertext->CloneEmpty();
result->SetElement(std::move(b));
return result;
}
template <class Element>
DecryptResult MultipartyBase<Element>::MultipartyDecryptFusion(const std::vector<Ciphertext<Element>>& ciphertextVec,
NativePoly* plaintext) const {
const auto cryptoParams =
std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(ciphertextVec[0]->GetCryptoParameters());
const std::vector<Element>& cv0 = ciphertextVec[0]->GetElements();
Element b = cv0[0];
for (size_t i = 1; i < ciphertextVec.size(); i++) {
const std::vector<Element>& cvi = ciphertextVec[i]->GetElements();
b += cvi[0];
}
b.SetFormat(Format::COEFFICIENT);
*plaintext = b.ToNativePoly();
return DecryptResult(plaintext->GetLength());
}
template <class Element>
PublicKey<Element> MultipartyBase<Element>::MultiAddPubKeys(PublicKey<Element> publicKey1,
PublicKey<Element> publicKey2) const {
PublicKey<Element> publicKeySum = std::make_shared<PublicKeyImpl<Element>>(publicKey1->GetCryptoContext());
const Element& b1 = publicKey1->GetPublicElements()[0];
const Element& b2 = publicKey2->GetPublicElements()[0];
const Element& a = publicKey1->GetPublicElements()[1];
publicKeySum->SetPublicElements(std::vector<Element>{(b1 + b2), a});
return publicKeySum;
}
template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiAddEvalKeys(EvalKey<Element> evalKey1, EvalKey<Element> evalKey2) const {
const auto cc = evalKey1->GetCryptoContext();
EvalKey<Element> evalKeySum = std::make_shared<EvalKeyRelinImpl<Element>>(cc);
const std::vector<Element>& a = evalKey1->GetAVector();
const std::vector<Element>& b1 = evalKey1->GetBVector();
const std::vector<Element>& b2 = evalKey2->GetBVector();
std::vector<Element> b;
b.reserve(a.size());
for (uint32_t i = 0; i < a.size(); i++) {
b.push_back(b1[i] + b2[i]);
}
evalKeySum->SetAVector(a);
evalKeySum->SetBVector(std::move(b));
return evalKeySum;
}
template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiAddEvalMultKeys(EvalKey<Element> evalKey1,
EvalKey<Element> evalKey2) const {
const auto cc = evalKey1->GetCryptoContext();
EvalKey<Element> evalKeySum = std::make_shared<EvalKeyRelinImpl<Element>>(cc);
const std::vector<Element>& a1 = evalKey1->GetAVector();
const std::vector<Element>& a2 = evalKey2->GetAVector();
const std::vector<Element>& b1 = evalKey1->GetBVector();
const std::vector<Element>& b2 = evalKey2->GetBVector();
std::vector<Element> a;
a.reserve(a1.size());
std::vector<Element> b;
b.reserve(a1.size());
for (uint32_t i = 0; i < a1.size(); i++) {
a.push_back(a1[i] + a2[i]);
b.push_back(b1[i] + b2[i]);
}
evalKeySum->SetAVector(std::move(a));
evalKeySum->SetBVector(std::move(b));
return evalKeySum;
}
template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiMultEvalKey(PrivateKey<Element> privateKey,
EvalKey<Element> evalKey) const {
const auto cc = evalKey->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(cc->GetCryptoParameters());
const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
const auto elementParams = cryptoParams->GetElementParams();
EvalKey<Element> evalKeyResult = std::make_shared<EvalKeyRelinImpl<Element>>(cc);
const std::vector<Element>& a0 = evalKey->GetAVector();
const std::vector<Element>& b0 = evalKey->GetBVector();
const Element& s = privateKey->GetPrivateElement();
const auto ns = cryptoParams->GetNoiseScale();
std::vector<Element> a;
a.reserve(a0.size());
std::vector<Element> b;
b.reserve(a0.size());
for (uint32_t i = 0; i < a0.size(); i++) {
a.push_back(a0[i] * s + ns * Element(dgg, elementParams, Format::EVALUATION));
b.push_back(b0[i] * s + ns * Element(dgg, elementParams, Format::EVALUATION));
}
evalKeyResult->SetAVector(std::move(a));
evalKeyResult->SetBVector(std::move(b));
return evalKeyResult;
}
template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiAddEvalAutomorphismKeys(
const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap1,
const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap2) const {
auto evalKeyMapAuto = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();
for (auto it = evalKeyMap1->begin(); it != evalKeyMap1->end(); ++it) {
auto it2 = evalKeyMap2->find(it->first);
if (it2 != evalKeyMap2->end())
(*evalKeyMapAuto)[it->first] = MultiAddEvalKeys(it->second, it2->second);
}
return evalKeyMapAuto;
}
template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiAddEvalSumKeys(
const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap1,
const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap2) const {
auto EvalKeyMapSum = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();
for (auto it = evalKeyMap1->begin(); it != evalKeyMap1->end(); ++it) {
auto it2 = evalKeyMap2->find(it->first);
if (it2 != evalKeyMap2->end())
(*EvalKeyMapSum)[it->first] = MultiAddEvalKeys(it->second, it2->second);
}
return EvalKeyMapSum;
}
} // namespace lbcrypto
// the code below is from base-multiparty-impl.cpp
namespace lbcrypto {
template class MultipartyBase<DCRTPoly>;
} // namespace lbcrypto