Program Listing for File ckksrns-schemeswitching.cpp

Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-schemeswitching.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
    CKKS to FHEW scheme switching implementation.
 */

#define PROFILE

#include "scheme/ckksrns/ckksrns-schemeswitching.h"

#include "cryptocontext.h"
#include "gen-cryptocontext.h"
#include "scheme/ckksrns/gen-cryptocontext-ckksrns.h"

#include "math/dftransform.h"

#include <iterator>

// K = 16
static constexpr std::initializer_list<double> g_coefficientsFHEW16{
    0.2455457340168511,     -0.04791906488334782,   0.2838870204084082,     -0.02994453873551349,
    0.3557652261903648,     0.01510656188507299,    0.2953294667450001,     0.07120360233373937,
    -0.1034734733966807,    0.04499759051255525,    -0.4275071243192574,    -0.09034212972909554,
    0.367628762693249,      0.04931806603933471,    -0.14535986272412,      -0.01510693848306369,
    0.03595193549924024,    0.003103658218868759,   -0.006264460660707066,  -0.0004660943047712052,
    0.0008212879885240095,  5.391053389217882e-05,  -8.455154976914221e-05, -4.977380178602518e-06,
    7.04666204400644e-06,   3.765980757166835e-07,  -4.864851013612591e-07, -2.383026746930811e-08,
    2.832970640938316e-08,  1.28177429687131e-09,   -1.412145521045378e-09, -5.939145408994641e-11,
    6.099273252183116e-11,  2.397381728164642e-12,  -2.307402856353623e-12, -8.500921247536622e-14,
    7.704571444110577e-14,  2.704051671841271e-15,  -2.154585361348821e-15, -7.263493008564584e-16,
    -1.260761739828568e-16, 1.637108527837095e-16,  1.185492382226862e-16,  6.379078056744543e-16,
    -1.411300455031979e-16, 3.123678340470779e-16,  5.946279250534737e-16,  2.954322285866942e-16,
    -8.279629336187608e-17, 5.024229619913844e-16,  -3.293034395074617e-16, -1.189255850106947e-15,
    1.674743206637948e-16,  -1.524204491434537e-16, -7.90328254817908e-17,  3.95164127408954e-16,
    -1.317213758029847e-17, 8.016186584581639e-16,  -3.650563843682718e-16, 3.763467880085276e-16,
    -2.709696873661399e-16, -1.524204491434537e-16, -4.83605622590958e-16,  7.376397044967142e-16,
    1.234417464667971e-15,  -2.672062194860546e-16, -4.892508244110859e-17, -7.122362963061386e-16,
    3.763467880085276e-17,  -2.944913616166729e-16, -2.897870267665663e-16, 9.794425157921932e-16,
    -3.198947698072485e-17, 6.614294799249873e-16,  -5.7769231959309e-16,   6.586068790149234e-16,
    -4.629065492504889e-16, -5.127724986616189e-16, -3.236582376873337e-16, -1.64745806450733e-15,
    -9.408669700213192e-16, -4.986594941112991e-16, -1.209954923447416e-15, -1.373665776231126e-16,
    -2.314532746252445e-16, 3.217765037472911e-16,  3.481207789078881e-16,  8.223177317986329e-16,
    -9.766199148821293e-16, 6.19090466274028e-16,   1.209014056477395e-15,  -3.30244306477483e-16,
    5.974505259635377e-16,  5.993322599035803e-16,  1.829986256691466e-16,  -2.690879534260973e-16,
    8.618341445395283e-16,  -1.002023323072705e-16, 6.374373721894436e-16,  6.270878355192092e-16,
    1.199605386777182e-15,  -8.712428142397415e-16, -2.507410475106815e-16, -1.086230916889613e-15,
    1.072588345824304e-15,  -4.534978795502758e-16, 2.119067633230516e-15,  -1.842923177529259e-15,
    -1.814697168428619e-15, 4.243310034796149e-16,  4.224492695395723e-16,  1.531966643937213e-15,
    -2.850826919164597e-16, -8.958229638315484e-16, -5.02893395476395e-16,  1.096110020074837e-16,
    -6.975352498995555e-16, -8.743006318923108e-16};

// K = 128
static constexpr std::initializer_list<double> g_coefficientsFHEW128_9{
    0.08761193238226354,    -0.01738402917379392,   0.08935060894767313,    -0.01667686631436392,
    0.09435445639097996,    -0.01518333497826596,   0.1019473189108075,     -0.01276275748916528,
    0.110882655474149,      -0.009252446966171999,  0.1192111685574758,     -0.004534979909938953,
    0.1242004317120066,     0.001362904847617233,   0.1224283765086551,     0.008145596233693092,
    0.1102080588183085,     0.01512350467093644,    0.08449405378412403,    0.02114203679334985,
    0.04431786059830115,    0.02464956129638114,    -0.007454366487154669,  0.02400059020366966,
    -0.06266441339261235,   0.0180491215413637,     -0.1077943201829795,    0.00695836538813938,
    -0.1265848641500751,    -0.007067567033131986,  -0.1060856934163377,    -0.01966175019277508,
    -0.04512467324356773,   -0.02537595733026167,   0.03862916785371963,    -0.0201785566296389,
    0.1092652333753526,     -0.004612578019766411,  0.1263344585514989,     0.01438496124842956,
    0.07022427857484209,    0.02550072245548077,    -0.03434514153678107,   0.01979242584243335,
    -0.1194659697149694,    -0.001008794768691528,  -0.1149256786653964,    -0.02192904329965044,
    -0.01184295110147364,   -0.02417858011117596,   0.1066507410103885,     -0.003076473516323838,
    0.122343225763269,      0.02209885820126707,    0.005200840409852563,   0.02321022960558625,
    -0.1224755172356864,    -0.003930982569218595,  -0.1000653894904632,    -0.02689795846413602,
    0.05865754664309684,    -0.01297065380451242,   0.1377909895596227,     0.02083617539534925,
    0.006502421233003679,   0.02248299870285675,    -0.139660074659475,     -0.01399307934458518,
    -0.04589168496663835,   -0.0263421662574377,    0.1358978738303921,     0.01130242907664306,
    0.05563799538901031,    0.02715486116995984,    -0.1426236952996719,    -0.01461041285557406,
    -0.03302834981489188,   -0.02454368648125577,   0.1559877850928394,     0.02360418859443232,
    -0.03051465817859748,   0.01394389273916019,    -0.1434779685133351,    -0.0326137520114734,
    0.1272587840850199,     0.00968806150092634,    0.04489729856072615,    0.02496761251245225,
    -0.1723551233719199,    -0.03505277577503257,   0.1396636892583818,     0.01468861799711852,
    0.00597622458952793,    0.01686435635501478,    -0.1508869780062401,    -0.0392626068463298,
    0.2221665014327329,     0.04513725581939847,    -0.2157338005707834,    -0.03852627732394119,
    0.1657363840292956,     0.02705951812948022,    -0.1076077703571204,    -0.01637507278621027,
    0.06108202920758021,    0.00876339054415577,    -0.03094805600072437,   -0.004218297546715713,
    0.01419483272929196,    0.001848310901625205,   -0.005954927442783235,  -0.000743844834357433,
    0.002303049930851211,   0.000276890872388833,   -0.0008263094529170254, -9.587788269377866e-05,
    0.000276459761481133,   3.10280277328683e-05,   -8.662530848949058e-05, -9.421991495095434e-06,
    2.551403977799462e-05,  2.693817139509806e-06,  -7.08627853168575e-06,  -7.273160333789605e-07,
    1.861116908629967e-06,  1.859288982562724e-07,  -4.633601616493963e-07, -4.510790623761216e-08,
    1.096004151863654e-07,  1.040757606442787e-08,  -2.467843591423806e-08, -2.288015736340782e-09,
    5.299290810294302e-09,  4.800959747999802e-10,  -1.086991796689273e-09, -9.63033831430537e-11,
    2.133040952766737e-10,  1.849336626863696e-11,  -4.010173216641153e-11, -3.404204559101731e-12,
    7.232799297263385e-12,  6.027665442316686e-13,  -1.252868789531569e-12, -1.035300267428826e-13,
    2.072453780944445e-13,  1.81572061755555e-14,   -3.012503176280137e-14, -4.417490972407089e-16,
    3.698522891563647e-15,  -4.204154533937635e-16, -2.740777660720187e-15, -1.348919106364917e-15,
    -1.620799477984723e-15, 4.003965342611375e-16,  -5.245330582249314e-16, 1.754761547401069e-15,
    -5.0481471966847e-16,   -4.722624632690369e-16, 1.628901569091919e-16,  -1.219903204684612e-15};

static constexpr std::initializer_list<double> g_coefficientsFHEW128_8{
    0.08761193238226343,    -0.01738402917379268,  0.08935060894767202,   -0.0166768663143651,   0.09435445639098095,
    -0.01518333497826714,   0.1019473189108076,    -0.01276275748916462,  0.1108826554741475,    -0.009252446966171845,
    0.1192111685574773,     -0.004534979909938402, 0.1242004317120066,    0.001362904847616587,  0.1224283765086535,
    0.008145596233693802,   0.1102080588183083,    0.0151235046709367,    0.08449405378412395,   0.02114203679334948,
    0.04431786059830203,    0.02464956129638117,   -0.007454366487155707, 0.02400059020367158,   -0.06266441339261287,
    0.01804912154136392,    -0.107794320182978,    0.006958365388138488,  -0.1265848641500738,   -0.007067567033133184,
    -0.1060856934163389,    -0.01966175019277399,  -0.0451246732435682,   -0.02537595733026211,  0.0386291678537217,
    -0.02017855662963969,   0.1092652333753532,    -0.004612578019767425, 0.1263344585514991,    0.01438496124843117,
    0.07022427857484087,    0.02550072245548053,   -0.03434514153678073,  0.01979242584243296,   -0.1194659697149702,
    -0.00100879476868968,   -0.1149256786653952,   -0.02192904329965062,  -0.01184295110147335,  -0.02417858011117619,
    0.1066507410103884,     -0.003076473516322021, 0.1223432257632692,    0.02209885820126752,   0.005200840409853516,
    0.02321022960558683,    -0.1224755172356849,   -0.003930982569218244, -0.1000653894904628,   -0.02689795846413568,
    0.05865754664309823,    -0.01297065380451253,  0.1377909895596233,    0.02083617539534807,   0.006502421233004046,
    0.02248299870285591,    -0.1396600746594754,   -0.01399307934458444,  -0.04589168496663817,  -0.02634216625743798,
    0.1358978738303917,     0.0113024290766429,    0.05563799538901171,   0.02715486116995986,   -0.1426236952996744,
    -0.01461041285557423,   -0.03302834981489241,  -0.02454368648125667,  0.155987785092838,     0.02360418859443058,
    -0.03051465817859778,   0.01394389273915945,   -0.1434779685133346,   -0.03261375201147241,  0.1272587840850196,
    0.009688061500927738,   0.04489729856072736,   0.02496761251245433,   -0.1723551233719191,   -0.03505277577503064,
    0.1396636892583768,     0.01468861799712161,   0.005976224589562133,  0.01686435635499993,   -0.1508869780064481,
    -0.03926260684622985,   0.2221665014339838,    0.04513725581879824,   -0.2157338005780147,   -0.03852627732053739,
    0.1657363840693943,     0.02705951811098469,   -0.1076077705704247,   -0.0163750726899083,   0.06108203029457551,
    0.008763390064061401,   -0.03094806130001855,  -0.00421829525869962,  0.01419485740772663,   0.001848300494047458,
    -0.005955037043195616,  -0.000743799726455706, 0.002303513291010024,  0.0002767049434914361, -0.0008281705698254814,
    -9.515056665793518e-05, 0.0002835460400168608, 2.833421059257267e-05, -0.0001121393482639905};

