Program Listing for File ckksrns-schemeswitching.cpp
↰ Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-schemeswitching.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
CKKS to FHEW scheme switching implementation.
*/
#define PROFILE
#include "cryptocontext.h"
#include "gen-cryptocontext.h"
#include "math/dftransform.h"
#include "scheme/ckksrns/ckksrns-fhe.h"
#include "scheme/ckksrns/ckksrns-schemeswitching.h"
#include "scheme/ckksrns/gen-cryptocontext-ckksrns.h"
#include <algorithm>
#include <cmath>
#include <iterator>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
// K = 16
// EvalChebyshevCoefficients([](double x) -> double {return std::pow(2.*M_PI, -1./8.) * std::cos(2.*M_PI/8. * (x - 0.25));}, -16, 16, 117)
static constexpr std::initializer_list<double> g_coefficientsFHEW16{
0.2455457340168511, -0.04791906488334782, 0.2838870204084082, -0.02994453873551349,
0.3557652261903648, 0.01510656188507299, 0.2953294667450001, 0.07120360233373937,
-0.1034734733966807, 0.04499759051255525, -0.4275071243192574, -0.09034212972909554,
0.367628762693249, 0.04931806603933471, -0.14535986272412, -0.01510693848306369,
0.03595193549924024, 0.003103658218868759, -0.006264460660707066, -0.0004660943047712052,
0.0008212879885240095, 5.391053389217882e-05, -8.455154976914221e-05, -4.977380178602518e-06,
7.04666204400644e-06, 3.765980757166835e-07, -4.864851013612591e-07, -2.383026746930811e-08,
2.832970640938316e-08, 1.28177429687131e-09, -1.412145521045378e-09, -5.939145408994641e-11,
6.099273252183116e-11, 2.397381728164642e-12, -2.307402856353623e-12, -8.500921247536622e-14,
7.704571444110577e-14, 2.704051671841271e-15, -2.154585361348821e-15, -7.263493008564584e-16,
-1.260761739828568e-16, 1.637108527837095e-16, 1.185492382226862e-16, 6.379078056744543e-16,
-1.411300455031979e-16, 3.123678340470779e-16, 5.946279250534737e-16, 2.954322285866942e-16,
-8.279629336187608e-17, 5.024229619913844e-16, -3.293034395074617e-16, -1.189255850106947e-15,
1.674743206637948e-16, -1.524204491434537e-16, -7.90328254817908e-17, 3.95164127408954e-16,
-1.317213758029847e-17, 8.016186584581639e-16, -3.650563843682718e-16, 3.763467880085276e-16,
-2.709696873661399e-16, -1.524204491434537e-16, -4.83605622590958e-16, 7.376397044967142e-16,
1.234417464667971e-15, -2.672062194860546e-16, -4.892508244110859e-17, -7.122362963061386e-16,
3.763467880085276e-17, -2.944913616166729e-16, -2.897870267665663e-16, 9.794425157921932e-16,
-3.198947698072485e-17, 6.614294799249873e-16, -5.7769231959309e-16, 6.586068790149234e-16,
-4.629065492504889e-16, -5.127724986616189e-16, -3.236582376873337e-16, -1.64745806450733e-15,
-9.408669700213192e-16, -4.986594941112991e-16, -1.209954923447416e-15, -1.373665776231126e-16,
-2.314532746252445e-16, 3.217765037472911e-16, 3.481207789078881e-16, 8.223177317986329e-16,
-9.766199148821293e-16, 6.19090466274028e-16, 1.209014056477395e-15, -3.30244306477483e-16,
5.974505259635377e-16, 5.993322599035803e-16, 1.829986256691466e-16, -2.690879534260973e-16,
8.618341445395283e-16, -1.002023323072705e-16, 6.374373721894436e-16, 6.270878355192092e-16,
1.199605386777182e-15, -8.712428142397415e-16, -2.507410475106815e-16, -1.086230916889613e-15,
1.072588345824304e-15, -4.534978795502758e-16, 2.119067633230516e-15, -1.842923177529259e-15,
-1.814697168428619e-15, 4.243310034796149e-16, 4.224492695395723e-16, 1.531966643937213e-15,
-2.850826919164597e-16, -8.958229638315484e-16, -5.02893395476395e-16, 1.096110020074837e-16,
-6.975352498995555e-16, -8.743006318923108e-16};
// K = 128
// EvalChebyshevCoefficients([](double x) -> double {return std::pow(2.*M_PI, -1./8.) * std::cos(2.*M_PI/8. * (x - 0.25));}, -128, 128, 159)
static constexpr std::initializer_list<double> g_coefficientsFHEW128_9{
0.08761193238226354, -0.01738402917379392, 0.08935060894767313, -0.01667686631436392,
0.09435445639097996, -0.01518333497826596, 0.1019473189108075, -0.01276275748916528,
0.110882655474149, -0.009252446966171999, 0.1192111685574758, -0.004534979909938953,
0.1242004317120066, 0.001362904847617233, 0.1224283765086551, 0.008145596233693092,
0.1102080588183085, 0.01512350467093644, 0.08449405378412403, 0.02114203679334985,
0.04431786059830115, 0.02464956129638114, -0.007454366487154669, 0.02400059020366966,
-0.06266441339261235, 0.0180491215413637, -0.1077943201829795, 0.00695836538813938,
-0.1265848641500751, -0.007067567033131986, -0.1060856934163377, -0.01966175019277508,
-0.04512467324356773, -0.02537595733026167, 0.03862916785371963, -0.0201785566296389,
0.1092652333753526, -0.004612578019766411, 0.1263344585514989, 0.01438496124842956,
0.07022427857484209, 0.02550072245548077, -0.03434514153678107, 0.01979242584243335,
-0.1194659697149694, -0.001008794768691528, -0.1149256786653964, -0.02192904329965044,
-0.01184295110147364, -0.02417858011117596, 0.1066507410103885, -0.003076473516323838,
0.122343225763269, 0.02209885820126707, 0.005200840409852563, 0.02321022960558625,
-0.1224755172356864, -0.003930982569218595, -0.1000653894904632, -0.02689795846413602,
0.05865754664309684, -0.01297065380451242, 0.1377909895596227, 0.02083617539534925,
0.006502421233003679, 0.02248299870285675, -0.139660074659475, -0.01399307934458518,
-0.04589168496663835, -0.0263421662574377, 0.1358978738303921, 0.01130242907664306,
0.05563799538901031, 0.02715486116995984, -0.1426236952996719, -0.01461041285557406,
-0.03302834981489188, -0.02454368648125577, 0.1559877850928394, 0.02360418859443232,
-0.03051465817859748, 0.01394389273916019, -0.1434779685133351, -0.0326137520114734,
0.1272587840850199, 0.00968806150092634, 0.04489729856072615, 0.02496761251245225,
-0.1723551233719199, -0.03505277577503257, 0.1396636892583818, 0.01468861799711852,
0.00597622458952793, 0.01686435635501478, -0.1508869780062401, -0.0392626068463298,
0.2221665014327329, 0.04513725581939847, -0.2157338005707834, -0.03852627732394119,
0.1657363840292956, 0.02705951812948022, -0.1076077703571204, -0.01637507278621027,
0.06108202920758021, 0.00876339054415577, -0.03094805600072437, -0.004218297546715713,
0.01419483272929196, 0.001848310901625205, -0.005954927442783235, -0.000743844834357433,
0.002303049930851211, 0.000276890872388833, -0.0008263094529170254, -9.587788269377866e-05,
0.000276459761481133, 3.10280277328683e-05, -8.662530848949058e-05, -9.421991495095434e-06,
2.551403977799462e-05, 2.693817139509806e-06, -7.08627853168575e-06, -7.273160333789605e-07,
1.861116908629967e-06, 1.859288982562724e-07, -4.633601616493963e-07, -4.510790623761216e-08,
1.096004151863654e-07, 1.040757606442787e-08, -2.467843591423806e-08, -2.288015736340782e-09,
5.299290810294302e-09, 4.800959747999802e-10, -1.086991796689273e-09, -9.63033831430537e-11,
2.133040952766737e-10, 1.849336626863696e-11, -4.010173216641153e-11, -3.404204559101731e-12,
7.232799297263385e-12, 6.027665442316686e-13, -1.252868789531569e-12, -1.035300267428826e-13,
2.072453780944445e-13, 1.81572061755555e-14, -3.012503176280137e-14, -4.417490972407089e-16,
3.698522891563647e-15, -4.204154533937635e-16, -2.740777660720187e-15, -1.348919106364917e-15,
-1.620799477984723e-15, 4.003965342611375e-16, -5.245330582249314e-16, 1.754761547401069e-15,
-5.0481471966847e-16, -4.722624632690369e-16, 1.628901569091919e-16, -1.219903204684612e-15};
// EvalChebyshevCoefficients([](double x) -> double {return std::pow(2.*M_PI, -1./8.) * std::cos(2.*M_PI/8. * (x - 0.25));}, -128, 128, 118)
static constexpr std::initializer_list<double> g_coefficientsFHEW128_8{
0.08761193238226343, -0.01738402917379268, 0.08935060894767202, -0.0166768663143651, 0.09435445639098095,
-0.01518333497826714, 0.1019473189108076, -0.01276275748916462, 0.1108826554741475, -0.009252446966171845,
0.1192111685574773, -0.004534979909938402, 0.1242004317120066, 0.001362904847616587, 0.1224283765086535,
0.008145596233693802, 0.1102080588183083, 0.0151235046709367, 0.08449405378412395, 0.02114203679334948,
0.04431786059830203, 0.02464956129638117, -0.007454366487155707, 0.02400059020367158, -0.06266441339261287,
0.01804912154136392, -0.107794320182978, 0.006958365388138488, -0.1265848641500738, -0.007067567033133184,
-0.1060856934163389, -0.01966175019277399, -0.0451246732435682, -0.02537595733026211, 0.0386291678537217,
-0.02017855662963969, 0.1092652333753532, -0.004612578019767425, 0.1263344585514991, 0.01438496124843117,
0.07022427857484087, 0.02550072245548053, -0.03434514153678073, 0.01979242584243296, -0.1194659697149702,
-0.00100879476868968, -0.1149256786653952, -0.02192904329965062, -0.01184295110147335, -0.02417858011117619,
0.1066507410103884, -0.003076473516322021, 0.1223432257632692, 0.02209885820126752, 0.005200840409853516,
0.02321022960558683, -0.1224755172356849, -0.003930982569218244, -0.1000653894904628, -0.02689795846413568,
0.05865754664309823, -0.01297065380451253, 0.1377909895596233, 0.02083617539534807, 0.006502421233004046,
0.02248299870285591, -0.1396600746594754, -0.01399307934458444, -0.04589168496663817, -0.02634216625743798,
0.1358978738303917, 0.0113024290766429, 0.05563799538901171, 0.02715486116995986, -0.1426236952996744,
-0.01461041285557423, -0.03302834981489241, -0.02454368648125667, 0.155987785092838, 0.02360418859443058,
-0.03051465817859778, 0.01394389273915945, -0.1434779685133346, -0.03261375201147241, 0.1272587840850196,
0.009688061500927738, 0.04489729856072736, 0.02496761251245433, -0.1723551233719191, -0.03505277577503064,
0.1396636892583768, 0.01468861799712161, 0.005976224589562133, 0.01686435635499993, -0.1508869780064481,
-0.03926260684622985, 0.2221665014339838, 0.04513725581879824, -0.2157338005780147, -0.03852627732053739,
0.1657363840693943, 0.02705951811098469, -0.1076077705704247, -0.0163750726899083, 0.06108203029457551,
0.008763390064061401, -0.03094806130001855, -0.00421829525869962, 0.01419485740772663, 0.001848300494047458,
-0.005955037043195616, -0.000743799726455706, 0.002303513291010024, 0.0002767049434914361, -0.0008281705698254814,
-9.515056665793518e-05, 0.0002835460400168608, 2.833421059257267e-05, -0.0001121393482639905};
namespace lbcrypto {
//------------------------------------------------------------------------------
// Key and modulus switch and extraction methods
//------------------------------------------------------------------------------
static NativeInteger RoundqQAlter(NativeInteger v, NativeInteger q, NativeInteger Q) {
return NativeInteger(static_cast<BasicInteger>(
std::floor(0.5 + v.ConvertToDouble() * q.ConvertToDouble() / Q.ConvertToDouble())))
.Mod(q);
}
// TODO: used anywhere?
EvalKey<DCRTPoly> switchingKeyGenRLWE(const PrivateKey<DCRTPoly>& ckksSK, ConstLWEPrivateKey& LWEsk) {
// This function is without the intermediate ModSwitch
// Extract CKKS params: method which populates the first n elements of a new RLWE key with the n elements of the target LWE key
auto skelements = ckksSK->GetPrivateElement();
auto lweskElements = LWEsk->GetElement();
for (size_t i = 0; i < skelements.GetNumOfElements(); i++) {
auto skelementsPlain = skelements.GetElementAtIndex(i);
skelementsPlain.SetFormat(Format::COEFFICIENT);
for (size_t j = 0; j < skelementsPlain.GetLength(); j++) {
if (j >= lweskElements.GetLength()) {
skelementsPlain[j] = 0;
}
else {
if (lweskElements[j] == 0) {
skelementsPlain[j] = 0;
}
else if (lweskElements[j].ConvertToInt() == 1) {
skelementsPlain[j] = 1;
}
else
skelementsPlain[j] = skelementsPlain.GetModulus() - 1;
}
}
skelementsPlain.SetFormat(Format::EVALUATION);
skelements.SetElementAtIndex(i, std::move(skelementsPlain));
}
skelements.OverrideFormat(Format::EVALUATION);
auto ccCKKS = ckksSK->GetCryptoContext();
auto RLWELWEsk = ccCKKS->KeyGen().secretKey;
RLWELWEsk->SetPrivateElement(std::move(skelements));
return ccCKKS->KeySwitchGen(ckksSK, RLWELWEsk);
}
void ModSwitch(ConstCiphertext<DCRTPoly> ctxt, Ciphertext<DCRTPoly>& ctxtKS, NativeInteger modulus_CKKS_to) {
if (ctxt->GetElements()[0].GetRingDimension() != ctxtKS->GetElements()[0].GetRingDimension())
OPENFHE_THROW("ModSwitch is implemented only for the same ring dimension.");
if (ctxt->GetElements()[0].GetNumOfElements() != 1 || ctxtKS->GetElements()[0].GetNumOfElements() != 1)
OPENFHE_THROW("ModSwitch is implemented only for ciphertext with one tower.");
std::vector<DCRTPoly> resultElements;
const auto& paramsQlP = ctxtKS->GetElements()[0].GetParams();
for (const auto& elem : ctxt->GetElements()) {
auto& ref = resultElements.emplace_back(paramsQlP, Format::COEFFICIENT, true);
ref.SetValuesModSwitch(elem, modulus_CKKS_to);
ref.SetFormat(Format::EVALUATION);
}
ctxtKS->SetElements(std::move(resultElements));
}
// TODO: used anywhere?
EvalKey<DCRTPoly> switchingKeyGen(const PrivateKey<DCRTPoly>& ckksSKto, const PrivateKey<DCRTPoly>& ckksSKfrom) {
auto skElements = ckksSKto->GetPrivateElement();
skElements.SetFormat(Format::COEFFICIENT);
auto skElementsFrom = ckksSKfrom->GetPrivateElement();
skElementsFrom.SetFormat(Format::COEFFICIENT);
for (size_t i = 0; i < skElements.GetNumOfElements(); i++) {
auto skElementsPlain = skElements.GetElementAtIndex(i);
auto skElementsFromPlain = skElementsFrom.GetElementAtIndex(i);
for (size_t j = 0; j < skElementsPlain.GetLength(); j++) {
if (skElementsFromPlain[j] == 0) {
skElementsPlain[j] = 0;
}
else if (skElementsFromPlain[j] == 1) {
skElementsPlain[j] = 1;
}
else
skElementsPlain[j] = skElementsPlain.GetModulus() - 1;
}
skElements.SetElementAtIndex(i, std::move(skElementsPlain));
}
skElements.SetFormat(Format::EVALUATION);
auto ccCKKSto = ckksSKto->GetCryptoContext();
auto oldTranformedSK = ccCKKSto->KeyGen().secretKey;
oldTranformedSK->SetPrivateElement(std::move(skElements));
return ccCKKSto->KeySwitchGen(oldTranformedSK, ckksSKto);
}
EvalKey<DCRTPoly> switchingKeyGenRLWEcc(const PrivateKey<DCRTPoly>& ckksSKto, const PrivateKey<DCRTPoly>& ckksSKfrom,
ConstLWEPrivateKey& LWEsk) {
auto skElements = ckksSKto->GetPrivateElement();
skElements.SetFormat(Format::COEFFICIENT);
auto skElementsFrom = ckksSKfrom->GetPrivateElement();
skElementsFrom.SetFormat(Format::COEFFICIENT);
auto skElements2 = ckksSKto->GetPrivateElement();
skElements2.SetFormat(Format::COEFFICIENT);
auto lweskElements = LWEsk->GetElement();
for (uint32_t i = 0; i < skElements.GetNumOfElements(); ++i) {
auto skElementsPlain = skElements.GetElementAtIndex(i);
auto skElementsFromPlain = skElementsFrom.GetElementAtIndex(i);
auto skElementsPlainLWE = skElements2.GetElementAtIndex(i);
for (uint32_t j = 0; j < skElementsPlain.GetLength(); ++j) {
if (skElementsFromPlain[j] == 0)
skElementsPlain[j] = 0;
else if (skElementsFromPlain[j] == 1)
skElementsPlain[j] = 1;
else
skElementsPlain[j] = skElementsPlain.GetModulus() - 1;
if (j >= lweskElements.GetLength()) {
skElementsPlainLWE[j] = 0;
}
else {
if (lweskElements[j] == 0)
skElementsPlainLWE[j] = 0;
else if (lweskElements[j].ConvertToInt() == 1)
skElementsPlainLWE[j] = 1;
else
skElementsPlainLWE[j] = skElementsPlain.GetModulus() - 1;
}
}
skElements.SetElementAtIndex(i, std::move(skElementsPlain));
skElements2.SetElementAtIndex(i, std::move(skElementsPlainLWE));
}
skElements.SetFormat(Format::EVALUATION);
skElements2.SetFormat(Format::EVALUATION);
auto ccCKKSto = ckksSKto->GetCryptoContext();
auto oldTranformedSK = ccCKKSto->KeyGen().secretKey;
oldTranformedSK->SetPrivateElement(std::move(skElements));
auto RLWELWEsk = ccCKKSto->KeyGen().secretKey;
RLWELWEsk->SetPrivateElement(std::move(skElements2));
return ccCKKSto->KeySwitchGen(oldTranformedSK, RLWELWEsk);
}
std::vector<std::vector<NativeInteger>> ExtractLWEpacked(const Ciphertext<DCRTPoly>& ct) {
auto originalA{(ct->GetElements()[1]).GetElementAtIndex(0)};
originalA.SetFormat(Format::COEFFICIENT);
auto* ptrA = &originalA.GetValues()[0];
auto originalB{(ct->GetElements()[0]).GetElementAtIndex(0)};
originalB.SetFormat(Format::COEFFICIENT);
auto* ptrB = &originalB.GetValues()[0];
size_t N = originalB.GetLength();
return {std::vector<NativeInteger>(ptrB, ptrB + N), std::vector<NativeInteger>(ptrA, ptrA + N)};
}
std::shared_ptr<LWECiphertextImpl> ExtractLWECiphertext(const std::vector<std::vector<NativeInteger>>& aANDb,
NativeInteger modulus, uint32_t n, uint32_t index = 0) {
NativeVector a(n, modulus);
for (uint32_t i = 0; i < n && i <= index; ++i)
a[i] = modulus - aANDb[1][index - i];
if (n > index) {
uint32_t N = aANDb[0].size();
for (uint32_t i = index + 1; i < n; ++i)
a[i] = aANDb[1][N + index - i];
}
return std::make_shared<LWECiphertextImpl>(std::move(a), aANDb[0][index]);
}
//------------------------------------------------------------------------------
// Linear transformation methods.
//------------------------------------------------------------------------------
std::vector<ReadOnlyPlaintext> SWITCHCKKSRNS::EvalLTPrecomputeSwitch(
const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A,
const std::vector<std::vector<std::complex<double>>>& B, uint32_t dim1, uint32_t L, double scale = 1.0) const {
const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
auto elementParams = *(cryptoParamsCKKS->GetElementParams());
uint32_t towersToDrop = 0;
if (L != 0) {
towersToDrop = elementParams.GetParams().size() - L - 1;
for (uint32_t i = 0; i < towersToDrop; ++i)
elementParams.PopLastParam();
}
const auto& paramsQ = elementParams.GetParams();
const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();
size_t sizeQP = paramsQ.size() + paramsP.size();
std::vector<NativeInteger> moduli;
moduli.reserve(sizeQP);
std::vector<NativeInteger> roots;
roots.reserve(sizeQP);
for (const auto& elem : paramsQ) {
moduli.emplace_back(elem->GetModulus());
roots.emplace_back(elem->GetRootOfUnity());
}
for (const auto& elem : paramsP) {
moduli.emplace_back(elem->GetModulus());
roots.emplace_back(elem->GetRootOfUnity());
}
auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(cc.GetCyclotomicOrder(), moduli, roots);
int32_t slots = A.size();
int32_t step = (dim1 == 0) ? getRatioBSGSLT(slots) : dim1;
std::vector<std::vector<std::complex<double>>> newA(slots);
// A and B are concatenated horizontally
for (uint32_t i = 0; i < A.size(); i++) {
newA[i].reserve(A[i].size() + B[i].size());
newA[i].insert(newA[i].end(), A[i].begin(), A[i].end());
newA[i].insert(newA[i].end(), B[i].begin(), B[i].end());
}
uint32_t M4 = cc.GetCyclotomicOrder() / 4;
std::vector<ReadOnlyPlaintext> result(slots);
#if !defined(__MINGW32__) && !defined(__MINGW64__)
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(slots))
#endif
for (int32_t ji = 0; ji < slots; ++ji) {
auto vec = ExtractShiftedDiagonal(newA, ji);
for (auto& v : vec)
v *= scale;
result[ji] = FHECKKSRNS::MakeAuxPlaintext(cc, elementParamsPtr, Rotate(Fill(vec, M4), -step * (ji / step)), 1,
towersToDrop, M4);
}
return result;
}
std::vector<ReadOnlyPlaintext> SWITCHCKKSRNS::EvalLTPrecomputeSwitch(
const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A, uint32_t dim1,
uint32_t L, double scale = 1.0) const {
if (A[0].size() != A.size())
OPENFHE_THROW("The matrix passed to EvalLTPrecomputeSwitch is not square");
// Make sure the plaintext is created only with the necessary amount of moduli
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
auto elementParams = *(cryptoParams->GetElementParams());
uint32_t towersToDrop = 0;
if (L != 0) {
towersToDrop = elementParams.GetParams().size() - L - 1;
for (uint32_t i = 0; i < towersToDrop; i++)
elementParams.PopLastParam();
}
const auto& paramsQ = elementParams.GetParams();
const auto& paramsP = cryptoParams->GetParamsP()->GetParams();
size_t sizeQP = paramsQ.size() + paramsP.size();
std::vector<NativeInteger> moduli;
moduli.reserve(sizeQP);
std::vector<NativeInteger> roots;
roots.reserve(sizeQP);
for (const auto& elem : paramsQ) {
moduli.emplace_back(elem->GetModulus());
roots.emplace_back(elem->GetRootOfUnity());
}
for (const auto& elem : paramsP) {
moduli.emplace_back(elem->GetModulus());
roots.emplace_back(elem->GetRootOfUnity());
}
auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(cc.GetCyclotomicOrder(), moduli, roots);
int32_t slots = A.size();
int32_t step = (dim1 == 0) ? getRatioBSGSLT(slots) : dim1;
uint32_t M4 = cc.GetCyclotomicOrder() / 4;
std::vector<ReadOnlyPlaintext> result(slots);
#if !defined(__MINGW32__) && !defined(__MINGW64__)
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(slots))
#endif
for (int32_t ji = 0; ji < slots; ++ji) {
auto vec = ExtractShiftedDiagonal(A, ji);
for (auto& v : vec)
v *= scale;
result[ji] = FHECKKSRNS::MakeAuxPlaintext(cc, elementParamsPtr, Rotate(Fill(vec, M4), -step * (ji / step)), 1,
towersToDrop, M4);
}
return result;
}
std::vector<std::vector<std::complex<double>>> EvalLTRectPrecomputeSwitch(
const std::vector<std::vector<std::complex<double>>>& A, uint32_t dim1, double scale) {
if (!IsPowerOfTwo(A.size()) || !IsPowerOfTwo(A[0].size()))
OPENFHE_THROW("The matrix passed to EvalLTPrecompute is not padded up to powers of two");
const uint32_t n = std::min(A.size(), A[0].size());
std::vector<std::vector<std::complex<double>>> diags(n);
if (A.size() >= A[0].size()) {
uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(n) : dim1;
uint32_t gStep = std::ceil(static_cast<double>(n) / bStep);
auto num_slices = A.size() / A[0].size();
std::vector<std::vector<std::vector<std::complex<double>>>> A_slices(num_slices);
for (size_t i = 0; i < num_slices; i++) {
A_slices[i] = std::vector<std::vector<std::complex<double>>>(A.begin() + i * A[0].size(),
A.begin() + (i + 1) * A[0].size());
}
#pragma omp parallel for
for (uint32_t j = 0; j < gStep; j++) {
for (uint32_t i = 0; i < bStep; i++) {
if (bStep * j + i < n) {
std::vector<std::complex<double>> diag;
diag.reserve(A.size() * num_slices);
for (uint32_t k = 0; k < num_slices; k++) {
auto tmp = ExtractShiftedDiagonal(A_slices[k], bStep * j + i);
diag.insert(diag.end(), std::make_move_iterator(tmp.begin()),
std::make_move_iterator(tmp.end()));
}
std::transform(diag.begin(), diag.end(), diag.begin(),
[&](const std::complex<double>& elem) { return elem * scale; });
diags[bStep * j + i] = std::move(diag);
}
}
}
}
else {
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t ji = 0; ji < n; ++ji) {
auto diag = ExtractShiftedDiagonal(A, ji);
for (auto& d : diag)
d *= scale;
diags[ji] = std::move(diag);
}
}
return diags;
}
Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalLTWithPrecomputeSwitch(const CryptoContextImpl<DCRTPoly>& cc,
ConstCiphertext<DCRTPoly> ctxt,
const std::vector<ReadOnlyPlaintext>& A,
uint32_t dim1) const {
// Computing the baby-step bStep and the giant-step gStep
uint32_t slots = A.size();
uint32_t bStep = dim1;
uint32_t gStep = std::ceil(static_cast<double>(slots) / bStep);
// Computes the NTTs for each CRT limb (for the hoisted automorphisms used later on)
auto digits = cc.EvalFastRotationPrecompute(ctxt);
// Hoisted automorphisms
std::vector<Ciphertext<DCRTPoly>> fastRotation(bStep - 1);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(bStep - 1))
for (uint32_t j = 1; j < bStep; ++j)
fastRotation[j - 1] = cc.EvalFastRotationExt(ctxt, j, digits, true);
uint32_t M = cc.GetCyclotomicOrder();
uint32_t N = cc.GetRingDimension();
std::vector<uint32_t> map(N);
Ciphertext<DCRTPoly> result;
DCRTPoly first;
for (uint32_t j = 0; j < gStep; ++j) {
auto inner = FHECKKSRNS::EvalMultExt(cc.KeySwitchExt(ctxt, true), A[bStep * j]);
for (uint32_t i = 1; i < bStep; ++i) {
if (bStep * j + i < slots)
FHECKKSRNS::EvalAddExtInPlace(inner, FHECKKSRNS::EvalMultExt(fastRotation[i - 1], A[bStep * j + i]));
}
if (j == 0) {
first = cc.KeySwitchDownFirstElement(inner);
auto elements = inner->GetElements();
elements[0].SetValuesToZero();
inner->SetElements(std::move(elements));
result = std::move(inner);
}
else {
inner = cc.KeySwitchDown(inner);
// Find the automorphism index that corresponds to the rotation index.
uint32_t autoIndex = FindAutomorphismIndex2nComplex(bStep * j, M);
PrecomputeAutoMap(N, autoIndex, &map);
first += inner->GetElements()[0].AutomorphismTransform(autoIndex, map);
auto&& innerDigits = cc.EvalFastRotationPrecompute(inner);
FHECKKSRNS::EvalAddExtInPlace(result, cc.EvalFastRotationExt(inner, bStep * j, innerDigits, false));
}
}
result = cc.KeySwitchDown(result);
result->GetElements()[0] += first;
return result;
}
Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalLTRectWithPrecomputeSwitch(
const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A,
ConstCiphertext<DCRTPoly> ct, bool wide, uint32_t dim1, uint32_t L) const {
uint32_t n = std::min(A.size(), A[0].size());
// Computing the baby-step bStep and the giant-step gStep
uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(n) : dim1;
uint32_t gStep = std::ceil(static_cast<double>(n) / bStep);
uint32_t M = cc.GetCyclotomicOrder();
uint32_t N = cc.GetRingDimension();
// Computes the NTTs for each CRT limb (for the hoisted automorphisms used later on)
auto digits = cc.EvalFastRotationPrecompute(ct);
std::vector<Ciphertext<DCRTPoly>> fastRotation(bStep - 1);
// Make sure the plaintext is created only with the necessary amount of moduli
const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ct->GetCryptoParameters());
ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());
uint32_t towersToDrop = 0;
// For FLEXIBLEAUTOEXT we do not need extra modulus in auxiliary plaintexts
if (L != 0) {
towersToDrop = elementParams.GetParams().size() - L - 1;
for (uint32_t i = 0; i < towersToDrop; i++)
elementParams.PopLastParam();
}
if (cryptoParamsCKKS->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
towersToDrop += 1;
elementParams.PopLastParam();
}
const auto& paramsQ = elementParams.GetParams();
const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();
size_t sizeQP = paramsQ.size() + paramsP.size();
std::vector<NativeInteger> moduli;
moduli.reserve(sizeQP);
std::vector<NativeInteger> roots;
roots.reserve(sizeQP);
for (const auto& elem : paramsQ) {
moduli.emplace_back(elem->GetModulus());
roots.emplace_back(elem->GetRootOfUnity());
}
for (const auto& elem : paramsP) {
moduli.emplace_back(elem->GetModulus());
roots.emplace_back(elem->GetRootOfUnity());
}
auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);
auto elementParamsPtr2 = std::dynamic_pointer_cast<typename DCRTPoly::Params>(elementParamsPtr);
// Hoisted automorphisms
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(bStep - 1))
for (uint32_t j = 1; j < bStep; ++j)
fastRotation[j - 1] = cc.EvalFastRotationExt(ct, j, digits, true);
std::vector<uint32_t> map(N);
Ciphertext<DCRTPoly> result;
DCRTPoly first;
for (uint32_t j = 0; j < gStep; j++) {
int32_t offset = (j == 0) ? 0 : -static_cast<int32_t>(bStep * j);
auto temp = cc.MakeCKKSPackedPlaintext(Rotate(Fill(A[bStep * j], N / 2), offset), 1, towersToDrop,
elementParamsPtr2, N / 2);
auto inner = FHECKKSRNS::EvalMultExt(cc.KeySwitchExt(ct, true), temp);
for (uint32_t i = 1; i < bStep; i++) {
if (bStep * j + i < n) {
auto tempi = cc.MakeCKKSPackedPlaintext(Rotate(Fill(A[bStep * j + i], N / 2), offset), 1, towersToDrop,
elementParamsPtr2, N / 2);
FHECKKSRNS::EvalAddExtInPlace(inner, FHECKKSRNS::EvalMultExt(fastRotation[i - 1], tempi));
}
}
if (j == 0) {
first = cc.KeySwitchDownFirstElement(inner);
auto elements = inner->GetElements();
elements[0].SetValuesToZero();
inner->SetElements(std::move(elements));
result = std::move(inner);
}
else {
inner = cc.KeySwitchDown(inner);
// Find the automorphism index that corresponds to rotation index index.
uint32_t autoIndex = FindAutomorphismIndex2nComplex(bStep * j, M);
PrecomputeAutoMap(N, autoIndex, &map);
first += inner->GetElements()[0].AutomorphismTransform(autoIndex, map);
auto&& innerDigits = cc.EvalFastRotationPrecompute(inner);
FHECKKSRNS::EvalAddExtInPlace(result, cc.EvalFastRotationExt(inner, bStep * j, innerDigits, false));
}
}
result = cc.KeySwitchDown(result);
result->GetElements()[0] += first;
// A represents the diagonals, which lose the information whether the initial matrix is tall or wide
if (wide) {
uint32_t logl = lbcrypto::GetMSB(A[0].size() / A.size()) - 1; // These are powers of two, so log(l) is integer
std::vector<Ciphertext<DCRTPoly>> ctxt(logl + 1);
ctxt[0] = result;
for (uint32_t j = 1; j <= logl; ++j) {
ctxt[j] = cc.EvalAdd(ctxt[j - 1], cc.EvalAtIndex(ctxt[j - 1], A.size() * (1 << (j - 1))));
}
result = ctxt[logl];
}
return result;
}
Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalSlotsToCoeffsSwitch(const CryptoContextImpl<DCRTPoly>& cc,
ConstCiphertext<DCRTPoly> ctxt) const {
if (m_U0Pre.size() == 0)
OPENFHE_THROW("Precomputations not generated. Call EvalCKKSToFHEWPrecompute to proceed.");
uint32_t m = 4 * m_numSlotsCKKS;
uint32_t M = cc.GetCyclotomicOrder();
bool isSparse = (M != m);
auto ctxtToDecode = ctxt->Clone();
uint32_t numTowersToKeep = 2;
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
if (cryptoParams->GetScalingTechnique() == ScalingTechnique::FLEXIBLEAUTO ||
cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
ctxtToDecode = cc.Compress(ctxtToDecode, numTowersToKeep + 1);
double targetSF =
cryptoParams->GetScalingFactorReal(cryptoParams->GetElementParams()->GetParams().size() - numTowersToKeep);
double sourceSF = ctxtToDecode->GetScalingFactor();
uint32_t numTowers = ctxtToDecode->GetElements()[0].GetNumOfElements();
double modToDrop = cryptoParams->GetElementParams()->GetParams()[numTowers - 1]->GetModulus().ConvertToDouble();
double adjustmentFactor = (targetSF / sourceSF) * (modToDrop / sourceSF);
ctxtToDecode = cc.EvalMult(ctxtToDecode, adjustmentFactor);
cc.GetScheme()->ModReduceInternalInPlace(ctxtToDecode, 1);
ctxtToDecode->SetScalingFactor(targetSF);
}
else {
ctxtToDecode = cc.Compress(ctxtToDecode, numTowersToKeep);
}
Ciphertext<DCRTPoly> ctxtDecoded;
if (!isSparse) { // fully packed
// ctxtToDecode = cc.EvalAdd(ctxtToDecode, cc.GetScheme()->MultByMonomial(ctxtToDecode, M / 4));
ctxtDecoded = EvalLTWithPrecomputeSwitch(cc, ctxtToDecode, m_U0Pre, m_dim1CF);
}
else { // sparsely packed
ctxtDecoded = EvalLTWithPrecomputeSwitch(cc, ctxtToDecode, m_U0Pre, m_dim1CF);
cc.EvalAddInPlace(ctxtDecoded, cc.EvalAtIndex(ctxtDecoded, m_numSlotsCKKS));
}
return ctxtDecoded;
}
Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalPartialHomDecryption(const CryptoContextImpl<DCRTPoly>& cc,
const std::vector<std::vector<std::complex<double>>>& A,
ConstCiphertext<DCRTPoly> ct, uint32_t dim1, double scale,
uint32_t L) const {
// Currently, by design, the # rows (# LWE ciphertexts to switch) is a power of two.
// Ensure that # cols (LWE lattice parameter n) is padded up to a power of two
auto Acopy = A;
size_t cols_po2 = 1 << static_cast<uint32_t>(std::ceil(std::log2(A[0].size())));
if (cols_po2 > A[0].size()) {
for (size_t i = 0; i < A.size(); ++i)
Acopy[i].resize(cols_po2);
}
auto Apre = EvalLTRectPrecomputeSwitch(Acopy, dim1, scale);
// The result is repeated every Acopy.size() slots
return EvalLTRectWithPrecomputeSwitch(cc, Apre, ct, (Acopy.size() < A[0].size()), dim1, L);
}
//------------------------------------------------------------------------------
// Scheme switching Wrapper
//------------------------------------------------------------------------------
LWEPrivateKey SWITCHCKKSRNS::EvalCKKStoFHEWSetup(const SchSwchParams& params) {
if (params.GetSecurityLevelFHEW() != TOY && params.GetSecurityLevelFHEW() != STD128)
OPENFHE_THROW("Only STD128 or TOY are currently supported.");
uint32_t ringDim = params.GetRingDimension();
m_numSlotsCKKS = (params.GetNumSlotsCKKS() == 0) ? ringDim / 2 : params.GetNumSlotsCKKS();
m_modulus_CKKS_initial = params.GetInitialCKKSModulus();
// Modulus to switch to in order to have secure RLWE samples with ring dimension n.
// We can select any Qswitch less than 27 bits corresponding to 128 bits of security for lattice parameter n=1024 < 1305
// according to https://homomorphicencryption.org/wp-content/uploads/2018/11/HomomorphicEncryptionStandardv1.1.pdf
// or any Qswitch for TOY security.
// Ensure that Qswitch is larger than Q_FHEW and smaller than Q_CKKS.
if (params.GetCtxtModSizeFHEWIntermedSwch() <= params.GetCtxtModSizeFHEWLargePrec() ||
params.GetCtxtModSizeFHEWIntermedSwch() > m_modulus_CKKS_initial.GetMSB() - 1) {
OPENFHE_THROW("Qswitch should be larger than QFHEW and smaller than QCKKS.");
}
// Intermediate cryptocontext
CCParams<CryptoContextCKKSRNS> parameters;
parameters.SetMultiplicativeDepth(0);
parameters.SetFirstModSize(params.GetCtxtModSizeFHEWIntermedSwch());
// scaling mod size is not used in this case
parameters.SetScalingModSize(params.GetCtxtModSizeFHEWIntermedSwch());
// This doesn't need this to be the same scaling technique as the outer cryptocontext, since we only do a key switch
parameters.SetScalingTechnique(FIXEDMANUAL);
parameters.SetSecurityLevel(params.GetSecurityLevelCKKS());
parameters.SetRingDim(ringDim);
parameters.SetBatchSize(params.GetBatchSize());
parameters.SetCKKSDataType(REAL);
m_ccKS = GenCryptoContext(parameters);
// Enable the features that you wish to use
m_ccKS->Enable(PKE);
m_ccKS->Enable(KEYSWITCH);
// Get the ciphertext modulus
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(m_ccKS->GetCryptoParameters());
m_modulus_CKKS_from = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus();
m_ccLWE = std::make_shared<BinFHEContext>();
m_ccLWE->BinFHEContext::GenerateBinFHEContext(
params.GetSecurityLevelFHEW(), params.GetArbitraryFunctionEvaluation(), params.GetCtxtModSizeFHEWLargePrec(), 0,
GINX, params.GetUseDynamicModeFHEW());
// For arbitrary functions, the LWE ciphertext needs to be at most the ring dimension in FHEW bootstrapping
m_modulus_LWE = (!params.GetArbitraryFunctionEvaluation()) ?
1 << params.GetCtxtModSizeFHEWLargePrec() :
m_ccLWE->GetParams()->GetLWEParams()->Getq().ConvertToInt();
// The baby-step and number of levels for the linear transformation associated to the homomorphic decoding
m_dim1CF = (params.GetBStepLTrCKKStoFHEW() == 0) ? getRatioBSGSLT(params.GetNumSlotsCKKS()) :
params.GetBStepLTrCKKStoFHEW();
m_LCF = params.GetLevelLTrCKKStoFHEW();
// Return LWE private key
return m_ccLWE->KeyGen();
}
std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalCKKStoFHEWKeyGen(
const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk) {
auto& privateKey = keyPair.secretKey;
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());
if (cryptoParams->GetKeySwitchTechnique() != HYBRID)
OPENFHE_THROW("CKKS to FHEW scheme switching is only supported for the Hybrid key switching method.");
#if NATIVEINT == 128
if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
OPENFHE_THROW("128-bit CKKS to FHEW scheme switching is supported for FIXEDMANUAL and FIXEDAUTO methods only.");
#endif
auto ccCKKS = privateKey->GetCryptoContext();
// Intermediate cryptocontext for CKKS to FHEW
auto keys2 = m_ccKS->KeyGen();
m_ctxtKS = m_ccKS->Encrypt(keys2.publicKey, m_ccKS->MakeCKKSPackedPlaintext(std::vector<double>{0.0}));
// Compute switching key between RLWE and LWE via the intermediate cryptocontext, keep it in RLWE form
m_CKKStoFHEWswk = switchingKeyGenRLWEcc(keys2.secretKey, privateKey, lwesk);
// Compute automorphism keys
uint32_t M = ccCKKS->GetCyclotomicOrder();
uint32_t slots = m_numSlotsCKKS;
// Compute indices for rotations for slotToCoeff transform
std::vector<int32_t> indexRotationS2C = FindLTRotationIndicesSwitch(m_dim1CF, M, slots);
indexRotationS2C.emplace_back(static_cast<int32_t>(slots));
// Remove possible duplicates and zero
sort(indexRotationS2C.begin(), indexRotationS2C.end());
indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());
indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());
auto algo = ccCKKS->GetScheme();
// Compute multiplication key
algo->EvalMultKeyGen(privateKey);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);
// Compute conjugation key
(*evalKeys)[M - 1] = FHECKKSRNS::ConjugateKeyGen(privateKey);
return evalKeys;
}
void SWITCHCKKSRNS::EvalCKKStoFHEWPrecompute(const CryptoContextImpl<DCRTPoly>& cc, double scale) {
uint32_t M = cc.GetCyclotomicOrder();
uint32_t slots = m_numSlotsCKKS;
uint32_t m = 4 * m_numSlotsCKKS;
uint32_t mmask = m - 1; // assumes m is power of 2
bool isSparse = (M != m);
// Computes indices for all primitive roots of unity
std::vector<uint32_t> rotGroup(slots);
uint32_t fivePows = 1;
for (uint32_t i = 0; i < slots; ++i) {
rotGroup[i] = fivePows;
fivePows *= 5;
fivePows &= mmask;
}
// Computes all powers of a primitive root of unity exp(2*M_PI/m)
std::vector<std::complex<double>> ksiPows(m + 1);
double ak = 2 * M_PI / m;
for (uint32_t j = 0; j < m; ++j) {
double angle = ak * j;
ksiPows[j].real(std::cos(angle));
ksiPows[j].imag(std::sin(angle));
}
ksiPows[m] = ksiPows[0];
// Matrices for decoding
std::vector<std::vector<std::complex<double>>> U0(slots, std::vector<std::complex<double>>(slots));
std::vector<std::vector<std::complex<double>>> U1(slots, std::vector<std::complex<double>>(slots));
for (uint32_t i = 0; i < slots; ++i) {
for (uint32_t j = 0; j < slots; ++j) {
U0[i][j] = ksiPows[(j * rotGroup[i]) & mmask];
U1[i][j] = std::complex<double>(0, 1) * U0[i][j];
}
}
// Obtain the right scaling for encoded messages in FHEW coming from encoded messages in CKKS
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
double scFactor = cryptoParams->GetScalingFactorReal(cryptoParams->GetElementParams()->GetParams().size() - 1);
scale *= m_modulus_CKKS_initial.ConvertToDouble() / scFactor;
if (!isSparse) { // fully packed
m_U0Pre = EvalLTPrecomputeSwitch(cc, U0, m_dim1CF, m_LCF, scale);
}
else { // sparsely packed
m_U0Pre = EvalLTPrecomputeSwitch(cc, U0, U1, m_dim1CF, m_LCF, scale);
}
}
std::vector<std::shared_ptr<LWECiphertextImpl>> SWITCHCKKSRNS::EvalCKKStoFHEW(ConstCiphertext<DCRTPoly> ciphertext,
uint32_t numCtxts) {
if (numCtxts == 0 || numCtxts > m_numSlotsCKKS)
numCtxts = m_numSlotsCKKS;
// Step 1. Homomorphic decoding
auto ccCKKS = ciphertext->GetCryptoContext();
auto ctxtDecoded = EvalSlotsToCoeffsSwitch(*ccCKKS, ciphertext);
ccCKKS->GetScheme()->ModReduceInternalInPlace(ctxtDecoded, 1);
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());
// Step 2. Modulus switch to Q', such that CKKS is secure for (Q',n)
auto ctxtKS = m_ctxtKS->Clone();
ModSwitch(ctxtDecoded, ctxtKS, m_modulus_CKKS_from);
// Step 3: Key switch from the CKKS key with the new modulus Q' to the RLWE version of the FHEW key with the new modulus Q'
auto ccKS = ctxtKS->GetCryptoContext(); // Use this instead of m_ccKS to work with serialization
auto ctSwitched = ccKS->KeySwitch(ctxtKS, m_CKKStoFHEWswk);
// Step 4. Extract LWE ciphertexts with the modulus Q'
uint32_t n = m_ccLWE->GetParams()->GetLWEParams()->Getn(); // lattice parameter for additive LWE
std::vector<std::shared_ptr<LWECiphertextImpl>> LWEciphertexts(numCtxts);
auto AandB = ExtractLWEpacked(ctSwitched);
uint32_t gap = ccKS->GetRingDimension() / (2 * m_numSlotsCKKS);
for (uint32_t i = 0, idx = 0; i < numCtxts; ++i, idx += gap)
LWEciphertexts[i] = ExtractLWECiphertext(AandB, m_modulus_CKKS_from, n, idx);
// Step 5. Modulus switch to q in FHEW
// Compute the necessary factor to obtaine the message Q'/pLWE
if (m_modulus_LWE != m_modulus_CKKS_from) {
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numCtxts))
for (uint32_t i = 0; i < numCtxts; ++i) {
auto& original_a = LWEciphertexts[i]->GetA();
auto original_b = LWEciphertexts[i]->GetB();
// multiply by Q_LWE/Q' and round to Q_LWE
NativeVector a_round(n, m_modulus_LWE);
for (uint32_t j = 0; j < n; ++j)
a_round[j] = RoundqQAlter(original_a[j], m_modulus_LWE, m_modulus_CKKS_from);
NativeInteger b_round = RoundqQAlter(original_b, m_modulus_LWE, m_modulus_CKKS_from);
LWEciphertexts[i] = std::make_shared<LWECiphertextImpl>(std::move(a_round), std::move(b_round));
}
}
return LWEciphertexts;
}
//------------------------------------------------------------------------------
// Scheme switching Wrapper
//------------------------------------------------------------------------------
void SWITCHCKKSRNS::EvalFHEWtoCKKSSetup(const CryptoContextImpl<DCRTPoly>& ccCKKS,
const std::shared_ptr<BinFHEContext>& ccLWE, uint32_t numSlotsCKKS,
uint32_t logQ) {
if (ccLWE->GetParams()->GetLWEParams()->Getn() * 2 > ccCKKS.GetRingDimension())
OPENFHE_THROW("The lattice parameter in LWE cannot be larger than half the RLWE ring dimension.");
m_ccLWE = ccLWE;
if (numSlotsCKKS == 0) {
if (ccCKKS.GetEncodingParams()->GetBatchSize() != 0)
m_numSlotsCKKS = ccCKKS.GetEncodingParams()->GetBatchSize();
else
m_numSlotsCKKS = ccCKKS.GetRingDimension() / 2;
}
else {
m_numSlotsCKKS = numSlotsCKKS;
}
m_modulus_LWE = (logQ != 0) ? 1 << logQ : m_ccLWE->GetParams()->GetLWEParams()->Getq().ConvertToInt();
}
std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalFHEWtoCKKSKeyGen(
const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk, uint32_t numSlots, uint32_t numCtxts, uint32_t dim1,
uint32_t L) {
auto& privateKey = keyPair.secretKey;
auto& publicKey = keyPair.publicKey;
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());
auto ccCKKS = privateKey->GetCryptoContext();
uint32_t n = lwesk->GetElement().GetLength();
uint32_t ringDim = ccCKKS->GetRingDimension();
// Generate FHEW to CKKS switching key, i.e., CKKS encryption of FHEW secret key. Pad up to the closest power of two
uint32_t n_po2 = 1 << static_cast<uint32_t>(std::ceil(std::log2(n)));
auto& skLWEElements = lwesk->GetElement();
const auto neg = lwesk->GetModulus().ConvertToDouble() - 1.0;
std::vector<std::complex<double>> skLWEDouble(n_po2);
for (uint32_t i = 0; i < n; ++i) {
auto tmp = skLWEElements[i].ConvertToDouble();
skLWEDouble[i] = std::complex<double>(tmp == neg ? -1.0 : tmp, 0);
}
// Check encoding and specify the number of slots, otherwise, if batchsize is set and is smaller, it will throw an error.
Plaintext skLWEPlainswk;
if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, BASE_NUM_LEVELS_TO_DROP,
nullptr, ringDim / 2);
else
skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, 0, nullptr, ringDim / 2);
m_FHEWtoCKKSswk = ccCKKS->Encrypt(publicKey, skLWEPlainswk);
// Compute automorphism keys for CKKS for baby-step giant-step
if (numCtxts == 0)
numCtxts = m_numSlotsCKKS; // If no value is specified, default to the column size of the linear transformation
uint32_t M = ccCKKS->GetCyclotomicOrder();
if (dim1 == 0)
dim1 = getRatioBSGSLT(numCtxts);
m_dim1FC = dim1;
m_LFC = L;
// Compute indices for rotations for homomorphic decryption in CKKS
std::vector<int32_t> indexRotationHomDec = FindLTRotationIndicesSwitch(dim1, M, numCtxts);
// If the linear transform is wide instead of tall, we need extra rotations
if (numCtxts < n_po2) {
uint32_t logl = lbcrypto::GetMSB(n_po2 / numCtxts) - 1; // These are powers of two, so log(l) is integer
indexRotationHomDec.reserve(indexRotationHomDec.size() + logl);
for (uint32_t j = 1; j <= logl; ++j)
indexRotationHomDec.emplace_back(numCtxts * (1 << (j - 1)));
}
uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
// Compute indices for rotations to bring back the final CKKS ciphertext encoding to slots
if (ringDim > 2 * slots) { // if the encoding is full, this does not execute
indexRotationHomDec.reserve(indexRotationHomDec.size() + GetMSB(ringDim) - 2);
for (uint32_t j = 1; j < ringDim / (2 * slots); j <<= 1)
indexRotationHomDec.emplace_back(j * slots);
}
// Remove possible duplicates and zero
sort(indexRotationHomDec.begin(), indexRotationHomDec.end());
indexRotationHomDec.erase(unique(indexRotationHomDec.begin(), indexRotationHomDec.end()),
indexRotationHomDec.end());
indexRotationHomDec.erase(std::remove(indexRotationHomDec.begin(), indexRotationHomDec.end(), 0),
indexRotationHomDec.end());
auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationHomDec);
// Compute multiplication key
ccCKKS->EvalMultKeyGen(privateKey);
return evalKeys;
}
Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalFHEWtoCKKS(std::vector<std::shared_ptr<LWECiphertextImpl>>& LWECiphertexts,
uint32_t numCtxts, uint32_t numSlots, uint32_t p, double pmin,
double pmax, uint32_t dim1) const {
if (LWECiphertexts.empty())
OPENFHE_THROW("Empty input FHEW ciphertext vector");
// This is the number of CKKS slots to use in encoding
const uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
uint32_t numLWECtxts = LWECiphertexts.size();
uint32_t numValues = (numCtxts == 0) ? numLWECtxts : std::min(numCtxts, numLWECtxts);
numValues = std::min(numValues, slots); // This is the number of LWE ciphertexts to pack into the CKKS ciphertext
uint32_t n = LWECiphertexts[0]->GetA().GetLength();
auto ccCKKS = m_FHEWtoCKKSswk->GetCryptoContext();
const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());
uint32_t m = 4 * slots;
uint32_t M = ccCKKS->GetCyclotomicOrder();
uint32_t N = ccCKKS->GetRingDimension();
bool isSparse = (M != m);
double K = 0.0;
std::vector<double> coefficientsFHEW; // EvalFHEWtoCKKS assumes lattice parameter n is at most 2048.
if (n == 32) {
K = 16.0;
coefficientsFHEW.assign(g_coefficientsFHEW16);
}
else {
K = 128.0; // Failure probability of 2^{-49}
if (p <= 4) {
// If the output messages are bits, we could use a lower degree polynomial
coefficientsFHEW.assign(g_coefficientsFHEW128_8);
}
else {
coefficientsFHEW.assign(g_coefficientsFHEW128_9);
}
}
// Step 1. Form matrix A and vector b from the LWE ciphertexts, but only extract the first necessary number of them
std::vector<std::vector<std::complex<double>>> A(numValues);
// To have the same encoding as A*s, create b with the appropriate number of elements
const uint32_t b_size = std::min(((numValues % n) != 0) ? (numValues + n - (numValues % n)) : numValues, N / 2);
std::vector<std::complex<double>> b(b_size);
// Combine the scale with the division by K to consume fewer levels, but careful since the value might be too small
const double prescale = (1.0 / LWECiphertexts[0]->GetModulus().ConvertToDouble()) / K;
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numValues))
for (uint32_t i = 0; i < numValues; ++i) {
auto& a = LWECiphertexts[i]->GetA();
A[i].resize(a.GetLength());
for (uint32_t j = 0; j < a.GetLength(); ++j)
A[i][j] = std::complex<double>(a[j].ConvertToDouble(), 0);
b[i] = std::complex<double>(prescale * LWECiphertexts[i]->GetB().ConvertToDouble(), 0);
}
// Step 2. Perform the homomorphic linear transformation of A*skLWE
if (dim1 == 0)
dim1 = m_dim1FC;
auto AdotS = EvalPartialHomDecryption(*ccCKKS, A, m_FHEWtoCKKSswk, dim1, prescale, 0);
// Step 3. Get the ciphertext of B - A*s
Plaintext BPlain = ccCKKS->MakeCKKSPackedPlaintext(b, AdotS->GetNoiseScaleDeg(), AdotS->GetLevel(), nullptr, N / 2);
auto BminusAdotS = ccCKKS->EvalAdd(ccCKKS->EvalNegate(AdotS), BPlain);
if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL)
ccCKKS->ModReduceInPlace(BminusAdotS);
else if (BminusAdotS->GetNoiseScaleDeg() == 2)
ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS, BASE_NUM_LEVELS_TO_DROP);
// Step 4. Do the modulus reduction: homomorphically evaluate modular function. We do it by using sine approximation.
// auto BminusAdotS2 = BminusAdotS; // Instead of zeroing out slots which are not of interest as done above
double a_cheby = -1.0;
double b_cheby = 1.0; // The division by K was performed before
// double a_cheby = -K; double b_cheby = K; // Alternatively, do this separately to not lose precision when scaling with everything at once
auto BminusAdotS3 = ccCKKS->EvalChebyshevSeries(BminusAdotS, coefficientsFHEW, a_cheby, b_cheby);
if (cryptoParamsCKKS->GetScalingTechnique() != FIXEDMANUAL)
ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS3, BASE_NUM_LEVELS_TO_DROP);
const int32_t BT_ITER = 3;
for (int32_t j = 1; j <= BT_ITER; ++j) {
BminusAdotS3 = ccCKKS->EvalMult(BminusAdotS3, BminusAdotS3);
ccCKKS->EvalAddInPlace(BminusAdotS3, BminusAdotS3);
double scalar = 1.0 / std::pow((2.0 * M_PI), std::pow(2.0, j - BT_ITER));
ccCKKS->EvalSubInPlace(BminusAdotS3, scalar);
if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL)
ccCKKS->ModReduceInPlace(BminusAdotS3);
else
ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS3, BASE_NUM_LEVELS_TO_DROP);
}
/* For p <= 4 and when we only encrypt bits, we don't need sin(2pi*x)/2pi to approximate x,
* we can directly use sin(0) for 0 and sin(pi/2) for 1.
* Here pmax is actually the plaintext modulus, not the maximum value of the messages that we
* consider. For plaintext modulus > 4, even if we only care about encrypting bits, 2pi is not
* the correct post-scaling factor.
* Moreover, we have to account for the different encoding the end ciphertext should have.
*/
double postScale = (p >= 1 && p <= 4) ? (2.0 * M_PI) : static_cast<double>(p);
double postBias = 0.0;
if (pmin != 0) {
postScale *= (pmax - pmin) / 4.0;
postBias = (pmax - pmin) / 4.0;
}
// numValues are set; the rest of values up to N/2 are made zero when creating the plaintext
std::vector<std::complex<double>> postScaleVec(numValues, std::complex<double>(postScale, 0));
std::vector<std::complex<double>> postBiasVec(numValues, std::complex<double>(postBias, 0));
ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());
uint32_t towersToDrop = BminusAdotS3->GetLevel() + BminusAdotS3->GetNoiseScaleDeg() - 1;
for (uint32_t i = 0; i < towersToDrop; i++)
elementParams.PopLastParam();
// Use full packing here to clear up the junk in the slots after numValues
auto postScalePlain = ccCKKS->MakeCKKSPackedPlaintext(postScaleVec, 1, towersToDrop, nullptr, N / 2);
auto BminusAdotSres = ccCKKS->EvalMult(BminusAdotS3, postScalePlain);
// Add the plaintext for bias at the correct level and depth
auto postBiasPlain = ccCKKS->MakeCKKSPackedPlaintext(postBiasVec, BminusAdotSres->GetNoiseScaleDeg(),
BminusAdotSres->GetLevel(), nullptr, N / 2);
ccCKKS->EvalAddInPlace(BminusAdotSres, postBiasPlain);
// Go back to the sparse encoding if needed
if (isSparse) {
for (uint32_t j = 1; j < N / (2 * slots); j <<= 1) {
auto temp = ccCKKS->EvalAtIndex(BminusAdotSres, j * slots);
ccCKKS->EvalAddInPlace(BminusAdotSres, temp);
}
BminusAdotSres->SetSlots(slots);
}
if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL) {
ccCKKS->ModReduceInPlace(BminusAdotSres);
}
return BminusAdotSres;
}
LWEPrivateKey SWITCHCKKSRNS::EvalSchemeSwitchingSetup(const SchSwchParams& params) {
// FHEW to CKKS
// Save the parameters to be used in EvalSchemeSwitchingKeyGen
m_argmin = params.GetComputeArgmin();
m_oneHot = params.GetOneHotEncoding();
m_alt = params.GetUseAltArgmin();
// Set parameters for linear transform for FHEW to CKKS
if (!m_argmin || (m_argmin && m_alt)) {
m_numCtxts = (params.GetNumValues() == 0) ? m_numSlotsCKKS : params.GetNumValues();
}
else { // argmin not in the alternative mode
m_numCtxts = (params.GetNumValues() == 0) ? m_numSlotsCKKS / 2 : params.GetNumValues() / 2;
}
// There are multiple dim1's required in argmin, but they are specified individually in EvalSchemeSwitchingKeyGen
m_dim1FC = (params.GetBStepLTrFHEWtoCKKS() == 0) ? getRatioBSGSLT(m_numCtxts) : params.GetBStepLTrFHEWtoCKKS();
m_LFC = params.GetLevelLTrFHEWtoCKKS();
// CKKS to FHEW
return EvalCKKStoFHEWSetup(params);
}
std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalSchemeSwitchingKeyGen(
const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk) {
auto& privateKey = keyPair.secretKey;
auto& publicKey = keyPair.publicKey;
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());
if (cryptoParams->GetKeySwitchTechnique() != HYBRID)
OPENFHE_THROW("CKKS to FHEW scheme switching is only supported for the Hybrid key switching method.");
#if NATIVEINT == 128
if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
OPENFHE_THROW("128-bit CKKS to FHEW scheme switching is supported for FIXEDMANUAL and FIXEDAUTO methods only.");
#endif
auto ccCKKS = privateKey->GetCryptoContext();
uint32_t M = ccCKKS->GetCyclotomicOrder();
uint32_t slots = m_numSlotsCKKS;
uint32_t n = lwesk->GetElement().GetLength();
uint32_t ringDim = ccCKKS->GetRingDimension();
// Intermediate cryptocontext for CKKS to FHEW
auto keys2 = m_ccKS->KeyGen();
Plaintext ptxtZeroKS = m_ccKS->MakeCKKSPackedPlaintext(std::vector<double>{0.0}, 1, 0, nullptr, slots);
m_ctxtKS = m_ccKS->Encrypt(keys2.publicKey, ptxtZeroKS);
// Compute switching key between RLWE and LWE via the intermediate cryptocontext, keep it in RLWE form
m_CKKStoFHEWswk = switchingKeyGenRLWEcc(keys2.secretKey, privateKey, lwesk);
// Generate FHEW to CKKS switching key, i.e., CKKS encryption of FHEW secret key. Pad up to the closest power of two
uint32_t n_po2 = 1 << static_cast<uint32_t>(std::ceil(std::log2(n)));
auto skLWEElements = lwesk->GetElement();
const auto neg = lwesk->GetModulus().ConvertToDouble() - 1.0;
std::vector<std::complex<double>> skLWEDouble(n_po2);
for (uint32_t i = 0; i < n; i++) {
auto tmp = skLWEElements[i].ConvertToDouble();
skLWEDouble[i] = std::complex<double>(tmp == neg ? -1.0 : tmp, 0);
}
// Check encoding and specify the number of slots, otherwise, if batchsize is set and is smaller, it will throw an error.
Plaintext skLWEPlainswk;
if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, BASE_NUM_LEVELS_TO_DROP,
nullptr, ringDim / 2);
else
skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, 0, nullptr, ringDim / 2);
m_FHEWtoCKKSswk = ccCKKS->Encrypt(publicKey, skLWEPlainswk);
// Compute automorphism keys
/* CKKS to FHEW */
// Compute indices for rotations for slotToCoeff transform
std::vector<int32_t> indexRotationS2C = FindLTRotationIndicesSwitch(m_dim1CF, M, slots);
indexRotationS2C.emplace_back(static_cast<int32_t>(slots));
// Compute indices for rotations for sparse packing
if (ringDim > 2 * slots) { // if the encoding is full, this does not execute
indexRotationS2C.reserve(indexRotationS2C.size() + GetMSB(ringDim) - 2 + GetMSB(slots) - 1);
for (uint32_t i = 1; i < ringDim / 2; i <<= 1) {
indexRotationS2C.emplace_back(static_cast<int32_t>(i));
if (i <= slots)
indexRotationS2C.emplace_back(-static_cast<int32_t>(i));
}
}
/* FHEW to CKKS */
std::vector<int32_t> indexRotationHomDec;
std::vector<int32_t> indexRotationArgmin;
if (!m_argmin || (m_argmin && m_alt)) {
// Compute indices for rotations for homomorphic decryption
indexRotationHomDec = FindLTRotationIndicesSwitch(m_dim1FC, M, m_numCtxts);
// If the linear transform is wide instead of tall, we need extra rotations
if (m_numCtxts < n_po2) {
uint32_t logl = lbcrypto::GetMSB(n_po2 / m_numCtxts) - 1; // These are powers of two, so log(l) is integer
indexRotationHomDec.reserve(indexRotationHomDec.size() + logl);
for (size_t j = 1; j <= logl; ++j) {
indexRotationHomDec.emplace_back(m_numCtxts * (1 << (j - 1)));
}
}
if (m_argmin) {
// Rotations for postprocessing after a level of the binary tree
indexRotationArgmin.reserve(GetMSB(m_numCtxts) - 2);
for (uint32_t i = 1; i < m_numCtxts; i <<= 1) {
indexRotationArgmin.emplace_back(static_cast<int32_t>(m_numCtxts / (2 * i)));
}
}
}
else { // argmin not in the alternative mode
// Compute indices for rotations for all homomorphic decryptions for the levels of the tree
indexRotationHomDec = FindLTRotationIndicesSwitchArgmin(M, m_numCtxts, n_po2);
// Rotations for postprocessing after a level of the binary tree
indexRotationArgmin.reserve(GetMSB(m_numCtxts) - 1 + 2 * (GetMSB(m_numCtxts) - 1));
for (uint32_t i = 1; i < 2 * m_numCtxts; i <<= 1) {
indexRotationArgmin.emplace_back(static_cast<int32_t>(m_numCtxts / (2 * i)));
indexRotationArgmin.emplace_back(-static_cast<int32_t>(m_numCtxts / (2 * i)));
if (i > 1) {
for (uint32_t j = 2 * m_numCtxts / i; j < 2 * m_numCtxts; j <<= 1)
indexRotationArgmin.emplace_back(-static_cast<int32_t>(j));
}
}
}
// Compute indices for rotations to bring back the final CKKS ciphertext encoding to slots
if (ringDim > 2 * slots) { // if the encoding is full, this does not execute
indexRotationHomDec.reserve(indexRotationHomDec.size() + GetMSB(ringDim) - 2);
for (uint32_t j = 1; j < ringDim / (2 * slots); j <<= 1) {
indexRotationHomDec.emplace_back(j * slots);
}
}
// Combine the indices lists
indexRotationS2C.reserve(indexRotationS2C.size() + indexRotationHomDec.size() + indexRotationArgmin.size());
indexRotationS2C.insert(indexRotationS2C.end(), indexRotationHomDec.begin(), indexRotationHomDec.end());
indexRotationS2C.insert(indexRotationS2C.end(), indexRotationArgmin.begin(), indexRotationArgmin.end());
// Remove possible duplicates and zero
sort(indexRotationS2C.begin(), indexRotationS2C.end());
indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());
indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());
auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);
// Compute conjugation key
auto conjKey = FHECKKSRNS::ConjugateKeyGen(privateKey);
(*evalKeys)[M - 1] = conjKey;
// Compute multiplication key
ccCKKS->EvalMultKeyGen(privateKey);
// Compute automorphism keys if we don't want one hot encoding for argmin
if (m_argmin && (m_oneHot == false)) {
ccCKKS->EvalSumKeyGen(privateKey);
}
/* FHEW computations */
// Generate the bootstrapping keys (refresh and switching keys)
m_ccLWE->BTKeyGen(lwesk);
return evalKeys;
}
void SWITCHCKKSRNS::EvalCompareSwitchPrecompute(const CryptoContextImpl<DCRTPoly>& ccCKKS, uint32_t pLWE,
double scaleSign, bool unit) {
double scaleCF = 1.0;
if ((pLWE != 0) && (!unit)) { // The messages are already scaled between 0 and 1, no need to divide by pLWE
scaleCF = 1.0 / pLWE;
}
// Else perform no scaling; the implicit FHEW plaintext modulus will be m_modulus_CKKS_initial / scFactor
scaleCF *= scaleSign;
EvalCKKStoFHEWPrecompute(ccCKKS, scaleCF);
}
Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalCompareSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext1,
ConstCiphertext<DCRTPoly> ciphertext2, uint32_t numCtxts,
uint32_t numSlots, uint32_t pLWE, double scaleSign,
bool unit) {
auto ccCKKS = ciphertext1->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());
auto cDiff = ccCKKS->EvalSub(ciphertext1, ciphertext2);
if (unit) {
if (pLWE == 0)
OPENFHE_THROW("To scale to the unit circle, pLWE must be non-zero.");
cDiff = ccCKKS->EvalMult(cDiff, 1.0 / static_cast<double>(pLWE));
cDiff = ccCKKS->Rescale(cDiff);
}
// The precomputation has already been performed, but if it is scaled differently than desired, recompute it
if (pLWE != 0) {
double scaleCF = 1.0;
if ((pLWE != 0) && (!unit)) {
scaleCF = 1.0 / pLWE;
}
scaleCF *= scaleSign;
EvalCKKStoFHEWPrecompute(*ccCKKS, scaleCF);
}
auto LWECiphertexts = EvalCKKStoFHEW(cDiff, numCtxts);
const uint32_t n = LWECiphertexts.size();
std::vector<LWECiphertext> cSigns(n);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t i = 0; i < n; ++i)
cSigns[i] = m_ccLWE->EvalSign(LWECiphertexts[i], true);
return EvalFHEWtoCKKS(cSigns, numCtxts, numSlots, 4, -1.0, 1.0, 0);
}
std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext,
PublicKey<DCRTPoly> publicKey,
uint32_t numValues, uint32_t numSlots,
uint32_t pLWE, double scaleSign) {
auto cc = ciphertext->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());
// The precomputation has already been performed, but if it is scaled differently than desired, recompute it
if (pLWE != 0)
EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);
uint32_t towersToDrop = 12; // How many levels are consumed in the EvalFHEWtoCKKS
uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
Plaintext pInd;
if (m_oneHot) {
std::vector<std::complex<double>> ind(numValues, 1.0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
else {
std::vector<std::complex<double>> ind(numValues);
std::iota(ind.begin(), ind.end(), 0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
auto cInd = cc->Encrypt(publicKey, pInd);
auto newCiphertext = ciphertext->Clone();
for (uint32_t M = 1; M < numValues; M <<= 1) {
const auto n = numValues / (2 * M);
// Compute CKKS ciphertext encoding difference of the first numValues
auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
// Transform the ciphertext from CKKS to FHEW
auto cTemp = EvalCKKStoFHEW(cDiff, n);
// Evaluate the sign
// We always assume for the moment that numValues is a power of 2
std::vector<LWECiphertext> LWESign(n);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t j = 0; j < n; ++j)
LWESign[j] = m_ccLWE->EvalSign(cTemp[j], true);
// Scheme switching from FHEW to CKKS
auto dim1 = getRatioBSGSLT(n);
auto cSelect = EvalFHEWtoCKKS(LWESign, n, numSlots, 4, -1.0, 1.0, dim1);
std::vector<std::complex<double>> ones(n, 1.0);
Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
cc->EvalAddInPlace(cSelect, cc->EvalAtIndex(cc->EvalSub(ptxtOnes, cSelect), -static_cast<int32_t>(n)));
if (M > 1) {
for (uint32_t j = numValues / M; j < numValues; j <<= 1)
cc->EvalAddInPlace(cSelect, cc->EvalAtIndex(cSelect, -static_cast<int32_t>(j)));
}
// Update the ciphertext of values and the indicator
newCiphertext = cc->EvalMult(newCiphertext, cSelect);
cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(newCiphertext);
cInd = cc->EvalMult(cInd, cSelect);
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(cInd);
}
// After computing the minimum and argument
if (!m_oneHot)
cInd = cc->EvalSum(cInd, numValues);
return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}
std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitchingAlt(ConstCiphertext<DCRTPoly> ciphertext,
PublicKey<DCRTPoly> publicKey,
uint32_t numValues, uint32_t numSlots,
uint32_t pLWE, double scaleSign) {
auto cc = ciphertext->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());
// The precomputation has already been performed, but if it is scaled differently than desired, recompute it
if (pLWE != 0)
EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);
uint32_t towersToDrop = 12; // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output.
uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
Plaintext pInd;
if (m_oneHot) {
std::vector<std::complex<double>> ind(numValues, 1.0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
else {
std::vector<std::complex<double>> ind(numValues);
std::iota(ind.begin(), ind.end(), 0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
auto cInd = cc->Encrypt(publicKey, pInd);
auto newCiphertext = ciphertext->Clone();
for (uint32_t M = 1; M < numValues; M <<= 1) {
const auto n = numValues / (2 * M);
// Compute CKKS ciphertext encoding difference of the first numValues
auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
// Transform the ciphertext from CKKS to FHEW
auto cTemp = EvalCKKStoFHEW(cDiff, n);
// Evaluate the sign
// We always assume for the moment that numValues is a power of 2
std::vector<LWECiphertext> LWESign(numValues);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t j = 0; j < n; ++j) {
LWECiphertext tempSign = m_ccLWE->EvalSign(cTemp[j], true);
LWECiphertext negTempSign = std::make_shared<LWECiphertextImpl>(*tempSign);
m_ccLWE->GetLWEScheme()->EvalAddConstEq(negTempSign, negTempSign->GetModulus() >> 1); // "negated" tempSign
for (uint32_t i = 0; i < 2 * M; i += 2) {
LWESign[i * n + j] = tempSign;
LWESign[(i + 1) * n + j] = negTempSign;
}
}
// Scheme switching from FHEW to CKKS
auto dim1 = getRatioBSGSLT(numValues);
auto cExpandSelect = EvalFHEWtoCKKS(LWESign, numValues, numSlots, 4, -1.0, 1.0, dim1);
// Update the ciphertext of values and the indicator
newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(newCiphertext);
cInd = cc->EvalMult(cInd, cExpandSelect);
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(cInd);
}
// After computing the minimum and argument
if (!m_oneHot)
cInd = cc->EvalSum(cInd, numValues);
return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}
// TODO: used anywhere?
std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMaxSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext,
PublicKey<DCRTPoly> publicKey,
uint32_t numValues, uint32_t numSlots,
uint32_t pLWE, double scaleSign) {
auto cc = ciphertext->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());
// The precomputation has already been performed, but if it is scaled differently than desired, recompute it
if (pLWE != 0)
EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);
uint32_t towersToDrop = 12; // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output.
uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
Plaintext pInd;
if (m_oneHot) {
std::vector<std::complex<double>> ind(numValues, 1.0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
else {
std::vector<std::complex<double>> ind(numValues);
std::iota(ind.begin(), ind.end(), 0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
auto cInd = cc->Encrypt(publicKey, pInd);
auto newCiphertext = ciphertext->Clone();
for (uint32_t M = 1; M < numValues; M <<= 1) {
const auto n = numValues / (2 * M);
// Compute CKKS ciphertext encoding difference of the first numValues
auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
// Transform the ciphertext from CKKS to FHEW
auto cTemp = EvalCKKStoFHEW(cDiff, n);
// Evaluate the sign
// We always assume for the moment that numValues is a power of 2
std::vector<LWECiphertext> LWESign(n);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t j = 0; j < n; ++j)
LWESign[j] = m_ccLWE->EvalSign(cTemp[j], true);
// Scheme switching from FHEW to CKKS
auto dim1 = getRatioBSGSLT(n);
auto cSelect = EvalFHEWtoCKKS(LWESign, n, numSlots, 4, -1.0, 1.0, dim1);
std::vector<std::complex<double>> ones(n, 1.0);
Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
cSelect = cc->EvalAdd(cc->EvalSub(ptxtOnes, cSelect), cc->EvalAtIndex(cSelect, -static_cast<int32_t>(n)));
if (M > 1) {
for (uint32_t j = numValues / M; j < numValues; j <<= 1)
cc->EvalAddInPlace(cSelect, cc->EvalAtIndex(cSelect, -static_cast<int32_t>(j)));
}
// Update the ciphertext of values and the indicator
newCiphertext = cc->EvalMult(newCiphertext, cSelect);
cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(newCiphertext);
cInd = cc->EvalMult(cInd, cSelect);
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(cInd);
}
// After computing the minimum and argument
if (!m_oneHot)
cInd = cc->EvalSum(cInd, numValues);
return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}
// TODO: used anywhere?
std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMaxSchemeSwitchingAlt(ConstCiphertext<DCRTPoly> ciphertext,
PublicKey<DCRTPoly> publicKey,
uint32_t numValues, uint32_t numSlots,
uint32_t pLWE, double scaleSign) {
auto cc = ciphertext->GetCryptoContext();
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());
// The precomputation has already been performed, but if it is scaled differently than desired, recompute it
if (pLWE != 0)
EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);
uint32_t towersToDrop = 12; // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output
uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
Plaintext pInd;
if (m_oneHot) {
std::vector<std::complex<double>> ind(numValues, 1.0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
else {
std::vector<std::complex<double>> ind(numValues);
std::iota(ind.begin(), ind.end(), 0);
pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
}
auto cInd = cc->Encrypt(publicKey, pInd);
auto newCiphertext = ciphertext->Clone();
for (uint32_t M = 1; M < numValues; M <<= 1) {
const auto n = numValues / (2 * M);
// Compute CKKS ciphertext encoding difference of the first numValues
auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
// Transform the ciphertext from CKKS to FHEW
auto cTemp = EvalCKKStoFHEW(cDiff, n);
// Evaluate the sign
// We always assume for the moment that numValues is a power of 2
std::vector<LWECiphertext> LWESign(numValues);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t j = 0; j < n; ++j) {
LWECiphertext tempSign = m_ccLWE->EvalSign(cTemp[j], true);
LWECiphertext negTempSign = std::make_shared<LWECiphertextImpl>(*tempSign);
m_ccLWE->GetLWEScheme()->EvalAddConstEq(negTempSign, negTempSign->GetModulus() >> 1); // "negated" tempSign
for (uint32_t i = 0; i < 2 * M; i += 2) {
LWESign[i * n + j] = negTempSign;
LWESign[(i + 1) * n + j] = tempSign;
}
}
// Scheme switching from FHEW to CKKS
auto dim1 = getRatioBSGSLT(numValues);
auto cExpandSelect = EvalFHEWtoCKKS(LWESign, numValues, numSlots, 4, -1.0, 1.0, dim1);
// Update the ciphertext of values and the indicator
newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(newCiphertext);
cInd = cc->EvalMult(cInd, cExpandSelect);
if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
cc->ModReduceInPlace(cInd);
}
// After computing the minimum and argument
if (!m_oneHot)
cInd = cc->EvalSum(cInd, numValues);
return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}
} // namespace lbcrypto