Program Listing for File keyswitch-bv.cpp
↰ Return to documentation for file (pke/lib/keyswitch/keyswitch-bv.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#include "ciphertext.h"
#include "key/evalkeyrelin.h"
#include "key/privatekey.h"
#include "key/publickey.h"
#include "keyswitch/keyswitch-bv.h"
#include "schemerns/rns-cryptoparameters.h"
namespace lbcrypto {
EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
const PrivateKey<DCRTPoly> newKey) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newKey->GetCryptoParameters());
DugType dug;
auto dgg = cryptoParams->GetDiscreteGaussianGenerator();
const auto ns = cryptoParams->GetNoiseScale();
const auto& sNew = newKey->GetPrivateElement();
const auto& ep = sNew.GetParams();
const auto& sOld = oldKey->GetPrivateElement();
const uint32_t sizeSOld = sOld.GetNumOfElements();
std::vector<DCRTPoly> av, bv;
if (auto digitSize = cryptoParams->GetDigitSize(); digitSize > 0) {
// creates an array of digits up to a certain tower
std::vector<uint32_t> arrWindows(sizeSOld);
uint32_t nWindows = 0;
for (uint32_t i = 0; i < sizeSOld; ++i) {
arrWindows[i] = nWindows;
double sOldMSB = sOld.GetElementAtIndex(i).GetModulus().GetMSB();
nWindows += std::ceil(sOldMSB / digitSize);
}
av.resize(nWindows);
bv.resize(nWindows);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeSOld)) private(dug, dgg)
for (uint32_t i = 0; i < sizeSOld; ++i) {
auto sOldDecomposed = sOld.GetElementAtIndex(i).PowersOfBase(digitSize);
for (uint32_t j = arrWindows[i], k = 0; k < sOldDecomposed.size(); ++j, ++k) {
av[j] = DCRTPoly(dug, ep, Format::EVALUATION);
bv[j] = DCRTPoly(ep, Format::EVALUATION, true);
bv[j].SetElementAtIndex(i, std::move(sOldDecomposed[k]));
bv[j] -= (av[j] * sNew + DCRTPoly(dgg, ep, Format::EVALUATION) * ns);
}
}
}
else {
av.resize(sizeSOld);
bv.resize(sizeSOld);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeSOld)) private(dug, dgg)
for (uint32_t i = 0; i < sizeSOld; ++i) {
av[i] = DCRTPoly(dug, ep, Format::EVALUATION);
bv[i] = DCRTPoly(ep, Format::EVALUATION, true);
bv[i].SetElementAtIndex(i, sOld.GetElementAtIndex(i));
bv[i] -= (av[i] * sNew + DCRTPoly(dgg, ep, Format::EVALUATION) * ns);
}
}
auto ek = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext());
ek->SetAVector(std::move(av));
ek->SetBVector(std::move(bv));
ek->SetKeyTag(newKey->GetKeyTag());
return ek;
}
EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
const PrivateKey<DCRTPoly> newKey,
const EvalKey<DCRTPoly> ek) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(oldKey->GetCryptoParameters());
DugType dug;
auto dgg = cryptoParams->GetDiscreteGaussianGenerator();
const auto ns = cryptoParams->GetNoiseScale();
const auto& sNew = newKey->GetPrivateElement();
const auto& ep = sNew.GetParams();
const auto& sOld = oldKey->GetPrivateElement();
const uint32_t sizeSOld = sOld.GetNumOfElements();
std::vector<DCRTPoly> av, bv;
if (auto digitSize = cryptoParams->GetDigitSize(); digitSize > 0) {
// creates an array of digits up to a certain tower
std::vector<uint32_t> arrWindows(sizeSOld);
uint32_t nWindows = 0;
for (uint32_t i = 0; i < sizeSOld; ++i) {
arrWindows[i] = nWindows;
double sOldMSB = sOld.GetElementAtIndex(i).GetModulus().GetMSB();
nWindows += std::ceil(sOldMSB / digitSize);
}
av.resize(nWindows);
bv.resize(nWindows);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeSOld)) private(dug, dgg)
for (uint32_t i = 0; i < sizeSOld; ++i) {
auto sOldDecomposed = sOld.GetElementAtIndex(i).PowersOfBase(digitSize);
for (uint32_t j = arrWindows[i], k = 0; k < sOldDecomposed.size(); ++j, ++k) {
av[j] = ek ? ek->GetAVector()[j] : DCRTPoly(dug, ep, Format::EVALUATION);
bv[j] = DCRTPoly(ep, Format::EVALUATION, true);
bv[j].SetElementAtIndex(i, std::move(sOldDecomposed[k]));
bv[j] -= (av[j] * sNew + DCRTPoly(dgg, ep, Format::EVALUATION) * ns);
}
}
}
else {
av.resize(sizeSOld);
bv.resize(sizeSOld);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeSOld)) private(dug, dgg)
for (uint32_t i = 0; i < sizeSOld; ++i) {
av[i] = ek ? ek->GetAVector()[i] : DCRTPoly(dug, ep, Format::EVALUATION);
bv[i] = DCRTPoly(ep, Format::EVALUATION, true);
bv[i].SetElementAtIndex(i, sOld.GetElementAtIndex(i));
bv[i] -= (av[i] * sNew + DCRTPoly(dgg, ep, Format::EVALUATION) * ns);
}
}
auto evalKey = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext());
evalKey->SetAVector(std::move(av));
evalKey->SetBVector(std::move(bv));
evalKey->SetKeyTag(newKey->GetKeyTag());
return evalKey;
}
EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldSk,
const PublicKey<DCRTPoly> newPk) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newPk->GetCryptoParameters());
TugType tug;
auto dgg = cryptoParams->GetDiscreteGaussianGenerator();
const auto ns = cryptoParams->GetNoiseScale();
const auto& newp0 = newPk->GetPublicElements().at(0);
const auto& newp1 = newPk->GetPublicElements().at(1);
const auto& ep = newp0.GetParams();
const auto& sOld = oldSk->GetPrivateElement();
const uint32_t sizeSOld = sOld.GetNumOfElements();
std::vector<DCRTPoly> av, bv;
if (uint32_t digitSize = cryptoParams->GetDigitSize(); digitSize > 0) {
// creates an array of digits up to a certain tower
std::vector<uint32_t> arrWindows(sizeSOld);
uint32_t nWindows = 0;
for (uint32_t i = 0; i < sizeSOld; ++i) {
arrWindows[i] = nWindows;
double sOldMSB = sOld.GetElementAtIndex(i).GetModulus().GetMSB();
nWindows += std::ceil(sOldMSB / digitSize);
}
av.resize(nWindows);
bv.resize(nWindows);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeSOld)) private(tug, dgg)
for (uint32_t i = 0; i < sizeSOld; ++i) {
auto sOldDecomposed = sOld.GetElementAtIndex(i).PowersOfBase(digitSize);
for (uint32_t j = arrWindows[i], k = 0; k < sOldDecomposed.size(); ++j, ++k) {
bv[j] = DCRTPoly(ep, Format::EVALUATION, true);
bv[j].SetElementAtIndex(i, std::move(sOldDecomposed[k]));
bv[j] += DCRTPoly(dgg, ep, Format::EVALUATION) * ns;
DCRTPoly u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ? DCRTPoly(dgg, ep, Format::EVALUATION) :
DCRTPoly(tug, ep, Format::EVALUATION);
bv[j] += newp0 * u;
av[j] = newp1 * u;
av[j] += DCRTPoly(dgg, ep, Format::EVALUATION) * ns;
}
}
}
else {
av.resize(sizeSOld);
bv.resize(sizeSOld);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeSOld)) private(tug, dgg)
for (uint32_t i = 0; i < sizeSOld; ++i) {
bv[i] = DCRTPoly(ep, Format::EVALUATION, true);
bv[i].SetElementAtIndex(i, sOld.GetElementAtIndex(i));
bv[i] += DCRTPoly(dgg, ep, Format::EVALUATION) * ns;
DCRTPoly u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ? DCRTPoly(dgg, ep, Format::EVALUATION) :
DCRTPoly(tug, ep, Format::EVALUATION);
bv[i] += newp0 * u;
av[i] = newp1 * u;
av[i] += DCRTPoly(dgg, ep, Format::EVALUATION) * ns;
}
}
auto ek = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newPk->GetCryptoContext());
ek->SetAVector(std::move(av));
ek->SetBVector(std::move(bv));
ek->SetKeyTag(newPk->GetKeyTag());
return ek;
}
void KeySwitchBV::KeySwitchInPlace(Ciphertext<DCRTPoly>& ciphertext, const EvalKey<DCRTPoly> ek) const {
auto& cv = ciphertext->GetElements();
auto ba = KeySwitchCore(cv.back(), ek);
cv[0].SetFormat(Format::EVALUATION);
cv[0] += (*ba)[0];
if (cv.size() > 2) {
cv[1].SetFormat(Format::EVALUATION);
cv[1] += (*ba)[1];
}
else {
cv[1] = (*ba)[1];
}
cv.resize(2);
}
std::shared_ptr<std::vector<DCRTPoly>> KeySwitchBV::KeySwitchCore(const DCRTPoly& a,
const EvalKey<DCRTPoly> evalKey) const {
return EvalFastKeySwitchCore(EvalKeySwitchPrecomputeCore(a, evalKey->GetCryptoParameters()), evalKey,
a.GetParams());
}
std::shared_ptr<std::vector<DCRTPoly>> KeySwitchBV::EvalKeySwitchPrecomputeCore(
const DCRTPoly& c, std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParamsBase) const {
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(cryptoParamsBase);
return std::make_shared<std::vector<DCRTPoly>>(c.CRTDecompose(cryptoParams->GetDigitSize()));
}
std::shared_ptr<std::vector<DCRTPoly>> KeySwitchBV::EvalFastKeySwitchCore(
const std::shared_ptr<std::vector<DCRTPoly>> digits, const EvalKey<DCRTPoly> evalKey,
const std::shared_ptr<ParmType> paramsQl) const {
std::vector<DCRTPoly> bv(evalKey->GetBVector());
std::vector<DCRTPoly> av(evalKey->GetAVector());
const auto diffQl = bv[0].GetParams()->GetParams().size() - paramsQl->GetParams().size();
const uint32_t limit = (*digits).size();
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(limit))
for (uint32_t i = 0; i < limit; ++i) {
bv[i].DropLastElements(diffQl);
bv[i] *= (*digits)[i];
av[i].DropLastElements(diffQl);
av[i] *= (*digits)[i];
}
std::vector<DCRTPoly> res{std::move(bv[0]), std::move(av[0])};
for (uint32_t i = 1; i < limit; ++i) {
res[0] += bv[i];
res[1] += av[i];
}
return std::make_shared<std::vector<DCRTPoly>>(std::move(res));
}
} // namespace lbcrypto