namespace lbcrypto {

//------------------------------------------------------------------------------
// Complex Plaintext Functions, copied from ckksrns-fhe. TODO: fix this
//------------------------------------------------------------------------------

#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
Plaintext SWITCHCKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, const std::shared_ptr<ParmType> params,
                                          const std::vector<std::complex<double>>& value, size_t noiseScaleDeg,
                                          uint32_t level, usint slots) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());

    double scFact = cryptoParams->GetScalingFactorReal(level);

    Plaintext p = Plaintext(std::make_shared<CKKSPackedEncoding>(params, cc.GetEncodingParams(), value, noiseScaleDeg,
                                                                 level, scFact, slots));

    DCRTPoly& plainElement = p->GetElement<DCRTPoly>();

    usint N = cc.GetRingDimension();

    std::vector<std::complex<double>> inverse = value;

    inverse.resize(slots);

    DiscreteFourierTransform::FFTSpecialInv(inverse, N * 2);
    uint64_t pBits = cc.GetEncodingParams()->GetPlaintextModulus();

    double powP      = std::pow(2.0, MAX_DOUBLE_PRECISION);
    int32_t pCurrent = pBits - MAX_DOUBLE_PRECISION;

    std::vector<int128_t> temp(2 * slots);
    for (size_t i = 0; i < slots; ++i) {
        // extract the mantissa of real part and multiply it by 2^52
        int32_t n1 = 0;
        double dre = std::frexp(inverse[i].real(), &n1) * powP;
        // extract the mantissa of imaginary part and multiply it by 2^52
        int32_t n2 = 0;
        double dim = std::frexp(inverse[i].imag(), &n2) * powP;

        // Check for possible overflow
        if (is128BitOverflow(dre) || is128BitOverflow(dim)) {
            DiscreteFourierTransform::FFTSpecial(inverse, N * 2);

            double invLen = static_cast<double>(inverse.size());
            double factor = 2 * M_PI * i;

            double realMax = -1, imagMax = -1;
            uint32_t realMaxIdx = -1, imagMaxIdx = -1;

            for (uint32_t idx = 0; idx < inverse.size(); idx++) {
                // exp( j*2*pi*n*k/N )
                std::complex<double> expFactor = {cos((factor * idx) / invLen), sin((factor * idx) / invLen)};

                // X[k] * exp( j*2*pi*n*k/N )
                std::complex<double> prodFactor = inverse[idx] * expFactor;

                double realVal = prodFactor.real();
                double imagVal = prodFactor.imag();

                if (realVal > realMax) {
                    realMax    = realVal;
                    realMaxIdx = idx;
                }
                if (imagVal > imagMax) {
                    imagMax    = imagVal;
                    imagMaxIdx = idx;
                }
            }

            auto scaledInputSize = ceil(log2(dre));

            std::stringstream buffer;
            buffer << std::endl
                   << "Overflow in data encoding - scaled input is too large to fit "
                      "into a NativeInteger (60 bits). Try decreasing scaling factor."
                   << std::endl;
            buffer << "Overflow at slot number " << i << std::endl;
            buffer << "- Max real part contribution from input[" << realMaxIdx << "]: " << realMax << std::endl;
            buffer << "- Max imaginary part contribution from input[" << imagMaxIdx << "]: " << imagMax << std::endl;
            buffer << "Scaling factor is " << ceil(log2(powP)) << " bits " << std::endl;
            buffer << "Scaled input is " << scaledInputSize << " bits " << std::endl;
            OPENFHE_THROW(buffer.str());
        }

        int64_t re64       = std::llround(dre);
        int32_t pRemaining = pCurrent + n1;
        __int128 re        = 0;
        if (pRemaining < 0) {
            re = re64 >> (-pRemaining);
        }
        else {
            __int128 pPowRemaining = ((__int128)1) << pRemaining;
            re                     = pPowRemaining * re64;
        }

        int64_t im64 = std::llround(dim);
        pRemaining   = pCurrent + n2;
        __int128 im  = 0;
        if (pRemaining < 0) {
            im = im64 >> (-pRemaining);
        }
        else {
            __int128 pPowRemaining = ((int64_t)1) << pRemaining;
            im                     = pPowRemaining * im64;
        }

        temp[i]         = (re < 0) ? Max128BitValue() + re : re;
        temp[i + slots] = (im < 0) ? Max128BitValue() + im : im;

        if (is128BitOverflow(temp[i]) || is128BitOverflow(temp[i + slots])) {
            OPENFHE_THROW("Overflow, try to decrease scaling factor");
        }
    }

    const std::shared_ptr<ILDCRTParams<BigInteger>> bigParams        = plainElement.GetParams();
    const std::vector<std::shared_ptr<ILNativeParams>>& nativeParams = bigParams->GetParams();

    for (size_t i = 0; i < nativeParams.size(); i++) {
        NativeVector nativeVec(N, nativeParams[i]->GetModulus());
        FitToNativeVector(N, temp, Max128BitValue(), &nativeVec);
        NativePoly element = plainElement.GetElementAtIndex(i);
        element.SetValues(nativeVec, Format::COEFFICIENT);
        plainElement.SetElementAtIndex(i, element);
    }

    usint numTowers = nativeParams.size();
    std::vector<DCRTPoly::Integer> moduli(numTowers);
    for (usint i = 0; i < numTowers; i++) {
        moduli[i] = nativeParams[i]->GetModulus();
    }

    DCRTPoly::Integer intPowP = NativeInteger(1) << pBits;
    std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);

    auto currPowP = crtPowP;

    // We want to scale temp by 2^(pd), and the loop starts from j=2
    // because temp is already scaled by 2^p in the re/im loop above,
    // and currPowP already is 2^p.
    for (size_t i = 2; i < noiseScaleDeg; i++) {
        currPowP = CKKSPackedEncoding::CRTMult(currPowP, crtPowP, moduli);
    }

    if (noiseScaleDeg > 1) {
        plainElement = plainElement.Times(currPowP);
    }

    p->SetFormat(Format::EVALUATION);
    p->SetScalingFactor(pow(p->GetScalingFactor(), noiseScaleDeg));

    return p;
}
#else
Plaintext SWITCHCKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, const std::shared_ptr<ParmType> params,
                                          const std::vector<std::complex<double>>& value, size_t noiseScaleDeg,
                                          uint32_t level, usint slots) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());

    double scFact = cryptoParams->GetScalingFactorReal(level);

    Plaintext p = Plaintext(std::make_shared<CKKSPackedEncoding>(params, cc.GetEncodingParams(), value, noiseScaleDeg,
                                                                 level, scFact, slots));

    DCRTPoly& plainElement = p->GetElement<DCRTPoly>();

    usint N = cc.GetRingDimension();

    std::vector<std::complex<double>> inverse = value;

    inverse.resize(slots);

    DiscreteFourierTransform::FFTSpecialInv(inverse, N * 2);
    double powP = scFact;

    // Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
    constexpr int32_t MAX_BITS_IN_WORD = 61;

    int32_t logc = 0;
    for (size_t i = 0; i < slots; ++i) {
        inverse[i] *= powP;
        if (inverse[i].real() != 0) {
            int32_t logci = static_cast<int32_t>(ceil(log2(std::abs(inverse[i].real()))));
            if (logc < logci)
                logc = logci;
        }
        if (inverse[i].imag() != 0) {
            int32_t logci = static_cast<int32_t>(ceil(log2(std::abs(inverse[i].imag()))));
            if (logc < logci)
                logc = logci;
        }
    }
    if (logc < 0) {
        OPENFHE_THROW("Too small scaling factor");
    }
    int32_t logValid    = (logc <= MAX_BITS_IN_WORD) ? logc : MAX_BITS_IN_WORD;
    int32_t logApprox   = logc - logValid;
    double approxFactor = pow(2, logApprox);

    std::vector<int64_t> temp(2 * slots);
    for (size_t i = 0; i < slots; ++i) {
        // Scale down by approxFactor in case the value exceeds a 64-bit integer.
        double dre = inverse[i].real() / approxFactor;
        double dim = inverse[i].imag() / approxFactor;

        // Check for possible overflow
        if (is64BitOverflow(dre) || is64BitOverflow(dim)) {
            DiscreteFourierTransform::FFTSpecial(inverse, N * 2);

            double invLen = static_cast<double>(inverse.size());
            double factor = 2 * M_PI * i;

            double realMax = -1;
            double imagMax = -1;
            // TODO (dsuponit): is this correct - "uint32_t realMaxIdx = -1" and "uint32_t imagMaxIdx = -1"? if yes,
            // TODO (dsuponit): shouldn't it better be "uint32_t realMaxIdx = std::numeric_limits<uint32_t>::max()"?"
            uint32_t realMaxIdx = -1;
            uint32_t imagMaxIdx = -1;

            for (uint32_t idx = 0; idx < inverse.size(); idx++) {
                // exp( j*2*pi*n*k/N )
                std::complex<double> expFactor = {cos((factor * idx) / invLen), sin((factor * idx) / invLen)};

                // X[k] * exp( j*2*pi*n*k/N )
                std::complex<double> prodFactor = inverse[idx] * expFactor;

                double realVal = prodFactor.real();
                double imagVal = prodFactor.imag();

                if (realVal > realMax) {
                    realMax    = realVal;
                    realMaxIdx = idx;
                }
                if (imagVal > imagMax) {
                    imagMax    = imagVal;
                    imagMaxIdx = idx;
                }
            }

            auto scaledInputSize = ceil(log2(dre));

            std::stringstream buffer;
            buffer << std::endl
                   << "Overflow in data encoding - scaled input is too large to fit "
                      "into a NativeInteger (60 bits). Try decreasing scaling factor."
                   << std::endl;
            buffer << "Overflow at slot number " << i << std::endl;
            buffer << "- Max real part contribution from input[" << realMaxIdx << "]: " << realMax << std::endl;
            buffer << "- Max imaginary part contribution from input[" << imagMaxIdx << "]: " << imagMax << std::endl;
            buffer << "Scaling factor is " << ceil(log2(powP)) << " bits " << std::endl;
            buffer << "Scaled input is " << scaledInputSize << " bits " << std::endl;
            OPENFHE_THROW(buffer.str());
        }

        int64_t re = std::llround(dre);
        int64_t im = std::llround(dim);

        temp[i]         = (re < 0) ? Max64BitValue() + re : re;
        temp[i + slots] = (im < 0) ? Max64BitValue() + im : im;
    }

    const std::shared_ptr<ILDCRTParams<BigInteger>> bigParams        = plainElement.GetParams();
    const std::vector<std::shared_ptr<ILNativeParams>>& nativeParams = bigParams->GetParams();

    for (size_t i = 0; i < nativeParams.size(); i++) {
        NativeVector nativeVec(N, nativeParams[i]->GetModulus());
        FitToNativeVector(N, temp, Max64BitValue(), &nativeVec);
        NativePoly element = plainElement.GetElementAtIndex(i);
        element.SetValues(nativeVec, Format::COEFFICIENT);
        plainElement.SetElementAtIndex(i, element);
    }

    usint numTowers = nativeParams.size();
    std::vector<DCRTPoly::Integer> moduli(numTowers);
    for (usint i = 0; i < numTowers; i++) {
        moduli[i] = nativeParams[i]->GetModulus();
    }

    DCRTPoly::Integer intPowP = std::llround(powP);
    std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);

    auto currPowP = crtPowP;

    // We want to scale temp by 2^(pd), and the loop starts from j=2
    // because temp is already scaled by 2^p in the re/im loop above,
    // and currPowP already is 2^p.
    for (size_t i = 2; i < noiseScaleDeg; i++) {
        currPowP = CKKSPackedEncoding::CRTMult(currPowP, crtPowP, moduli);
    }

    if (noiseScaleDeg > 1) {
        plainElement = plainElement.Times(currPowP);
    }

    // Scale back up by the approxFactor to get the correct encoding.
    if (logApprox > 0) {
        int32_t logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP;
        auto intStep    = DCRTPoly::Integer(uint64_t(1) << logStep);
        std::vector<DCRTPoly::Integer> crtApprox(numTowers, intStep);
        logApprox -= logStep;

        while (logApprox > 0) {
            logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP;
            intStep = DCRTPoly::Integer(uint64_t(1) << logStep);
            std::vector<DCRTPoly::Integer> crtSF(numTowers, intStep);
            crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli);
            logApprox -= logStep;
        }
        plainElement = plainElement.Times(crtApprox);
    }

    p->SetFormat(Format::EVALUATION);
    p->SetScalingFactor(pow(p->GetScalingFactor(), noiseScaleDeg));

    return p;
}
#endif

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalMultExt(ConstCiphertext<DCRTPoly> ciphertext, ConstPlaintext plaintext) const {
    Ciphertext<DCRTPoly> result = ciphertext->Clone();
    std::vector<DCRTPoly>& cv   = result->GetElements();

    DCRTPoly pt = plaintext->GetElement<DCRTPoly>();
    pt.SetFormat(Format::EVALUATION);

    for (auto& c : cv) {
        c *= pt;
    }
    result->SetNoiseScaleDeg(result->GetNoiseScaleDeg() + plaintext->GetNoiseScaleDeg());
    result->SetScalingFactor(result->GetScalingFactor() * plaintext->GetScalingFactor());
    return result;
}

void SWITCHCKKSRNS::EvalAddExtInPlace(Ciphertext<DCRTPoly>& ciphertext1, ConstCiphertext<DCRTPoly> ciphertext2) const {
    std::vector<DCRTPoly>& cv1       = ciphertext1->GetElements();
    const std::vector<DCRTPoly>& cv2 = ciphertext2->GetElements();

    for (size_t i = 0; i < cv1.size(); ++i) {
        cv1[i] += cv2[i];
    }
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalAddExt(ConstCiphertext<DCRTPoly> ciphertext1,
                                               ConstCiphertext<DCRTPoly> ciphertext2) const {
    Ciphertext<DCRTPoly> result = ciphertext1->Clone();
    EvalAddExtInPlace(result, ciphertext2);
    return result;
}

EvalKey<DCRTPoly> SWITCHCKKSRNS::ConjugateKeyGen(const PrivateKey<DCRTPoly> privateKey) const {
    const auto cc = privateKey->GetCryptoContext();
    auto algo     = cc->GetScheme();

    const DCRTPoly& s = privateKey->GetPrivateElement();
    usint N           = s.GetRingDimension();

    PrivateKey<DCRTPoly> privateKeyPermuted = std::make_shared<PrivateKeyImpl<DCRTPoly>>(cc);

    usint index = 2 * N - 1;
    std::vector<usint> vec(N);
    PrecomputeAutoMap(N, index, &vec);

    DCRTPoly sPermuted = s.AutomorphismTransform(index, vec);

    privateKeyPermuted->SetPrivateElement(sPermuted);
    privateKeyPermuted->SetKeyTag(privateKey->GetKeyTag());

    return algo->KeySwitchGen(privateKey, privateKeyPermuted);
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::Conjugate(ConstCiphertext<DCRTPoly> ciphertext,
                                              const std::map<usint, EvalKey<DCRTPoly>>& evalKeyMap) const {
    const std::vector<DCRTPoly>& cv = ciphertext->GetElements();
    usint N                         = cv[0].GetRingDimension();

    std::vector<usint> vec(N);
    PrecomputeAutoMap(N, 2 * N - 1, &vec);

    auto algo = ciphertext->GetCryptoContext()->GetScheme();

    Ciphertext<DCRTPoly> result = ciphertext->Clone();

    algo->KeySwitchInPlace(result, evalKeyMap.at(2 * N - 1));

    std::vector<DCRTPoly>& rcv = result->GetElements();

    rcv[0] = rcv[0].AutomorphismTransform(2 * N - 1, vec);
    rcv[1] = rcv[1].AutomorphismTransform(2 * N - 1, vec);

    return result;
}

#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
void SWITCHCKKSRNS::FitToNativeVector(uint32_t ringDim, const std::vector<__int128>& vec, __int128 bigBound,
                                      NativeVector* nativeVec) const {
    if (nativeVec == nullptr)
        OPENFHE_THROW("The passed native vector is empty.");
    NativeInteger bigValueHf((unsigned __int128)bigBound >> 1);
    NativeInteger modulus(nativeVec->GetModulus());
    NativeInteger diff = NativeInteger((unsigned __int128)bigBound) - modulus;
    uint32_t dslots    = vec.size();
    uint32_t gap       = ringDim / dslots;
    for (usint i = 0; i < vec.size(); i++) {
        NativeInteger n((unsigned __int128)vec[i]);
        if (n > bigValueHf) {
            (*nativeVec)[gap * i] = n.ModSub(diff, modulus);
        }
        else {
            (*nativeVec)[gap * i] = n.Mod(modulus);
        }
    }
}
#else  // NATIVEINT == 64
void SWITCHCKKSRNS::FitToNativeVector(uint32_t ringDim, const std::vector<int64_t>& vec, int64_t bigBound,
                                      NativeVector* nativeVec) const {
    if (nativeVec == nullptr)
        OPENFHE_THROW("The passed native vector is empty.");
    NativeInteger bigValueHf(bigBound >> 1);
    NativeInteger modulus(nativeVec->GetModulus());
    NativeInteger diff = bigBound - modulus;
    uint32_t dslots    = vec.size();
    uint32_t gap       = ringDim / dslots;
    for (usint i = 0; i < vec.size(); i++) {
        NativeInteger n(vec[i]);
        if (n > bigValueHf) {
            (*nativeVec)[gap * i] = n.ModSub(diff, modulus);
        }
        else {
            (*nativeVec)[gap * i] = n.Mod(modulus);
        }
    }
}
#endif

//------------------------------------------------------------------------------
// Key and modulus switch and extraction methods
//------------------------------------------------------------------------------

NativeInteger RoundqQAlter(const NativeInteger& v, const NativeInteger& q, const NativeInteger& Q) {
    return NativeInteger(
               (BasicInteger)std::floor(0.5 + v.ConvertToDouble() * q.ConvertToDouble() / Q.ConvertToDouble()))
        .Mod(q);
}

EvalKey<DCRTPoly> switchingKeyGenRLWE(
    const PrivateKey<DCRTPoly>& ckksSK,
    ConstLWEPrivateKey& LWEsk) {  // This function is without the intermediate ModSwitch
    // Extract CKKS params: method which populates the first n elements of a new RLWE key with the n elements of the target LWE key
    auto skelements = ckksSK->GetPrivateElement();
    skelements.SetFormat(Format::COEFFICIENT);
    auto lweskElements = LWEsk->GetElement();
    for (size_t i = 0; i < skelements.GetNumOfElements(); i++) {
        auto skelementsPlain = skelements.GetElementAtIndex(i);
        for (size_t j = 0; j < skelementsPlain.GetLength(); j++) {
            if (j >= lweskElements.GetLength()) {
                skelementsPlain[j] = 0;
            }
            else {
                if (lweskElements[j] == 0) {
                    skelementsPlain[j] = 0;
                }
                else if (lweskElements[j].ConvertToInt() == 1) {
                    skelementsPlain[j] = 1;
                }
                else
                    skelementsPlain[j] = skelementsPlain.GetModulus() - 1;
            }
        }
        skelements.SetElementAtIndex(i, skelementsPlain);
    }

    skelements.SetFormat(Format::EVALUATION);

    auto ccCKKS    = ckksSK->GetCryptoContext();
    auto RLWELWEsk = ccCKKS->KeyGen().secretKey;
    RLWELWEsk->SetPrivateElement(std::move(skelements));

    return ccCKKS->KeySwitchGen(ckksSK, RLWELWEsk);
}

void ModSwitch(ConstCiphertext<DCRTPoly> ctxt, Ciphertext<DCRTPoly>& ctxtKS, NativeInteger modulus_CKKS_to) {
    if (ctxt->GetElements()[0].GetRingDimension() != ctxtKS->GetElements()[0].GetRingDimension()) {
        OPENFHE_THROW("ModSwitch is implemented only for the same ring dimension.");
    }

    const std::vector<DCRTPoly>& cv = ctxt->GetElements();

    if (cv[0].GetNumOfElements() != 1 || ctxtKS->GetElements()[0].GetNumOfElements() != 1) {
        OPENFHE_THROW("ModSwitch is implemented only for ciphertext with one tower.");
    }

    const auto& paramsQlP = ctxtKS->GetElements()[0].GetParams();
    std::vector<DCRTPoly> resultElements;
    resultElements.reserve(cv.size());

    for (const auto& elem : cv) {
        auto& ref = resultElements.emplace_back(paramsQlP, Format::COEFFICIENT, true);
        ref.SetValuesModSwitch(elem, modulus_CKKS_to);
        ref.SetFormat(Format::EVALUATION);
    }

    ctxtKS->SetElements(resultElements);
}

EvalKey<DCRTPoly> switchingKeyGen(const PrivateKey<DCRTPoly>& ckksSKto, const PrivateKey<DCRTPoly>& ckksSKfrom) {
    auto skElements = ckksSKto->GetPrivateElement();
    skElements.SetFormat(Format::COEFFICIENT);
    auto skElementsFrom = ckksSKfrom->GetPrivateElement();
    skElementsFrom.SetFormat(Format::COEFFICIENT);

    for (size_t i = 0; i < skElements.GetNumOfElements(); i++) {
        auto skElementsPlain     = skElements.GetElementAtIndex(i);
        auto skElementsFromPlain = skElementsFrom.GetElementAtIndex(i);
        for (size_t j = 0; j < skElementsPlain.GetLength(); j++) {
            if (skElementsFromPlain[j] == 0) {
                skElementsPlain[j] = 0;
            }
            else if (skElementsFromPlain[j] == 1) {
                skElementsPlain[j] = 1;
            }
            else
                skElementsPlain[j] = skElementsPlain.GetModulus() - 1;
        }
        skElements.SetElementAtIndex(i, skElementsPlain);
    }

    skElements.SetFormat(Format::EVALUATION);

    auto ccCKKSto        = ckksSKto->GetCryptoContext();
    auto oldTranformedSK = ccCKKSto->KeyGen().secretKey;
    oldTranformedSK->SetPrivateElement(std::move(skElements));

    return ccCKKSto->KeySwitchGen(oldTranformedSK, ckksSKto);
}

EvalKey<DCRTPoly> switchingKeyGenRLWEcc(const PrivateKey<DCRTPoly>& ckksSKto, const PrivateKey<DCRTPoly>& ckksSKfrom,
                                        ConstLWEPrivateKey& LWEsk) {
    auto skElements = ckksSKto->GetPrivateElement();
    skElements.SetFormat(Format::COEFFICIENT);
    auto skElementsFrom = ckksSKfrom->GetPrivateElement();
    skElementsFrom.SetFormat(Format::COEFFICIENT);
    auto skElements2 = ckksSKto->GetPrivateElement();
    skElements2.SetFormat(Format::COEFFICIENT);
    auto lweskElements = LWEsk->GetElement();

    for (size_t i = 0; i < skElements.GetNumOfElements(); i++) {
        auto skElementsPlain     = skElements.GetElementAtIndex(i);
        auto skElementsFromPlain = skElementsFrom.GetElementAtIndex(i);
        auto skElementsPlainLWE  = skElements2.GetElementAtIndex(i);
        for (size_t j = 0; j < skElementsPlain.GetLength(); j++) {
            if (skElementsFromPlain[j] == 0) {
                skElementsPlain[j] = 0;
            }
            else if (skElementsFromPlain[j] == 1) {
                skElementsPlain[j] = 1;
            }
            else
                skElementsPlain[j] = skElementsPlain.GetModulus() - 1;

            if (j >= lweskElements.GetLength()) {
                skElementsPlainLWE[j] = 0;
            }
            else {
                if (lweskElements[j] == 0) {
                    skElementsPlainLWE[j] = 0;
                }
                else if (lweskElements[j].ConvertToInt() == 1) {
                    skElementsPlainLWE[j] = 1;
                }
                else
                    skElementsPlainLWE[j] = skElementsPlain.GetModulus() - 1;
            }
        }
        skElements.SetElementAtIndex(i, skElementsPlain);
        skElements2.SetElementAtIndex(i, skElementsPlainLWE);
    }

    skElements.SetFormat(Format::EVALUATION);
    skElements2.SetFormat(Format::EVALUATION);

    auto ccCKKSto        = ckksSKto->GetCryptoContext();
    auto oldTranformedSK = ccCKKSto->KeyGen().secretKey;
    oldTranformedSK->SetPrivateElement(std::move(skElements));
    auto RLWELWEsk = ccCKKSto->KeyGen().secretKey;
    RLWELWEsk->SetPrivateElement(std::move(skElements2));

    return ccCKKSto->KeySwitchGen(oldTranformedSK, RLWELWEsk);
}

std::vector<std::vector<NativeInteger>> ExtractLWEpacked(const Ciphertext<DCRTPoly>& ct) {
    auto originalA{(ct->GetElements()[1]).GetElementAtIndex(0)};
    originalA.SetFormat(Format::COEFFICIENT);
    auto* ptrA = &originalA.GetValues()[0];
    auto originalB{(ct->GetElements()[0]).GetElementAtIndex(0)};
    originalB.SetFormat(Format::COEFFICIENT);
    auto* ptrB = &originalB.GetValues()[0];
    size_t N = originalB.GetLength();
    std::vector<std::vector<NativeInteger>> extracted{std::vector<NativeInteger>(ptrB, ptrB + N),
                                                      std::vector<NativeInteger>(ptrA, ptrA + N)};
    return extracted;
}

std::shared_ptr<LWECiphertextImpl> ExtractLWECiphertext(const std::vector<std::vector<NativeInteger>>& aANDb,
                                                        NativeInteger modulus, uint32_t n, uint32_t index = 0) {
    auto N = aANDb[0].size();
    NativeVector a(n, modulus);
    NativeInteger b;

    for (size_t i = 0; i < n && i <= index; ++i) {
        a[i] = modulus - aANDb[1][index - i];
    }
    if (n > index) {
        for (size_t i = index + 1; i < n; ++i) {
            a[i] = aANDb[1][N + index - i];
        }
    }

    b           = aANDb[0][index];
    auto result = std::make_shared<LWECiphertextImpl>(std::move(a), std::move(b));
    return result;
}

//------------------------------------------------------------------------------
// Linear transformation methods.
// Currently mostly copied from ckksrns-fhe, because there an internal bootstrapping global structure is used.
// TODO: fix this.
//------------------------------------------------------------------------------

std::vector<ConstPlaintext> SWITCHCKKSRNS::EvalLTPrecomputeSwitch(
    const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A,
    const std::vector<std::vector<std::complex<double>>>& B, uint32_t dim1, uint32_t L, double scale = 1) const {
    uint32_t slots = A.size();
    uint32_t M     = cc.GetCyclotomicOrder();

    // Computing the baby-step bStep and the giant-step gStep
    uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(slots) : dim1;
    uint32_t gStep = ceil(static_cast<double>(slots) / bStep);

    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());

    ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());

    uint32_t towersToDrop = 0;
    if (L != 0) {
        towersToDrop = elementParams.GetParams().size() - L - 1;
        for (uint32_t i = 0; i < towersToDrop; i++)
            elementParams.PopLastParam();
    }

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);

    std::vector<std::vector<std::complex<double>>> newA(slots);
    std::vector<ConstPlaintext> result(slots);

    //  A and B are concatenated horizontally
    for (uint32_t i = 0; i < A.size(); i++) {
        newA[i].reserve(A[i].size() + B[i].size());
        newA[i].insert(newA[i].end(), A[i].begin(), A[i].end());
        newA[i].insert(newA[i].end(), B[i].begin(), B[i].end());
    }

#pragma omp parallel for
    for (uint32_t j = 0; j < gStep; j++) {
        int32_t offset = -static_cast<int32_t>(bStep * j);
        for (uint32_t i = 0; i < bStep; i++) {
            if (bStep * j + i < slots) {
                // shifted diagonal is computed for rectangular map newA of dimension slots x 2*slots
                auto vec = ExtractShiftedDiagonal(newA, bStep * j + i);
                std::transform(vec.begin(), vec.end(), vec.begin(),
                               [&](const std::complex<double>& elem) { return elem * scale; });

                result[bStep * j + i] =
                    MakeAuxPlaintext(cc, elementParamsPtr, Rotate(Fill(vec, M / 4), offset), 1, towersToDrop, M / 4);
            }
        }
    }

    return result;
}

std::vector<ConstPlaintext> SWITCHCKKSRNS::EvalLTPrecomputeSwitch(
    const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A, uint32_t dim1,
    uint32_t L, double scale = 1) const {
    if (A[0].size() != A.size()) {
        OPENFHE_THROW("The matrix passed to EvalLTPrecomputeSwitch is not square");
    }

    uint32_t slots = A.size();

    uint32_t M     = cc.GetCyclotomicOrder();
    uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(slots) : dim1;
    uint32_t gStep = ceil(static_cast<double>(slots) / bStep);

    // Make sure the plaintext is created only with the necessary amount of moduli
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());

    ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParams->GetElementParams());

    uint32_t towersToDrop = 0;
    if (L != 0) {
        towersToDrop = elementParams.GetParams().size() - L - 1;
        for (uint32_t i = 0; i < towersToDrop; i++)
            elementParams.PopLastParam();
    }

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParams->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);

    std::vector<ConstPlaintext> result(slots);
#pragma omp parallel for
    for (uint32_t j = 0; j < gStep; j++) {
        int32_t offset = -static_cast<int32_t>(bStep * j);
        for (uint32_t i = 0; i < bStep; i++) {
            if (bStep * j + i < slots) {
                auto diag = ExtractShiftedDiagonal(A, bStep * j + i);
                std::transform(diag.begin(), diag.end(), diag.begin(),
                               [&](const std::complex<double>& elem) { return elem * scale; });
                result[bStep * j + i] =
                    MakeAuxPlaintext(cc, elementParamsPtr, Rotate(Fill(diag, M / 4), offset), 1, towersToDrop, M / 4);
            }
        }
    }

    return result;
}

std::vector<std::vector<std::complex<double>>> EvalLTRectPrecomputeSwitch(
    const std::vector<std::vector<std::complex<double>>>& A, uint32_t dim1, double scale) {
    if (!IsPowerOfTwo(A.size()) || !IsPowerOfTwo(A[0].size())) {
        OPENFHE_THROW("The matrix passed to EvalLTPrecompute is not padded up to powers of two");
    }
    uint32_t n     = std::min(A.size(), A[0].size());
    uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(n) : dim1;
    uint32_t gStep = ceil(static_cast<double>(n) / bStep);

    std::vector<std::vector<std::complex<double>>> diags(n);

    if (A.size() >= A[0].size()) {
        auto num_slices = A.size() / A[0].size();
        std::vector<std::vector<std::vector<std::complex<double>>>> A_slices(num_slices);
        for (size_t i = 0; i < num_slices; i++) {
            A_slices[i] = std::vector<std::vector<std::complex<double>>>(A.begin() + i * A[0].size(),
                                                                         A.begin() + (i + 1) * A[0].size());
        }
#pragma omp parallel for
        for (uint32_t j = 0; j < gStep; j++) {
            for (uint32_t i = 0; i < bStep; i++) {
                if (bStep * j + i < n) {
                    std::vector<std::complex<double>> diag;
                    diag.reserve(A.size() * num_slices);
                    for (uint32_t k = 0; k < num_slices; k++) {
                        auto tmp = ExtractShiftedDiagonal(A_slices[k], bStep * j + i);
                        diag.insert(diag.end(), std::make_move_iterator(tmp.begin()),
                                    std::make_move_iterator(tmp.end()));
                    }
                    std::transform(diag.begin(), diag.end(), diag.begin(),
                                   [&](const std::complex<double>& elem) { return elem * scale; });
                    diags[bStep * j + i] = std::move(diag);
                }
            }
        }
    }
    else {
#pragma omp parallel for
        for (uint32_t j = 0; j < gStep; j++) {
            for (uint32_t i = 0; i < bStep; i++) {
                if (bStep * j + i < n) {
                    std::vector<std::complex<double>> diag = ExtractShiftedDiagonal(A, bStep * j + i);
                    std::transform(diag.begin(), diag.end(), diag.begin(),
                                   [&](const std::complex<double>& elem) { return elem * scale; });
                    diags[bStep * j + i] = std::move(diag);
                }
            }
        }
    }
    return diags;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalLTWithPrecomputeSwitch(const CryptoContextImpl<DCRTPoly>& cc,
                                                               ConstCiphertext<DCRTPoly> ctxt,
                                                               const std::vector<ConstPlaintext>& A,
                                                               uint32_t dim1) const {
    uint32_t slots = A.size();

    // Computing the baby-step bStep and the giant-step gStep
    uint32_t bStep = dim1;
    uint32_t gStep = ceil(static_cast<double>(slots) / bStep);

    uint32_t M = cc.GetCyclotomicOrder();
    uint32_t N = cc.GetRingDimension();

    // Computes the NTTs for each CRT limb (for the hoisted automorphisms used later on)
    auto digits = cc.EvalFastRotationPrecompute(ctxt);

    std::vector<Ciphertext<DCRTPoly>> fastRotation(bStep - 1);

    // Hoisted automorphisms
#pragma omp parallel for
    for (uint32_t j = 1; j < bStep; j++) {
        fastRotation[j - 1] = cc.EvalFastRotationExt(ctxt, j, digits, true);
    }

    Ciphertext<DCRTPoly> result;
    DCRTPoly first;

    for (uint32_t j = 0; j < gStep; j++) {
        Ciphertext<DCRTPoly> inner = EvalMultExt(cc.KeySwitchExt(ctxt, true), A[bStep * j]);

        for (uint32_t i = 1; i < bStep; i++) {
            if (bStep * j + i < slots) {
                EvalAddExtInPlace(inner, EvalMultExt(fastRotation[i - 1], A[bStep * j + i]));
            }
        }

        if (j == 0) {
            first         = cc.KeySwitchDownFirstElement(inner);
            auto elements = inner->GetElements();
            elements[0].SetValuesToZero();
            inner->SetElements(elements);
            result = inner;
        }
        else {
            inner = cc.KeySwitchDown(inner);
            // Find the automorphism index that corresponds to the rotation index.
            usint autoIndex = FindAutomorphismIndex2nComplex(bStep * j, M);
            std::vector<usint> map(N);
            PrecomputeAutoMap(N, autoIndex, &map);
            DCRTPoly firstCurrent = inner->GetElements()[0].AutomorphismTransform(autoIndex, map);
            first += firstCurrent;

            auto innerDigits = cc.EvalFastRotationPrecompute(inner);
            EvalAddExtInPlace(result, cc.EvalFastRotationExt(inner, bStep * j, innerDigits, false));
        }
    }

    result        = cc.KeySwitchDown(result);
    auto elements = result->GetElements();
    elements[0] += first;
    result->SetElements(elements);

    return result;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalLTRectWithPrecomputeSwitch(
    const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A,
    ConstCiphertext<DCRTPoly> ct, bool wide, uint32_t dim1, uint32_t L) const {
    uint32_t n = std::min(A.size(), A[0].size());

    // Computing the baby-step bStep and the giant-step gStep
    uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(n) : dim1;
    uint32_t gStep = ceil(static_cast<double>(n) / bStep);

    uint32_t M = cc.GetCyclotomicOrder();
    uint32_t N = cc.GetRingDimension();

    // Computes the NTTs for each CRT limb (for the hoisted automorphisms used later on)
    auto digits = cc.EvalFastRotationPrecompute(ct);

    std::vector<Ciphertext<DCRTPoly>> fastRotation(bStep - 1);

    // Make sure the plaintext is created only with the necessary amount of moduli
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ct->GetCryptoParameters());

    ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());
    uint32_t towersToDrop                         = 0;

    // For FLEXIBLEAUTOEXT we do not need extra modulus in auxiliary plaintexts
    if (L != 0) {
        towersToDrop = elementParams.GetParams().size() - L - 1;
        for (uint32_t i = 0; i < towersToDrop; i++)
            elementParams.PopLastParam();
    }
    if (cryptoParamsCKKS->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
        towersToDrop += 1;
        elementParams.PopLastParam();
    }

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr  = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);
    auto elementParamsPtr2 = std::dynamic_pointer_cast<typename DCRTPoly::Params>(elementParamsPtr);

// Hoisted automorphisms
#pragma omp parallel for
    for (uint32_t j = 1; j < bStep; j++) {
        fastRotation[j - 1] = cc.EvalFastRotationExt(ct, j, digits, true);
    }

    Ciphertext<DCRTPoly> result;
    DCRTPoly first;

    for (uint32_t j = 0; j < gStep; j++) {
        int32_t offset = (j == 0) ? 0 : -static_cast<int32_t>(bStep * j);
        auto temp      = cc.MakeCKKSPackedPlaintext(Rotate(Fill(A[bStep * j], N / 2), offset), 1, towersToDrop,
                                                    elementParamsPtr2, N / 2);
        Ciphertext<DCRTPoly> inner = EvalMultExt(cc.KeySwitchExt(ct, true), temp);

        for (uint32_t i = 1; i < bStep; i++) {
            if (bStep * j + i < n) {
                auto tempi = cc.MakeCKKSPackedPlaintext(Rotate(Fill(A[bStep * j + i], N / 2), offset), 1, towersToDrop,
                                                        elementParamsPtr2, N / 2);
                EvalAddExtInPlace(inner, EvalMultExt(fastRotation[i - 1], tempi));
            }
        }

        if (j == 0) {
            first         = cc.KeySwitchDownFirstElement(inner);
            auto elements = inner->GetElements();
            elements[0].SetValuesToZero();
            inner->SetElements(elements);
            result = inner;
        }
        else {
            inner = cc.KeySwitchDown(inner);
            // Find the automorphism index that corresponds to rotation index index.
            usint autoIndex = FindAutomorphismIndex2nComplex(bStep * j, M);
            std::vector<usint> map(N);
            PrecomputeAutoMap(N, autoIndex, &map);
            DCRTPoly firstCurrent = inner->GetElements()[0].AutomorphismTransform(autoIndex, map);
            first += firstCurrent;

            auto innerDigits = cc.EvalFastRotationPrecompute(inner);
            EvalAddExtInPlace(result, cc.EvalFastRotationExt(inner, bStep * j, innerDigits, false));
        }
    }
    result        = cc.KeySwitchDown(result);
    auto elements = result->GetElements();
    elements[0] += first;
    result->SetElements(elements);

    // A represents the diagonals, which lose the information whether the initial matrix is tall or wide
    if (wide) {
        uint32_t logl = lbcrypto::GetMSB(A[0].size() / A.size()) - 1;  // These are powers of two, so log(l) is integer
        std::vector<Ciphertext<DCRTPoly>> ctxt(logl + 1);
        ctxt[0] = result;
        for (size_t j = 1; j <= logl; ++j) {
            ctxt[j] = cc.EvalAdd(ctxt[j - 1], cc.EvalAtIndex(ctxt[j - 1], A.size() * (1 << (j - 1))));
        }
        result = ctxt[logl];
    }

    return result;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalSlotsToCoeffsSwitch(const CryptoContextImpl<DCRTPoly>& cc,
                                                            ConstCiphertext<DCRTPoly> ctxt) const {
    uint32_t slots = m_numSlotsCKKS;
    uint32_t m     = 4 * slots;
    uint32_t M     = cc.GetCyclotomicOrder();
    bool isSparse  = (M != m) ? true : false;

    auto ctxtToDecode = ctxt->Clone();

    uint32_t numTowersToKeep = 2;
    const auto cryptoParams  = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
    if (cryptoParams->GetScalingTechnique() == ScalingTechnique::FLEXIBLEAUTO ||
        cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
        ctxtToDecode = cc.Compress(ctxtToDecode, numTowersToKeep + 1);

        double targetSF =
            cryptoParams->GetScalingFactorReal(cryptoParams->GetElementParams()->GetParams().size() - numTowersToKeep);
        double sourceSF    = ctxtToDecode->GetScalingFactor();
        uint32_t numTowers = ctxtToDecode->GetElements()[0].GetNumOfElements();
        double modToDrop = cryptoParams->GetElementParams()->GetParams()[numTowers - 1]->GetModulus().ConvertToDouble();
        double adjustmentFactor = (targetSF / sourceSF) * (modToDrop / sourceSF);

        ctxtToDecode = cc.EvalMult(ctxtToDecode, adjustmentFactor);
        cc.GetScheme()->ModReduceInternalInPlace(ctxtToDecode, 1);
        ctxtToDecode->SetScalingFactor(targetSF);
    }
    else {
        ctxtToDecode = cc.Compress(ctxtToDecode, numTowersToKeep);
    }

    Ciphertext<DCRTPoly> ctxtDecoded;

    if (slots != m_numSlotsCKKS || m_U0Pre.size() == 0) {
        std::string errorMsg(std::string("Precomputations for ") + std::to_string(slots) +
                             std::string(" slots were not generated") +
                             std::string(" Need to call EvalCKKSToFHEWPrecompute to proceed"));
        OPENFHE_THROW(errorMsg);
    }

    if (!isSparse) {  // fully packed
        // ctxtToDecode = cc.EvalAdd(ctxtToDecode, cc.GetScheme()->MultByMonomial(ctxtToDecode, M / 4));
        ctxtDecoded = EvalLTWithPrecomputeSwitch(cc, ctxtToDecode, m_U0Pre, m_dim1CF);
    }
    else {  // sparsely packed
        ctxtDecoded = EvalLTWithPrecomputeSwitch(cc, ctxtToDecode, m_U0Pre, m_dim1CF);
        cc.EvalAddInPlace(ctxtDecoded, cc.EvalAtIndex(ctxtDecoded, slots));
    }

    return ctxtDecoded;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalPartialHomDecryption(const CryptoContextImpl<DCRTPoly>& cc,
                                                             const std::vector<std::vector<std::complex<double>>>& A,
                                                             ConstCiphertext<DCRTPoly> ct, uint32_t dim1, double scale,
                                                             uint32_t L) const {
    // Currently, by design, the # rows (# LWE ciphertexts to switch) is a power of two.
    // Ensure that # cols (LWE lattice parameter n) is padded up to a power of two
    std::vector<std::vector<std::complex<double>>> Acopy(A);
    uint32_t cols_po2 = 1 << static_cast<uint32_t>(std::ceil(std::log2(A[0].size())));

    if (cols_po2 > A[0].size()) {
        for (size_t i = 0; i < A.size(); ++i) {
            Acopy[i].resize(cols_po2);
        }
    }

    auto Apre = EvalLTRectPrecomputeSwitch(Acopy, dim1, scale);
    auto res  = EvalLTRectWithPrecomputeSwitch(cc, Apre, ct, (Acopy.size() < A[0].size()), dim1,
                                               L);  // The result is repeated every Acopy.size() slots

    return res;
}

//------------------------------------------------------------------------------
// Scheme switching Wrapper
//------------------------------------------------------------------------------
LWEPrivateKey SWITCHCKKSRNS::EvalCKKStoFHEWSetup(const SchSwchParams& params) {
    if (params.GetSecurityLevelFHEW() != TOY && params.GetSecurityLevelFHEW() != STD128)
        OPENFHE_THROW("Only STD128 or TOY are currently supported.");

    uint32_t ringDim = params.GetRingDimension();
    if (params.GetNumSlotsCKKS() == 0 || params.GetNumSlotsCKKS() == (ringDim / 2))  // fully-packed
        m_numSlotsCKKS = ringDim / 2;
    else  // sparsely-packed
        m_numSlotsCKKS = params.GetNumSlotsCKKS();

    m_modulus_CKKS_initial = params.GetInitialCKKSModulus();
    // Modulus to switch to in order to have secure RLWE samples with ring dimension n.
    // We can select any Qswitch less than 27 bits corresponding to 128 bits of security for lattice parameter n=1024 < 1305
    // according to https://homomorphicencryption.org/wp-content/uploads/2018/11/HomomorphicEncryptionStandardv1.1.pdf
    // or any Qswitch for TOY security.
    // Ensure that Qswitch is larger than Q_FHEW and smaller than Q_CKKS.
    if (params.GetCtxtModSizeFHEWIntermedSwch() <= params.GetCtxtModSizeFHEWLargePrec() ||
        params.GetCtxtModSizeFHEWIntermedSwch() > GetMSB(m_modulus_CKKS_initial.ConvertToInt()) - 1) {
        OPENFHE_THROW("Qswitch should be larger than QFHEW and smaller than QCKKS.");
    }
    // Intermediate cryptocontext
    CCParams<CryptoContextCKKSRNS> parameters;
    parameters.SetMultiplicativeDepth(0);
    parameters.SetFirstModSize(params.GetCtxtModSizeFHEWIntermedSwch());
    parameters.SetScalingModSize(params.GetScalingModSize());
    // This doesn't need this to be the same scaling technique as the outer cryptocontext, since we only do a key switch
    parameters.SetScalingTechnique(FIXEDMANUAL);
    parameters.SetSecurityLevel(params.GetSecurityLevelCKKS());
    parameters.SetRingDim(ringDim);
    parameters.SetBatchSize(params.GetBatchSize());

    m_ccKS = GenCryptoContext(parameters);

    // Enable the features that you wish to use
    m_ccKS->Enable(PKE);
    m_ccKS->Enable(KEYSWITCH);

    // Get the ciphertext modulus
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(m_ccKS->GetCryptoParameters());
    m_modulus_CKKS_from     = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus();

    m_ccLWE = std::make_shared<BinFHEContext>();
    m_ccLWE->BinFHEContext::GenerateBinFHEContext(
        params.GetSecurityLevelFHEW(), params.GetArbitraryFunctionEvaluation(), params.GetCtxtModSizeFHEWLargePrec(), 0,
        GINX, params.GetUseDynamicModeFHEW());

    // For arbitrary functions, the LWE ciphertext needs to be at most the ring dimension in FHEW bootstrapping
    m_modulus_LWE = (!params.GetArbitraryFunctionEvaluation()) ?
                        1 << params.GetCtxtModSizeFHEWLargePrec() :
                        m_ccLWE->GetParams()->GetLWEParams()->Getq().ConvertToInt();

    // LWE private key
    LWEPrivateKey lwesk = m_ccLWE->KeyGen();

    // The baby-step and number of levels for the linear transformation associated to the homomorphic decoding
    m_dim1CF = (params.GetBStepLTrCKKStoFHEW() == 0) ? getRatioBSGSLT(params.GetNumSlotsCKKS()) :
                                                       params.GetBStepLTrCKKStoFHEW();
    m_LCF    = params.GetLevelLTrCKKStoFHEW();

    return lwesk;
}

std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalCKKStoFHEWKeyGen(
    const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk) {
    auto privateKey = keyPair.secretKey;
    auto publicKey  = keyPair.publicKey;

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());

    if (cryptoParams->GetKeySwitchTechnique() != HYBRID)
        OPENFHE_THROW("CKKS to FHEW scheme switching is only supported for the Hybrid key switching method.");
#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        OPENFHE_THROW("128-bit CKKS to FHEW scheme switching is supported for FIXEDMANUAL and FIXEDAUTO methods only.");
#endif

    auto ccCKKS = privateKey->GetCryptoContext();

    // Intermediate cryptocontext for CKKS to FHEW
    auto keys2 = m_ccKS->KeyGen();

    Plaintext ptxtZeroKS = m_ccKS->MakeCKKSPackedPlaintext(std::vector<double>{0.0});
    m_ctxtKS             = m_ccKS->Encrypt(keys2.publicKey, ptxtZeroKS);

    // Compute switching key between RLWE and LWE via the intermediate cryptocontext, keep it in RLWE form
    m_CKKStoFHEWswk = switchingKeyGenRLWEcc(keys2.secretKey, privateKey, lwesk);

    // Compute automorphism keys
    uint32_t M     = ccCKKS->GetCyclotomicOrder();
    uint32_t slots = m_numSlotsCKKS;

    // Compute indices for rotations for slotToCoeff transform
    std::vector<int32_t> indexRotationS2C = FindLTRotationIndicesSwitch(m_dim1CF, M, slots);
    indexRotationS2C.emplace_back(static_cast<int32_t>(slots));

    // Remove possible duplicates and zero
    sort(indexRotationS2C.begin(), indexRotationS2C.end());
    indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());
    indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());

    auto algo = ccCKKS->GetScheme();

    // Compute multiplication key
    algo->EvalMultKeyGen(privateKey);

    auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationS2C);

    // Compute conjugation key
    auto conjKey       = ConjugateKeyGen(privateKey);
    (*evalKeys)[M - 1] = conjKey;

    return evalKeys;
}

void SWITCHCKKSRNS::EvalCKKStoFHEWPrecompute(const CryptoContextImpl<DCRTPoly>& cc, double scale) {
    uint32_t M     = cc.GetCyclotomicOrder();
    uint32_t slots = m_numSlotsCKKS;

    uint32_t m    = 4 * m_numSlotsCKKS;
    bool isSparse = (M != m) ? true : false;

    // Computes indices for all primitive roots of unity
    std::vector<uint32_t> rotGroup(slots);
    uint32_t fivePows = 1;
    for (uint32_t i = 0; i < slots; ++i) {
        rotGroup[i] = fivePows;
        fivePows *= 5;
        fivePows %= m;
    }
    // Computes all powers of a primitive root of unity exp(2*M_PI/m)
    std::vector<std::complex<double>> ksiPows(m + 1);
    for (uint32_t j = 0; j < m; ++j) {
        double angle = 2.0 * M_PI * j / m;
        ksiPows[j].real(cos(angle));
        ksiPows[j].imag(sin(angle));
    }
    ksiPows[m] = ksiPows[0];

    // Matrices for decoding
    std::vector<std::vector<std::complex<double>>> U0(slots, std::vector<std::complex<double>>(slots));
    std::vector<std::vector<std::complex<double>>> U1(slots, std::vector<std::complex<double>>(slots));

    for (size_t i = 0; i < slots; i++) {
        for (size_t j = 0; j < slots; j++) {
            U0[i][j] = ksiPows[(j * rotGroup[i]) % m];
            U1[i][j] = std::complex<double>(0, 1) * U0[i][j];
        }
    }

    // Obtain the right scaling for encoded messages in FHEW coming from encoded messages in CKKS
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
    double scFactor = cryptoParams->GetScalingFactorReal(cryptoParams->GetElementParams()->GetParams().size() - 1);
    scale *= m_modulus_CKKS_initial.ConvertToDouble() / scFactor;

    if (!isSparse) {  // fully packed
        m_U0Pre = EvalLTPrecomputeSwitch(cc, U0, m_dim1CF, m_LCF, scale);
    }
    else {  // sparsely packed
        m_U0Pre = EvalLTPrecomputeSwitch(cc, U0, U1, m_dim1CF, m_LCF, scale);
    }
}

std::vector<std::shared_ptr<LWECiphertextImpl>> SWITCHCKKSRNS::EvalCKKStoFHEW(ConstCiphertext<DCRTPoly> ciphertext,
                                                                              uint32_t numCtxts) {
    auto ccCKKS    = ciphertext->GetCryptoContext();
    uint32_t slots = m_numSlotsCKKS;

    // Step 1. Homomorphic decoding
    auto ctxtDecoded = EvalSlotsToCoeffsSwitch(*ccCKKS, ciphertext);
    ccCKKS->GetScheme()->ModReduceInternalInPlace(ctxtDecoded, 1);

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());

    // Step 2. Modulus switch to Q', such that CKKS is secure for (Q',n)
    auto ctxtKS = m_ctxtKS->Clone();
    ModSwitch(ctxtDecoded, ctxtKS, m_modulus_CKKS_from);

    // Step 3: Key switch from the CKKS key with the new modulus Q' to the RLWE version of the FHEW key with the new modulus Q'
    auto ccKS       = ctxtKS->GetCryptoContext();  // Use this instead of m_ccKS to work with serialization
    auto ctSwitched = ccKS->KeySwitch(ctxtKS, m_CKKStoFHEWswk);

    // Step 4. Extract LWE ciphertexts with the modulus Q'
    uint32_t n = m_ccLWE->GetParams()->GetLWEParams()->Getn();  // lattice parameter for additive LWE
    std::vector<std::shared_ptr<LWECiphertextImpl>> LWEciphertexts(numCtxts);
    auto AandB = ExtractLWEpacked(ctSwitched);

    if (numCtxts == 0 || numCtxts > slots) {
        numCtxts = slots;
    }

    uint32_t gap = ccKS->GetRingDimension() / (2 * slots);

    for (uint32_t i = 0, idx = 0; i < numCtxts; ++i, idx += gap) {
        LWEciphertexts[i] = ExtractLWECiphertext(AandB, m_modulus_CKKS_from, n, idx);
    }

    // Step 5. Modulus switch to q in FHEW

    // Compute the necessary factor to obtaine the message Q'/pLWE
    if (m_modulus_LWE != m_modulus_CKKS_from) {
#pragma omp parallel for
        for (uint32_t i = 0; i < numCtxts; i++) {
            auto original_a = LWEciphertexts[i]->GetA();
            auto original_b = LWEciphertexts[i]->GetB();
            // multiply by Q_LWE/Q' and round to Q_LWE
            NativeVector a_round(n, m_modulus_LWE);
            for (uint32_t j = 0; j < n; ++j) {
                a_round[j] = RoundqQAlter(original_a[j], m_modulus_LWE, m_modulus_CKKS_from);
            }
            NativeInteger b_round = RoundqQAlter(original_b, m_modulus_LWE, m_modulus_CKKS_from);
            LWEciphertexts[i]     = std::make_shared<LWECiphertextImpl>(std::move(a_round), std::move(b_round));
        }
    }

    return LWEciphertexts;
}

//------------------------------------------------------------------------------
// Scheme switching Wrapper
//------------------------------------------------------------------------------
void SWITCHCKKSRNS::EvalFHEWtoCKKSSetup(const CryptoContextImpl<DCRTPoly>& ccCKKS,
                                        const std::shared_ptr<BinFHEContext>& ccLWE, uint32_t numSlotsCKKS,
                                        uint32_t logQ) {
    m_ccLWE = ccLWE;

    if (m_ccLWE->GetParams()->GetLWEParams()->Getn() * 2 > ccCKKS.GetRingDimension())
        OPENFHE_THROW("The lattice parameter in LWE cannot be larger than half the RLWE ring dimension.");

    if (numSlotsCKKS == 0) {
        if (ccCKKS.GetEncodingParams()->GetBatchSize() != 0)
            m_numSlotsCKKS = ccCKKS.GetEncodingParams()->GetBatchSize();
        else
            m_numSlotsCKKS = ccCKKS.GetRingDimension() / 2;
    }
    else {
        m_numSlotsCKKS = numSlotsCKKS;
    }

    m_modulus_LWE = (logQ != 0) ? 1 << logQ : m_ccLWE->GetParams()->GetLWEParams()->Getq().ConvertToInt();
}

std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalFHEWtoCKKSKeyGen(
    const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk, uint32_t numSlots, uint32_t numCtxts, uint32_t dim1,
    uint32_t L) {
    auto privateKey = keyPair.secretKey;
    auto publicKey  = keyPair.publicKey;

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());
    auto ccCKKS             = privateKey->GetCryptoContext();

    uint32_t n       = lwesk->GetElement().GetLength();
    uint32_t ringDim = ccCKKS->GetRingDimension();

    // Generate FHEW to CKKS switching key, i.e., CKKS encryption of FHEW secret key. Pad up to the closest power of two
    uint32_t n_po2     = 1 << static_cast<uint32_t>(std::ceil(std::log2(n)));
    auto skLWEElements = lwesk->GetElement();
    std::vector<std::complex<double>> skLWEDouble(n_po2);
    for (uint32_t i = 0; i < n; i++) {
        auto tmp = skLWEElements[i].ConvertToDouble();
        if (tmp == lwesk->GetModulus().ConvertToInt() - 1)
            tmp = -1;
        skLWEDouble[i] = std::complex<double>(tmp, 0);
    }

    // Check encoding and specify the number of slots, otherwise, if batchsize is set and is smaller, it will throw an error.
    Plaintext skLWEPlainswk;
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, BASE_NUM_LEVELS_TO_DROP,
                                                        nullptr, ringDim / 2);
    else
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, 0, nullptr, ringDim / 2);

    m_FHEWtoCKKSswk = ccCKKS->Encrypt(publicKey, skLWEPlainswk);

    // Compute automorphism keys for CKKS for baby-step giant-step
    if (numCtxts == 0) {
        numCtxts = m_numSlotsCKKS;  // If no value is specified, default to the column size of the linear transformation
    }
    uint32_t M = ccCKKS->GetCyclotomicOrder();
    if (dim1 == 0)
        dim1 = getRatioBSGSLT(numCtxts);
    m_dim1FC = dim1;
    m_LFC    = L;

    // Compute indices for rotations for homomorphic decryption in CKKS
    std::vector<int32_t> indexRotationHomDec = FindLTRotationIndicesSwitch(dim1, M, numCtxts);

    // If the linear transform is wide instead of tall, we need extra rotations
    if (numCtxts < n_po2) {
        uint32_t logl = lbcrypto::GetMSB(n_po2 / numCtxts) - 1;  // These are powers of two, so log(l) is integer
        indexRotationHomDec.reserve(indexRotationHomDec.size() + logl);
        for (size_t j = 1; j <= logl; ++j) {
            indexRotationHomDec.emplace_back(numCtxts * (1 << (j - 1)));
        }
    }

    uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
    // Compute indices for rotations to bring back the final CKKS ciphertext encoding to slots
    if (ringDim > 2 * slots) {  // if the encoding is full, this does not execute
        indexRotationHomDec.reserve(indexRotationHomDec.size() + GetMSB(ringDim) - 2);
        for (uint32_t j = 1; j < ringDim / (2 * slots); j <<= 1) {
            indexRotationHomDec.emplace_back(j * slots);
        }
    }

    // Remove possible duplicates and zero
    sort(indexRotationHomDec.begin(), indexRotationHomDec.end());
    indexRotationHomDec.erase(unique(indexRotationHomDec.begin(), indexRotationHomDec.end()),
                              indexRotationHomDec.end());
    indexRotationHomDec.erase(std::remove(indexRotationHomDec.begin(), indexRotationHomDec.end(), 0),
                              indexRotationHomDec.end());

    auto algo     = ccCKKS->GetScheme();
    auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationHomDec);

    // Compute multiplication key
    ccCKKS->EvalMultKeyGen(privateKey);

    return evalKeys;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalFHEWtoCKKS(std::vector<std::shared_ptr<LWECiphertextImpl>>& LWECiphertexts,
                                                   uint32_t numCtxts, uint32_t numSlots, uint32_t p, double pmin,
                                                   double pmax, uint32_t dim1) const {
    if (!LWECiphertexts.size())
        OPENFHE_THROW("Empty input FHEW ciphertext vector");

    // This is the number of CKKS slots to use in eg_coefficientsFHEW128_9ncoding
    const uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    uint32_t numLWECtxts = LWECiphertexts.size();
    uint32_t numValues   = (numCtxts == 0) ? numLWECtxts : std::min(numCtxts, numLWECtxts);
    numValues = std::min(numValues, slots);  // This is the number of LWE ciphertexts to pack into the CKKS ciphertext

    uint32_t n = LWECiphertexts[0]->GetA().GetLength();

    auto ccCKKS                 = m_FHEWtoCKKSswk->GetCryptoContext();
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());

    uint32_t m    = 4 * slots;
    uint32_t M    = ccCKKS->GetCyclotomicOrder();
    uint32_t N    = ccCKKS->GetRingDimension();
    bool isSparse = (M != m) ? true : false;

    double K = 0.0;
    std::vector<double> coefficientsFHEW;  // EvalFHEWtoCKKS assumes lattice parameter n is at most 2048.
    if (n == 32) {
        K = 16.0;
        coefficientsFHEW.assign(g_coefficientsFHEW16);
    }
    else {
        K = 128.0;  // Failure probability of 2^{-49}
        if (p <= 4) {
            // If the output messages are bits, we could use a lower degree polynomial
            coefficientsFHEW.assign(g_coefficientsFHEW128_8);
        }
        else {
            coefficientsFHEW.assign(g_coefficientsFHEW128_9);
        }
    }

    // Step 1. Form matrix A and vector b from the LWE ciphertexts, but only extract the first necessary number of them
    std::vector<std::vector<std::complex<double>>> A(numValues);

    // To have the same encoding as A*s, create b with the appropriate number of elements
    const uint32_t b_size = ((numValues % n) != 0) ? (numValues + n - (numValues % n)) : numValues;
    std::vector<std::complex<double>> b(b_size);

    // Combine the scale with the division by K to consume fewer levels, but careful since the value might be too small
    const double prescale = (1.0 / LWECiphertexts[0]->GetModulus().ConvertToDouble()) / K;

#pragma omp parallel for
    for (uint32_t i = 0; i < numValues; i++) {
        auto a = LWECiphertexts[i]->GetA();
        A[i]   = std::vector<std::complex<double>>(a.GetLength());
        for (uint32_t j = 0; j < a.GetLength(); j++) {
            A[i][j] = std::complex<double>(a[j].ConvertToDouble(), 0);
        }
        b[i] = std::complex<double>(prescale * LWECiphertexts[i]->GetB().ConvertToDouble(), 0);
    }

    // Step 2. Perform the homomorphic linear transformation of A*skLWE
    if (dim1 == 0) {
        dim1 = m_dim1FC;
    }
    Ciphertext<DCRTPoly> AdotS = EvalPartialHomDecryption(*ccCKKS, A, m_FHEWtoCKKSswk, dim1, prescale, 0);

    // Step 3. Get the ciphertext of B - A*s
    Plaintext BPlain = ccCKKS->MakeCKKSPackedPlaintext(b, AdotS->GetNoiseScaleDeg(), AdotS->GetLevel(), nullptr, N / 2);

    auto BminusAdotS = ccCKKS->EvalAdd(ccCKKS->EvalNegate(AdotS), BPlain);

    if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL) {
        ccCKKS->ModReduceInPlace(BminusAdotS);
    }
    else {
        if (BminusAdotS->GetNoiseScaleDeg() == 2)
            ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS, BASE_NUM_LEVELS_TO_DROP);
    }

    // Step 4. Do the modulus reduction: homomorphically evaluate modular function. We do it by using sine approximation.
    // auto BminusAdotS2 = BminusAdotS;  // Instead of zeroing out slots which are not of interest as done above

    double a_cheby = -1;
    double b_cheby = 1;  // The division by K was performed before

    // double a_cheby = -K; double b_cheby = K; // Alternatively, do this separately to not lose precision when scaling with everything at once

    auto BminusAdotS3 = ccCKKS->EvalChebyshevSeries(BminusAdotS, coefficientsFHEW, a_cheby, b_cheby);

    if (cryptoParamsCKKS->GetScalingTechnique() != FIXEDMANUAL) {
        ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS3, BASE_NUM_LEVELS_TO_DROP);
    }

    enum { BT_ITER = 3 };
    for (int32_t j = 1; j < BT_ITER + 1; j++) {
        BminusAdotS3 = ccCKKS->EvalMult(BminusAdotS3, BminusAdotS3);
        ccCKKS->EvalAddInPlace(BminusAdotS3, BminusAdotS3);
        double scalar = 1.0 / std::pow((2.0 * Pi), std::pow(2.0, j - BT_ITER));
        ccCKKS->EvalSubInPlace(BminusAdotS3, scalar);
        if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL) {
            ccCKKS->ModReduceInPlace(BminusAdotS3);
        }
        else {
            ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS3, BASE_NUM_LEVELS_TO_DROP);
        }
    }

    /* For p <= 4 and when we only encrypt bits, we don't need sin(2pi*x)/2pi to approximate x,
     * we can directly use sin(0) for 0 and sin(pi/2) for 1.
     * Here pmax is actually the plaintext modulus, not the maximum value of the messages that we
     * consider. For plaintext modulus > 4, even if we only care about encrypting bits, 2pi is not
     * the correct post-scaling factor.
     * Moreover, we have to account for the different encoding the end ciphertext should have.
     */
    double postScale = (p >= 1 && p <= 4) ? (static_cast<double>(2) * Pi) : static_cast<double>(p);
    double postBias  = 0.0;
    if (pmin != 0) {
        postScale *= (pmax - pmin) / 4.0;
        postBias = (pmax - pmin) / 4.0;
    }

    // numValues are set; the rest of values up to N/2 are made zero when creating the plaintext
    std::vector<std::complex<double>> postScaleVec(numValues, std::complex<double>(postScale, 0));
    std::vector<std::complex<double>> postBiasVec(numValues, std::complex<double>(postBias, 0));

    ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());

    uint32_t towersToDrop = BminusAdotS3->GetLevel() + BminusAdotS3->GetNoiseScaleDeg() - 1;
    for (uint32_t i = 0; i < towersToDrop; i++)
        elementParams.PopLastParam();

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr  = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);
    auto elementParamsPtr2 = std::dynamic_pointer_cast<typename DCRTPoly::Params>(elementParamsPtr);

    // Use full packing here to clear up the junk in the slots after numValues
    auto postScalePlain = ccCKKS->MakeCKKSPackedPlaintext(postScaleVec, 1, towersToDrop, elementParamsPtr2, N / 2);
    auto BminusAdotSres = ccCKKS->EvalMult(BminusAdotS3, postScalePlain);

    // Add the plaintext for bias at the correct level and depth
    auto postBiasPlain = ccCKKS->MakeCKKSPackedPlaintext(postBiasVec, BminusAdotSres->GetNoiseScaleDeg(),
                                                         BminusAdotSres->GetLevel(), nullptr, N / 2);

    ccCKKS->EvalAddInPlace(BminusAdotSres, postBiasPlain);

    // Go back to the sparse encoding if needed
    if (isSparse) {
        for (uint32_t j = 1; j < N / (2 * slots); j <<= 1) {
            auto temp = ccCKKS->EvalAtIndex(BminusAdotSres, j * slots);
            ccCKKS->EvalAddInPlace(BminusAdotSres, temp);
        }
        BminusAdotSres->SetSlots(slots);
    }

    if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL) {
        ccCKKS->ModReduceInPlace(BminusAdotSres);
    }

    return BminusAdotSres;
}

LWEPrivateKey SWITCHCKKSRNS::EvalSchemeSwitchingSetup(const SchSwchParams& params) {
    // CKKS to FHEW
    auto lwesk = EvalCKKStoFHEWSetup(params);

    // FHEW to CKKS
    // Save the parameters to be used in EvalSchemeSwitchingKeyGen
    m_argmin = params.GetComputeArgmin();
    m_oneHot = params.GetOneHotEncoding();
    m_alt    = params.GetUseAltArgmin();

    // Set parameters for linear transform for FHEW to CKKS
    if (!m_argmin || (m_argmin && m_alt)) {
        m_numCtxts = (params.GetNumValues() == 0) ? m_numSlotsCKKS : params.GetNumValues();
    }
    else {  // argmin not in the alternative mode
        m_numCtxts = (params.GetNumValues() == 0) ? m_numSlotsCKKS / 2 : params.GetNumValues() / 2;
    }

    // There are multiple dim1's required in argmin, but they are specified individually in EvalSchemeSwitchingKeyGen
    m_dim1FC = (params.GetBStepLTrFHEWtoCKKS() == 0) ? getRatioBSGSLT(m_numCtxts) : params.GetBStepLTrFHEWtoCKKS();
    m_LFC    = params.GetLevelLTrFHEWtoCKKS();

    return lwesk;
}

std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalSchemeSwitchingKeyGen(
    const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk) {
    auto privateKey = keyPair.secretKey;
    auto publicKey  = keyPair.publicKey;

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());

    if (cryptoParams->GetKeySwitchTechnique() != HYBRID)
        OPENFHE_THROW("CKKS to FHEW scheme switching is only supported for the Hybrid key switching method.");
#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        OPENFHE_THROW("128-bit CKKS to FHEW scheme switching is supported for FIXEDMANUAL and FIXEDAUTO methods only.");
#endif

    auto ccCKKS = privateKey->GetCryptoContext();

    uint32_t M       = ccCKKS->GetCyclotomicOrder();
    uint32_t slots   = m_numSlotsCKKS;
    uint32_t n       = lwesk->GetElement().GetLength();
    uint32_t ringDim = ccCKKS->GetRingDimension();

    // Intermediate cryptocontext for CKKS to FHEW
    auto keys2 = m_ccKS->KeyGen();

    Plaintext ptxtZeroKS = m_ccKS->MakeCKKSPackedPlaintext(std::vector<double>{0.0}, 1, 0, nullptr, slots);
    m_ctxtKS             = m_ccKS->Encrypt(keys2.publicKey, ptxtZeroKS);

    // Compute switching key between RLWE and LWE via the intermediate cryptocontext, keep it in RLWE form
    m_CKKStoFHEWswk = switchingKeyGenRLWEcc(keys2.secretKey, privateKey, lwesk);

    // Generate FHEW to CKKS switching key, i.e., CKKS encryption of FHEW secret key. Pad up to the closest power of two
    uint32_t n_po2     = 1 << static_cast<uint32_t>(std::ceil(std::log2(n)));
    auto skLWEElements = lwesk->GetElement();
    std::vector<std::complex<double>> skLWEDouble(n_po2);
    for (uint32_t i = 0; i < n; i++) {
        auto tmp = skLWEElements[i].ConvertToDouble();
        if (tmp == lwesk->GetModulus().ConvertToInt() - 1)
            tmp = -1;
        skLWEDouble[i] = std::complex<double>(tmp, 0);
    }

    // Check encoding and specify the number of slots, otherwise, if batchsize is set and is smaller, it will throw an error.
    Plaintext skLWEPlainswk;
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, BASE_NUM_LEVELS_TO_DROP,
                                                        nullptr, ringDim / 2);
    else
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, 0, nullptr, ringDim / 2);

    m_FHEWtoCKKSswk = ccCKKS->Encrypt(publicKey, skLWEPlainswk);

    // Compute automorphism keys

    /* CKKS to FHEW */
    // Compute indices for rotations for slotToCoeff transform
    std::vector<int32_t> indexRotationS2C = FindLTRotationIndicesSwitch(m_dim1CF, M, slots);
    indexRotationS2C.emplace_back(static_cast<int32_t>(slots));

    // Compute indices for rotations for sparse packing
    if (ringDim > 2 * slots) {  // if the encoding is full, this does not execute
        indexRotationS2C.reserve(indexRotationS2C.size() + GetMSB(ringDim) - 2 + GetMSB(slots) - 1);
        for (uint32_t i = 1; i < ringDim / 2; i <<= 1) {
            indexRotationS2C.emplace_back(static_cast<int32_t>(i));
            if (i <= slots)
                indexRotationS2C.emplace_back(-static_cast<int32_t>(i));
        }
    }

    /* FHEW to CKKS */
    std::vector<int32_t> indexRotationHomDec;
    std::vector<int32_t> indexRotationArgmin;

    if (!m_argmin || (m_argmin && m_alt)) {
        // Compute indices for rotations for homomorphic decryption
        indexRotationHomDec = FindLTRotationIndicesSwitch(m_dim1FC, M, m_numCtxts);

        // If the linear transform is wide instead of tall, we need extra rotations
        if (m_numCtxts < n_po2) {
            uint32_t logl = lbcrypto::GetMSB(n_po2 / m_numCtxts) - 1;  // These are powers of two, so log(l) is integer
            indexRotationHomDec.reserve(indexRotationHomDec.size() + logl);
            for (size_t j = 1; j <= logl; ++j) {
                indexRotationHomDec.emplace_back(m_numCtxts * (1 << (j - 1)));
            }
        }
        if (m_argmin) {
            // Rotations for postprocessing after a level of the binary tree
            indexRotationArgmin.reserve(GetMSB(m_numCtxts) - 2);
            for (uint32_t i = 1; i < m_numCtxts; i <<= 1) {
                indexRotationArgmin.emplace_back(static_cast<int32_t>(m_numCtxts / (2 * i)));
            }
        }
    }
    else {  // argmin not in the alternative mode
        // Compute indices for rotations for all homomorphic decryptions for the levels of the tree
        indexRotationHomDec = FindLTRotationIndicesSwitchArgmin(M, m_numCtxts, n_po2);

        // Rotations for postprocessing after a level of the binary tree
        indexRotationArgmin.reserve(GetMSB(m_numCtxts) - 1 + 2 * (GetMSB(m_numCtxts) - 1));
        for (uint32_t i = 1; i < 2 * m_numCtxts; i <<= 1) {
            indexRotationArgmin.emplace_back(static_cast<int32_t>(m_numCtxts / (2 * i)));
            indexRotationArgmin.emplace_back(-static_cast<int32_t>(m_numCtxts / (2 * i)));
            if (i > 1) {
                for (uint32_t j = 2 * m_numCtxts / i; j < 2 * m_numCtxts; j <<= 1)
                    indexRotationArgmin.emplace_back(-static_cast<int32_t>(j));
            }
        }
    }

    // Compute indices for rotations to bring back the final CKKS ciphertext encoding to slots
    if (ringDim > 2 * slots) {  // if the encoding is full, this does not execute
        indexRotationHomDec.reserve(indexRotationHomDec.size() + GetMSB(ringDim) - 2);
        for (uint32_t j = 1; j < ringDim / (2 * slots); j <<= 1) {
            indexRotationHomDec.emplace_back(j * slots);
        }
    }

    // Combine the indices lists
    indexRotationS2C.reserve(indexRotationS2C.size() + indexRotationHomDec.size() + indexRotationArgmin.size());
    indexRotationS2C.insert(indexRotationS2C.end(), indexRotationHomDec.begin(), indexRotationHomDec.end());
    indexRotationS2C.insert(indexRotationS2C.end(), indexRotationArgmin.begin(), indexRotationArgmin.end());

    // Remove possible duplicates and zero
    sort(indexRotationS2C.begin(), indexRotationS2C.end());
    indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());
    indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());

    auto algo     = ccCKKS->GetScheme();
    auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationS2C);

    // Compute conjugation key
    auto conjKey       = ConjugateKeyGen(privateKey);
    (*evalKeys)[M - 1] = conjKey;

    // Compute multiplication key
    ccCKKS->EvalMultKeyGen(privateKey);

    // Compute automorphism keys if we don't want one hot encoding for argmin
    if (m_argmin && (m_oneHot == false)) {
        ccCKKS->EvalSumKeyGen(privateKey);
    }

    /* FHEW computations */
    // Generate the bootstrapping keys (refresh and switching keys)
    m_ccLWE->BTKeyGen(lwesk);

    return evalKeys;
}

void SWITCHCKKSRNS::EvalCompareSwitchPrecompute(const CryptoContextImpl<DCRTPoly>& ccCKKS, uint32_t pLWE,
                                                double scaleSign, bool unit) {
    double scaleCF = 1.0;
    if ((pLWE != 0) && (!unit)) {  // The messages are already scaled between 0 and 1, no need to divide by pLWE
        scaleCF = 1.0 / pLWE;
    }
    // Else perform no scaling; the implicit FHEW plaintext modulus will be m_modulus_CKKS_initial / scFactor
    scaleCF *= scaleSign;

    EvalCKKStoFHEWPrecompute(ccCKKS, scaleCF);
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalCompareSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext1,
                                                               ConstCiphertext<DCRTPoly> ciphertext2, uint32_t numCtxts,
                                                               uint32_t numSlots, uint32_t pLWE, double scaleSign,
                                                               bool unit) {
    auto ccCKKS             = ciphertext1->GetCryptoContext();
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());

    auto cDiff = ccCKKS->EvalSub(ciphertext1, ciphertext2);

    if (unit) {
        if (pLWE == 0)
            OPENFHE_THROW("To scale to the unit circle, pLWE must be non-zero.");
        else {
            cDiff = ccCKKS->EvalMult(cDiff, 1.0 / static_cast<double>(pLWE));
            cDiff = ccCKKS->Rescale(cDiff);
        }
    }

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0) {
        double scaleCF = 1.0;
        if ((pLWE != 0) && (!unit)) {
            scaleCF = 1.0 / pLWE;
        }
        scaleCF *= scaleSign;

        EvalCKKStoFHEWPrecompute(*ccCKKS, scaleCF);
    }

    auto LWECiphertexts = EvalCKKStoFHEW(cDiff, numCtxts);

    std::vector<LWECiphertext> cSigns(LWECiphertexts.size());
#pragma omp parallel for
    for (uint32_t i = 0; i < LWECiphertexts.size(); i++) {
        cSigns[i] = m_ccLWE->EvalSign(LWECiphertexts[i], true);
    }

    return EvalFHEWtoCKKS(cSigns, numCtxts, numSlots, 4, -1.0, 1.0, 0);
}

std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext,
                                                                        PublicKey<DCRTPoly> publicKey,
                                                                        uint32_t numValues, uint32_t numSlots,
                                                                        uint32_t pLWE, double scaleSign) {
    auto cc                 = ciphertext->GetCryptoContext();
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0) {
        double scaleCF = 1.0 / pLWE;
        scaleCF *= scaleSign;
        EvalCKKStoFHEWPrecompute(*cc, scaleCF);
    }

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS
    uint32_t slots        = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    Ciphertext<DCRTPoly> cInd          = cc->Encrypt(publicKey, pInd);
    Ciphertext<DCRTPoly> newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, numValues / (2 * M));

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(numValues / (2 * M));
#pragma omp parallel for
        for (uint32_t j = 0; j < numValues / (2 * M); j++) {
            LWESign[j] = m_ccLWE->EvalSign(cTemp[j], true);
        }

        // Scheme switching from FHEW to CKKS
        auto dim1    = getRatioBSGSLT(numValues / (2 * M));
        auto cSelect = EvalFHEWtoCKKS(LWESign, numValues / (2 * M), numSlots, 4, -1.0, 1.0, dim1);

        std::vector<std::complex<double>> ones(numValues / (2 * M), 1.0);
        Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
        cc->EvalAddInPlace(cSelect,
                           cc->EvalAtIndex(cc->EvalSub(ptxtOnes, cSelect), -static_cast<int32_t>(numValues / (2 * M))));

        auto cExpandSelect = cSelect;
        if (M > 1) {
            for (uint32_t j = numValues / M; j < numValues; j <<= 1)
                cc->EvalAddInPlace(cExpandSelect, cc->EvalAtIndex(cExpandSelect, -static_cast<int32_t>(j)));
        }

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(newCiphertext);
        }

        cInd = cc->EvalMult(cInd, cExpandSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(cInd);
        }
    }
    // After computing the minimum and argument
    if (!m_oneHot) {
        cInd = cc->EvalSum(cInd, numValues);
    }

    std::vector<Ciphertext<DCRTPoly>> cRes{newCiphertext, cInd};

    return cRes;
}

std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitchingAlt(ConstCiphertext<DCRTPoly> ciphertext,
                                                                           PublicKey<DCRTPoly> publicKey,
                                                                           uint32_t numValues, uint32_t numSlots,
                                                                           uint32_t pLWE, double scaleSign) {
    auto cc                 = ciphertext->GetCryptoContext();
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0) {
        double scaleCF = 1.0 / pLWE;
        scaleCF *= scaleSign;
        EvalCKKStoFHEWPrecompute(*cc, scaleCF);
    }

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output.
    uint32_t slots        = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    Ciphertext<DCRTPoly> cInd          = cc->Encrypt(publicKey, pInd);
    Ciphertext<DCRTPoly> newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, numValues / (2 * M));

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(numValues);
#pragma omp parallel for
        for (uint32_t j = 0; j < numValues / (2 * M); j++) {
            LWECiphertext tempSign    = m_ccLWE->EvalSign(cTemp[j], true);
            LWECiphertext negTempSign = std::make_shared<LWECiphertextImpl>(*tempSign);
            m_ccLWE->GetLWEScheme()->EvalAddConstEq(negTempSign, negTempSign->GetModulus() >> 1);  // "negated" tempSign
            for (uint32_t i = 0; i < 2 * M; i += 2) {
                LWESign[i * numValues / (2 * M) + j]       = tempSign;
                LWESign[(i + 1) * numValues / (2 * M) + j] = negTempSign;
            }
        }

        // Scheme switching from FHEW to CKKS
        auto dim1          = getRatioBSGSLT(numValues);
        auto cExpandSelect = EvalFHEWtoCKKS(LWESign, numValues, numSlots, 4, -1.0, 1.0, dim1);

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(newCiphertext);
        }

        cInd = cc->EvalMult(cInd, cExpandSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(cInd);
        }
    }
    // After computing the minimum and argument
    if (!m_oneHot) {
        cInd = cc->EvalSum(cInd, numValues);
    }

    std::vector<Ciphertext<DCRTPoly>> cRes{newCiphertext, cInd};

    return cRes;
}

std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMaxSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext,
                                                                        PublicKey<DCRTPoly> publicKey,
                                                                        uint32_t numValues, uint32_t numSlots,
                                                                        uint32_t pLWE, double scaleSign) {
    auto cc                 = ciphertext->GetCryptoContext();
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0) {
        double scaleCF = 1.0 / pLWE;
        scaleCF *= scaleSign;
        EvalCKKStoFHEWPrecompute(*cc, scaleCF);
    }

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output.
    uint32_t slots        = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    Ciphertext<DCRTPoly> cInd          = cc->Encrypt(publicKey, pInd);
    Ciphertext<DCRTPoly> newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, numValues / (2 * M));

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(numValues / (2 * M));
#pragma omp parallel for
        for (uint32_t j = 0; j < numValues / (2 * M); j++) {
            LWESign[j] = m_ccLWE->EvalSign(cTemp[j], true);
        }

        // Scheme switching from FHEW to CKKS
        auto dim1    = getRatioBSGSLT(numValues / (2 * M));
        auto cSelect = EvalFHEWtoCKKS(LWESign, numValues / (2 * M), numSlots, 4, -1.0, 1.0, dim1);

        std::vector<std::complex<double>> ones(numValues / (2 * M), 1.0);
        Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
        cSelect            = cc->EvalAdd(cc->EvalSub(ptxtOnes, cSelect),
                                         cc->EvalAtIndex(cSelect, -static_cast<int32_t>(numValues / (2 * M))));

        auto cExpandSelect = cSelect;
        if (M > 1) {
            for (uint32_t j = numValues / M; j < numValues; j <<= 1)
                cc->EvalAddInPlace(cExpandSelect, cc->EvalAtIndex(cExpandSelect, -static_cast<int32_t>(j)));
        }

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(newCiphertext);
        }

        cInd = cc->EvalMult(cInd, cExpandSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(cInd);
        }
    }
    // After computing the minimum and argument
    if (!m_oneHot) {
        cInd = cc->EvalSum(cInd, numValues);
    }

    std::vector<Ciphertext<DCRTPoly>> cRes{newCiphertext, cInd};

    return cRes;
}

std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMaxSchemeSwitchingAlt(ConstCiphertext<DCRTPoly> ciphertext,
                                                                           PublicKey<DCRTPoly> publicKey,
                                                                           uint32_t numValues, uint32_t numSlots,
                                                                           uint32_t pLWE, double scaleSign) {
    auto cc                 = ciphertext->GetCryptoContext();
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0) {
        double scaleCF = 1.0 / pLWE;
        scaleCF *= scaleSign;
        EvalCKKStoFHEWPrecompute(*cc, scaleCF);
    }

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output
    uint32_t slots        = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    Ciphertext<DCRTPoly> cInd          = cc->Encrypt(publicKey, pInd);
    Ciphertext<DCRTPoly> newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, numValues / (2 * M));

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(numValues);
#pragma omp parallel for
        for (uint32_t j = 0; j < numValues / (2 * M); j++) {
            LWECiphertext tempSign    = m_ccLWE->EvalSign(cTemp[j], true);
            LWECiphertext negTempSign = std::make_shared<LWECiphertextImpl>(*tempSign);
            m_ccLWE->GetLWEScheme()->EvalAddConstEq(negTempSign, negTempSign->GetModulus() >> 1);  // "negated" tempSign
            for (uint32_t i = 0; i < 2 * M; i += 2) {
                LWESign[i * numValues / (2 * M) + j]       = negTempSign;
                LWESign[(i + 1) * numValues / (2 * M) + j] = tempSign;
            }
        }

        // Scheme switching from FHEW to CKKS
        auto dim1          = getRatioBSGSLT(numValues);
        auto cExpandSelect = EvalFHEWtoCKKS(LWESign, numValues, numSlots, 4, -1.0, 1.0, dim1);

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, numValues / (2 * M)));

        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(newCiphertext);
        }

        cInd = cc->EvalMult(cInd, cExpandSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL) {
            cc->ModReduceInPlace(cInd);
        }
    }
    // After computing the minimum and argument
    if (!m_oneHot) {
        cInd = cc->EvalSum(cInd, numValues);
    }

    std::vector<Ciphertext<DCRTPoly>> cRes{newCiphertext, cInd};

    return cRes;
}

}  // namespace lbcrypto