Program Listing for File ckksrns-schemeswitching.cpp

Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-schemeswitching.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
    CKKS to FHEW scheme switching implementation.
 */

#define PROFILE

#include "cryptocontext.h"
#include "gen-cryptocontext.h"
#include "math/dftransform.h"
#include "scheme/ckksrns/ckksrns-fhe.h"
#include "scheme/ckksrns/ckksrns-schemeswitching.h"
#include "scheme/ckksrns/gen-cryptocontext-ckksrns.h"

#include <algorithm>
#include <cmath>
#include <iterator>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

// K = 16
// EvalChebyshevCoefficients([](double x) -> double {return std::pow(2.*M_PI, -1./8.) * std::cos(2.*M_PI/8. * (x - 0.25));}, -16, 16, 117)
static constexpr std::initializer_list<double> g_coefficientsFHEW16{
    0.2455457340168511,     -0.04791906488334782,   0.2838870204084082,     -0.02994453873551349,
    0.3557652261903648,     0.01510656188507299,    0.2953294667450001,     0.07120360233373937,
    -0.1034734733966807,    0.04499759051255525,    -0.4275071243192574,    -0.09034212972909554,
    0.367628762693249,      0.04931806603933471,    -0.14535986272412,      -0.01510693848306369,
    0.03595193549924024,    0.003103658218868759,   -0.006264460660707066,  -0.0004660943047712052,
    0.0008212879885240095,  5.391053389217882e-05,  -8.455154976914221e-05, -4.977380178602518e-06,
    7.04666204400644e-06,   3.765980757166835e-07,  -4.864851013612591e-07, -2.383026746930811e-08,
    2.832970640938316e-08,  1.28177429687131e-09,   -1.412145521045378e-09, -5.939145408994641e-11,
    6.099273252183116e-11,  2.397381728164642e-12,  -2.307402856353623e-12, -8.500921247536622e-14,
    7.704571444110577e-14,  2.704051671841271e-15,  -2.154585361348821e-15, -7.263493008564584e-16,
    -1.260761739828568e-16, 1.637108527837095e-16,  1.185492382226862e-16,  6.379078056744543e-16,
    -1.411300455031979e-16, 3.123678340470779e-16,  5.946279250534737e-16,  2.954322285866942e-16,
    -8.279629336187608e-17, 5.024229619913844e-16,  -3.293034395074617e-16, -1.189255850106947e-15,
    1.674743206637948e-16,  -1.524204491434537e-16, -7.90328254817908e-17,  3.95164127408954e-16,
    -1.317213758029847e-17, 8.016186584581639e-16,  -3.650563843682718e-16, 3.763467880085276e-16,
    -2.709696873661399e-16, -1.524204491434537e-16, -4.83605622590958e-16,  7.376397044967142e-16,
    1.234417464667971e-15,  -2.672062194860546e-16, -4.892508244110859e-17, -7.122362963061386e-16,
    3.763467880085276e-17,  -2.944913616166729e-16, -2.897870267665663e-16, 9.794425157921932e-16,
    -3.198947698072485e-17, 6.614294799249873e-16,  -5.7769231959309e-16,   6.586068790149234e-16,
    -4.629065492504889e-16, -5.127724986616189e-16, -3.236582376873337e-16, -1.64745806450733e-15,
    -9.408669700213192e-16, -4.986594941112991e-16, -1.209954923447416e-15, -1.373665776231126e-16,
    -2.314532746252445e-16, 3.217765037472911e-16,  3.481207789078881e-16,  8.223177317986329e-16,
    -9.766199148821293e-16, 6.19090466274028e-16,   1.209014056477395e-15,  -3.30244306477483e-16,
    5.974505259635377e-16,  5.993322599035803e-16,  1.829986256691466e-16,  -2.690879534260973e-16,
    8.618341445395283e-16,  -1.002023323072705e-16, 6.374373721894436e-16,  6.270878355192092e-16,
    1.199605386777182e-15,  -8.712428142397415e-16, -2.507410475106815e-16, -1.086230916889613e-15,
    1.072588345824304e-15,  -4.534978795502758e-16, 2.119067633230516e-15,  -1.842923177529259e-15,
    -1.814697168428619e-15, 4.243310034796149e-16,  4.224492695395723e-16,  1.531966643937213e-15,
    -2.850826919164597e-16, -8.958229638315484e-16, -5.02893395476395e-16,  1.096110020074837e-16,
    -6.975352498995555e-16, -8.743006318923108e-16};

// K = 128
// EvalChebyshevCoefficients([](double x) -> double {return std::pow(2.*M_PI, -1./8.) * std::cos(2.*M_PI/8. * (x - 0.25));}, -128, 128, 159)
static constexpr std::initializer_list<double> g_coefficientsFHEW128_9{
    0.08761193238226354,    -0.01738402917379392,   0.08935060894767313,    -0.01667686631436392,
    0.09435445639097996,    -0.01518333497826596,   0.1019473189108075,     -0.01276275748916528,
    0.110882655474149,      -0.009252446966171999,  0.1192111685574758,     -0.004534979909938953,
    0.1242004317120066,     0.001362904847617233,   0.1224283765086551,     0.008145596233693092,
    0.1102080588183085,     0.01512350467093644,    0.08449405378412403,    0.02114203679334985,
    0.04431786059830115,    0.02464956129638114,    -0.007454366487154669,  0.02400059020366966,
    -0.06266441339261235,   0.0180491215413637,     -0.1077943201829795,    0.00695836538813938,
    -0.1265848641500751,    -0.007067567033131986,  -0.1060856934163377,    -0.01966175019277508,
    -0.04512467324356773,   -0.02537595733026167,   0.03862916785371963,    -0.0201785566296389,
    0.1092652333753526,     -0.004612578019766411,  0.1263344585514989,     0.01438496124842956,
    0.07022427857484209,    0.02550072245548077,    -0.03434514153678107,   0.01979242584243335,
    -0.1194659697149694,    -0.001008794768691528,  -0.1149256786653964,    -0.02192904329965044,
    -0.01184295110147364,   -0.02417858011117596,   0.1066507410103885,     -0.003076473516323838,
    0.122343225763269,      0.02209885820126707,    0.005200840409852563,   0.02321022960558625,
    -0.1224755172356864,    -0.003930982569218595,  -0.1000653894904632,    -0.02689795846413602,
    0.05865754664309684,    -0.01297065380451242,   0.1377909895596227,     0.02083617539534925,
    0.006502421233003679,   0.02248299870285675,    -0.139660074659475,     -0.01399307934458518,
    -0.04589168496663835,   -0.0263421662574377,    0.1358978738303921,     0.01130242907664306,
    0.05563799538901031,    0.02715486116995984,    -0.1426236952996719,    -0.01461041285557406,
    -0.03302834981489188,   -0.02454368648125577,   0.1559877850928394,     0.02360418859443232,
    -0.03051465817859748,   0.01394389273916019,    -0.1434779685133351,    -0.0326137520114734,
    0.1272587840850199,     0.00968806150092634,    0.04489729856072615,    0.02496761251245225,
    -0.1723551233719199,    -0.03505277577503257,   0.1396636892583818,     0.01468861799711852,
    0.00597622458952793,    0.01686435635501478,    -0.1508869780062401,    -0.0392626068463298,
    0.2221665014327329,     0.04513725581939847,    -0.2157338005707834,    -0.03852627732394119,
    0.1657363840292956,     0.02705951812948022,    -0.1076077703571204,    -0.01637507278621027,
    0.06108202920758021,    0.00876339054415577,    -0.03094805600072437,   -0.004218297546715713,
    0.01419483272929196,    0.001848310901625205,   -0.005954927442783235,  -0.000743844834357433,
    0.002303049930851211,   0.000276890872388833,   -0.0008263094529170254, -9.587788269377866e-05,
    0.000276459761481133,   3.10280277328683e-05,   -8.662530848949058e-05, -9.421991495095434e-06,
    2.551403977799462e-05,  2.693817139509806e-06,  -7.08627853168575e-06,  -7.273160333789605e-07,
    1.861116908629967e-06,  1.859288982562724e-07,  -4.633601616493963e-07, -4.510790623761216e-08,
    1.096004151863654e-07,  1.040757606442787e-08,  -2.467843591423806e-08, -2.288015736340782e-09,
    5.299290810294302e-09,  4.800959747999802e-10,  -1.086991796689273e-09, -9.63033831430537e-11,
    2.133040952766737e-10,  1.849336626863696e-11,  -4.010173216641153e-11, -3.404204559101731e-12,
    7.232799297263385e-12,  6.027665442316686e-13,  -1.252868789531569e-12, -1.035300267428826e-13,
    2.072453780944445e-13,  1.81572061755555e-14,   -3.012503176280137e-14, -4.417490972407089e-16,
    3.698522891563647e-15,  -4.204154533937635e-16, -2.740777660720187e-15, -1.348919106364917e-15,
    -1.620799477984723e-15, 4.003965342611375e-16,  -5.245330582249314e-16, 1.754761547401069e-15,
    -5.0481471966847e-16,   -4.722624632690369e-16, 1.628901569091919e-16,  -1.219903204684612e-15};

// EvalChebyshevCoefficients([](double x) -> double {return std::pow(2.*M_PI, -1./8.) * std::cos(2.*M_PI/8. * (x - 0.25));}, -128, 128, 118)
static constexpr std::initializer_list<double> g_coefficientsFHEW128_8{
    0.08761193238226343,    -0.01738402917379268,  0.08935060894767202,   -0.0166768663143651,   0.09435445639098095,
    -0.01518333497826714,   0.1019473189108076,    -0.01276275748916462,  0.1108826554741475,    -0.009252446966171845,
    0.1192111685574773,     -0.004534979909938402, 0.1242004317120066,    0.001362904847616587,  0.1224283765086535,
    0.008145596233693802,   0.1102080588183083,    0.0151235046709367,    0.08449405378412395,   0.02114203679334948,
    0.04431786059830203,    0.02464956129638117,   -0.007454366487155707, 0.02400059020367158,   -0.06266441339261287,
    0.01804912154136392,    -0.107794320182978,    0.006958365388138488,  -0.1265848641500738,   -0.007067567033133184,
    -0.1060856934163389,    -0.01966175019277399,  -0.0451246732435682,   -0.02537595733026211,  0.0386291678537217,
    -0.02017855662963969,   0.1092652333753532,    -0.004612578019767425, 0.1263344585514991,    0.01438496124843117,
    0.07022427857484087,    0.02550072245548053,   -0.03434514153678073,  0.01979242584243296,   -0.1194659697149702,
    -0.00100879476868968,   -0.1149256786653952,   -0.02192904329965062,  -0.01184295110147335,  -0.02417858011117619,
    0.1066507410103884,     -0.003076473516322021, 0.1223432257632692,    0.02209885820126752,   0.005200840409853516,
    0.02321022960558683,    -0.1224755172356849,   -0.003930982569218244, -0.1000653894904628,   -0.02689795846413568,
    0.05865754664309823,    -0.01297065380451253,  0.1377909895596233,    0.02083617539534807,   0.006502421233004046,
    0.02248299870285591,    -0.1396600746594754,   -0.01399307934458444,  -0.04589168496663817,  -0.02634216625743798,
    0.1358978738303917,     0.0113024290766429,    0.05563799538901171,   0.02715486116995986,   -0.1426236952996744,
    -0.01461041285557423,   -0.03302834981489241,  -0.02454368648125667,  0.155987785092838,     0.02360418859443058,
    -0.03051465817859778,   0.01394389273915945,   -0.1434779685133346,   -0.03261375201147241,  0.1272587840850196,
    0.009688061500927738,   0.04489729856072736,   0.02496761251245433,   -0.1723551233719191,   -0.03505277577503064,
    0.1396636892583768,     0.01468861799712161,   0.005976224589562133,  0.01686435635499993,   -0.1508869780064481,
    -0.03926260684622985,   0.2221665014339838,    0.04513725581879824,   -0.2157338005780147,   -0.03852627732053739,
    0.1657363840693943,     0.02705951811098469,   -0.1076077705704247,   -0.0163750726899083,   0.06108203029457551,
    0.008763390064061401,   -0.03094806130001855,  -0.00421829525869962,  0.01419485740772663,   0.001848300494047458,
    -0.005955037043195616,  -0.000743799726455706, 0.002303513291010024,  0.0002767049434914361, -0.0008281705698254814,
    -9.515056665793518e-05, 0.0002835460400168608, 2.833421059257267e-05, -0.0001121393482639905};

namespace lbcrypto {

//------------------------------------------------------------------------------
// Key and modulus switch and extraction methods
//------------------------------------------------------------------------------

static NativeInteger RoundqQAlter(NativeInteger v, NativeInteger q, NativeInteger Q) {
    return NativeInteger(static_cast<BasicInteger>(
                             std::floor(0.5 + v.ConvertToDouble() * q.ConvertToDouble() / Q.ConvertToDouble())))
        .Mod(q);
}

// TODO: used anywhere?
EvalKey<DCRTPoly> switchingKeyGenRLWE(const PrivateKey<DCRTPoly>& ckksSK, ConstLWEPrivateKey& LWEsk) {
    // This function is without the intermediate ModSwitch
    // Extract CKKS params: method which populates the first n elements of a new RLWE key with the n elements of the target LWE key
    auto skelements    = ckksSK->GetPrivateElement();
    auto lweskElements = LWEsk->GetElement();

    for (size_t i = 0; i < skelements.GetNumOfElements(); i++) {
        auto skelementsPlain = skelements.GetElementAtIndex(i);
        skelementsPlain.SetFormat(Format::COEFFICIENT);
        for (size_t j = 0; j < skelementsPlain.GetLength(); j++) {
            if (j >= lweskElements.GetLength()) {
                skelementsPlain[j] = 0;
            }
            else {
                if (lweskElements[j] == 0) {
                    skelementsPlain[j] = 0;
                }
                else if (lweskElements[j].ConvertToInt() == 1) {
                    skelementsPlain[j] = 1;
                }
                else
                    skelementsPlain[j] = skelementsPlain.GetModulus() - 1;
            }
        }
        skelementsPlain.SetFormat(Format::EVALUATION);
        skelements.SetElementAtIndex(i, std::move(skelementsPlain));
    }
    skelements.OverrideFormat(Format::EVALUATION);

    auto ccCKKS    = ckksSK->GetCryptoContext();
    auto RLWELWEsk = ccCKKS->KeyGen().secretKey;
    RLWELWEsk->SetPrivateElement(std::move(skelements));

    return ccCKKS->KeySwitchGen(ckksSK, RLWELWEsk);
}

void ModSwitch(ConstCiphertext<DCRTPoly> ctxt, Ciphertext<DCRTPoly>& ctxtKS, NativeInteger modulus_CKKS_to) {
    if (ctxt->GetElements()[0].GetRingDimension() != ctxtKS->GetElements()[0].GetRingDimension())
        OPENFHE_THROW("ModSwitch is implemented only for the same ring dimension.");
    if (ctxt->GetElements()[0].GetNumOfElements() != 1 || ctxtKS->GetElements()[0].GetNumOfElements() != 1)
        OPENFHE_THROW("ModSwitch is implemented only for ciphertext with one tower.");

    std::vector<DCRTPoly> resultElements;
    const auto& paramsQlP = ctxtKS->GetElements()[0].GetParams();
    for (const auto& elem : ctxt->GetElements()) {
        auto& ref = resultElements.emplace_back(paramsQlP, Format::COEFFICIENT, true);
        ref.SetValuesModSwitch(elem, modulus_CKKS_to);
        ref.SetFormat(Format::EVALUATION);
    }
    ctxtKS->SetElements(std::move(resultElements));
}

// TODO: used anywhere?
EvalKey<DCRTPoly> switchingKeyGen(const PrivateKey<DCRTPoly>& ckksSKto, const PrivateKey<DCRTPoly>& ckksSKfrom) {
    auto skElements = ckksSKto->GetPrivateElement();
    skElements.SetFormat(Format::COEFFICIENT);
    auto skElementsFrom = ckksSKfrom->GetPrivateElement();
    skElementsFrom.SetFormat(Format::COEFFICIENT);

    for (size_t i = 0; i < skElements.GetNumOfElements(); i++) {
        auto skElementsPlain     = skElements.GetElementAtIndex(i);
        auto skElementsFromPlain = skElementsFrom.GetElementAtIndex(i);
        for (size_t j = 0; j < skElementsPlain.GetLength(); j++) {
            if (skElementsFromPlain[j] == 0) {
                skElementsPlain[j] = 0;
            }
            else if (skElementsFromPlain[j] == 1) {
                skElementsPlain[j] = 1;
            }
            else
                skElementsPlain[j] = skElementsPlain.GetModulus() - 1;
        }
        skElements.SetElementAtIndex(i, std::move(skElementsPlain));
    }

    skElements.SetFormat(Format::EVALUATION);

    auto ccCKKSto        = ckksSKto->GetCryptoContext();
    auto oldTranformedSK = ccCKKSto->KeyGen().secretKey;
    oldTranformedSK->SetPrivateElement(std::move(skElements));

    return ccCKKSto->KeySwitchGen(oldTranformedSK, ckksSKto);
}

EvalKey<DCRTPoly> switchingKeyGenRLWEcc(const PrivateKey<DCRTPoly>& ckksSKto, const PrivateKey<DCRTPoly>& ckksSKfrom,
                                        ConstLWEPrivateKey& LWEsk) {
    auto skElements = ckksSKto->GetPrivateElement();
    skElements.SetFormat(Format::COEFFICIENT);
    auto skElementsFrom = ckksSKfrom->GetPrivateElement();
    skElementsFrom.SetFormat(Format::COEFFICIENT);
    auto skElements2 = ckksSKto->GetPrivateElement();
    skElements2.SetFormat(Format::COEFFICIENT);
    auto lweskElements = LWEsk->GetElement();

    for (uint32_t i = 0; i < skElements.GetNumOfElements(); ++i) {
        auto skElementsPlain     = skElements.GetElementAtIndex(i);
        auto skElementsFromPlain = skElementsFrom.GetElementAtIndex(i);
        auto skElementsPlainLWE  = skElements2.GetElementAtIndex(i);
        for (uint32_t j = 0; j < skElementsPlain.GetLength(); ++j) {
            if (skElementsFromPlain[j] == 0)
                skElementsPlain[j] = 0;
            else if (skElementsFromPlain[j] == 1)
                skElementsPlain[j] = 1;
            else
                skElementsPlain[j] = skElementsPlain.GetModulus() - 1;

            if (j >= lweskElements.GetLength()) {
                skElementsPlainLWE[j] = 0;
            }
            else {
                if (lweskElements[j] == 0)
                    skElementsPlainLWE[j] = 0;
                else if (lweskElements[j].ConvertToInt() == 1)
                    skElementsPlainLWE[j] = 1;
                else
                    skElementsPlainLWE[j] = skElementsPlain.GetModulus() - 1;
            }
        }
        skElements.SetElementAtIndex(i, std::move(skElementsPlain));
        skElements2.SetElementAtIndex(i, std::move(skElementsPlainLWE));
    }
    skElements.SetFormat(Format::EVALUATION);
    skElements2.SetFormat(Format::EVALUATION);

    auto ccCKKSto        = ckksSKto->GetCryptoContext();
    auto oldTranformedSK = ccCKKSto->KeyGen().secretKey;
    oldTranformedSK->SetPrivateElement(std::move(skElements));
    auto RLWELWEsk = ccCKKSto->KeyGen().secretKey;
    RLWELWEsk->SetPrivateElement(std::move(skElements2));

    return ccCKKSto->KeySwitchGen(oldTranformedSK, RLWELWEsk);
}

std::vector<std::vector<NativeInteger>> ExtractLWEpacked(const Ciphertext<DCRTPoly>& ct) {
    auto originalA{(ct->GetElements()[1]).GetElementAtIndex(0)};
    originalA.SetFormat(Format::COEFFICIENT);
    auto* ptrA = &originalA.GetValues()[0];
    auto originalB{(ct->GetElements()[0]).GetElementAtIndex(0)};
    originalB.SetFormat(Format::COEFFICIENT);
    auto* ptrB = &originalB.GetValues()[0];
    size_t N   = originalB.GetLength();
    return {std::vector<NativeInteger>(ptrB, ptrB + N), std::vector<NativeInteger>(ptrA, ptrA + N)};
}

std::shared_ptr<LWECiphertextImpl> ExtractLWECiphertext(const std::vector<std::vector<NativeInteger>>& aANDb,
                                                        NativeInteger modulus, uint32_t n, uint32_t index = 0) {
    NativeVector a(n, modulus);
    for (uint32_t i = 0; i < n && i <= index; ++i)
        a[i] = modulus - aANDb[1][index - i];
    if (n > index) {
        uint32_t N = aANDb[0].size();
        for (uint32_t i = index + 1; i < n; ++i)
            a[i] = aANDb[1][N + index - i];
    }
    return std::make_shared<LWECiphertextImpl>(std::move(a), aANDb[0][index]);
}

//------------------------------------------------------------------------------
// Linear transformation methods.
//------------------------------------------------------------------------------

std::vector<ReadOnlyPlaintext> SWITCHCKKSRNS::EvalLTPrecomputeSwitch(
    const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A,
    const std::vector<std::vector<std::complex<double>>>& B, uint32_t dim1, uint32_t L, double scale = 1.0) const {
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());

    auto elementParams = *(cryptoParamsCKKS->GetElementParams());

    uint32_t towersToDrop = 0;
    if (L != 0) {
        towersToDrop = elementParams.GetParams().size() - L - 1;
        for (uint32_t i = 0; i < towersToDrop; ++i)
            elementParams.PopLastParam();
    }

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(cc.GetCyclotomicOrder(), moduli, roots);

    int32_t slots = A.size();
    int32_t step  = (dim1 == 0) ? getRatioBSGSLT(slots) : dim1;

    std::vector<std::vector<std::complex<double>>> newA(slots);

    //  A and B are concatenated horizontally
    for (uint32_t i = 0; i < A.size(); i++) {
        newA[i].reserve(A[i].size() + B[i].size());
        newA[i].insert(newA[i].end(), A[i].begin(), A[i].end());
        newA[i].insert(newA[i].end(), B[i].begin(), B[i].end());
    }

    uint32_t M4 = cc.GetCyclotomicOrder() / 4;
    std::vector<ReadOnlyPlaintext> result(slots);
#if !defined(__MINGW32__) && !defined(__MINGW64__)
    #pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(slots))
#endif
    for (int32_t ji = 0; ji < slots; ++ji) {
        auto vec = ExtractShiftedDiagonal(newA, ji);
        for (auto& v : vec)
            v *= scale;
        result[ji] = FHECKKSRNS::MakeAuxPlaintext(cc, elementParamsPtr, Rotate(Fill(vec, M4), -step * (ji / step)), 1,
                                                  towersToDrop, M4);
    }
    return result;
}

std::vector<ReadOnlyPlaintext> SWITCHCKKSRNS::EvalLTPrecomputeSwitch(
    const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A, uint32_t dim1,
    uint32_t L, double scale = 1.0) const {
    if (A[0].size() != A.size())
        OPENFHE_THROW("The matrix passed to EvalLTPrecomputeSwitch is not square");

    // Make sure the plaintext is created only with the necessary amount of moduli
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());

    auto elementParams = *(cryptoParams->GetElementParams());

    uint32_t towersToDrop = 0;
    if (L != 0) {
        towersToDrop = elementParams.GetParams().size() - L - 1;
        for (uint32_t i = 0; i < towersToDrop; i++)
            elementParams.PopLastParam();
    }

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParams->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(cc.GetCyclotomicOrder(), moduli, roots);

    int32_t slots = A.size();
    int32_t step  = (dim1 == 0) ? getRatioBSGSLT(slots) : dim1;

    uint32_t M4 = cc.GetCyclotomicOrder() / 4;
    std::vector<ReadOnlyPlaintext> result(slots);
#if !defined(__MINGW32__) && !defined(__MINGW64__)
    #pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(slots))
#endif
    for (int32_t ji = 0; ji < slots; ++ji) {
        auto vec = ExtractShiftedDiagonal(A, ji);
        for (auto& v : vec)
            v *= scale;
        result[ji] = FHECKKSRNS::MakeAuxPlaintext(cc, elementParamsPtr, Rotate(Fill(vec, M4), -step * (ji / step)), 1,
                                                  towersToDrop, M4);
    }
    return result;
}

std::vector<std::vector<std::complex<double>>> EvalLTRectPrecomputeSwitch(
    const std::vector<std::vector<std::complex<double>>>& A, uint32_t dim1, double scale) {
    if (!IsPowerOfTwo(A.size()) || !IsPowerOfTwo(A[0].size()))
        OPENFHE_THROW("The matrix passed to EvalLTPrecompute is not padded up to powers of two");

    const uint32_t n = std::min(A.size(), A[0].size());
    std::vector<std::vector<std::complex<double>>> diags(n);

    if (A.size() >= A[0].size()) {
        uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(n) : dim1;
        uint32_t gStep = std::ceil(static_cast<double>(n) / bStep);

        auto num_slices = A.size() / A[0].size();
        std::vector<std::vector<std::vector<std::complex<double>>>> A_slices(num_slices);
        for (size_t i = 0; i < num_slices; i++) {
            A_slices[i] = std::vector<std::vector<std::complex<double>>>(A.begin() + i * A[0].size(),
                                                                         A.begin() + (i + 1) * A[0].size());
        }
#pragma omp parallel for
        for (uint32_t j = 0; j < gStep; j++) {
            for (uint32_t i = 0; i < bStep; i++) {
                if (bStep * j + i < n) {
                    std::vector<std::complex<double>> diag;
                    diag.reserve(A.size() * num_slices);
                    for (uint32_t k = 0; k < num_slices; k++) {
                        auto tmp = ExtractShiftedDiagonal(A_slices[k], bStep * j + i);
                        diag.insert(diag.end(), std::make_move_iterator(tmp.begin()),
                                    std::make_move_iterator(tmp.end()));
                    }
                    std::transform(diag.begin(), diag.end(), diag.begin(),
                                   [&](const std::complex<double>& elem) { return elem * scale; });
                    diags[bStep * j + i] = std::move(diag);
                }
            }
        }
    }
    else {
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
        for (uint32_t ji = 0; ji < n; ++ji) {
            auto diag = ExtractShiftedDiagonal(A, ji);
            for (auto& d : diag)
                d *= scale;
            diags[ji] = std::move(diag);
        }
    }
    return diags;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalLTWithPrecomputeSwitch(const CryptoContextImpl<DCRTPoly>& cc,
                                                               ConstCiphertext<DCRTPoly> ctxt,
                                                               const std::vector<ReadOnlyPlaintext>& A,
                                                               uint32_t dim1) const {
    // Computing the baby-step bStep and the giant-step gStep
    uint32_t slots = A.size();
    uint32_t bStep = dim1;
    uint32_t gStep = std::ceil(static_cast<double>(slots) / bStep);

    // Computes the NTTs for each CRT limb (for the hoisted automorphisms used later on)
    auto digits = cc.EvalFastRotationPrecompute(ctxt);

    // Hoisted automorphisms
    std::vector<Ciphertext<DCRTPoly>> fastRotation(bStep - 1);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(bStep - 1))
    for (uint32_t j = 1; j < bStep; ++j)
        fastRotation[j - 1] = cc.EvalFastRotationExt(ctxt, j, digits, true);

    uint32_t M = cc.GetCyclotomicOrder();
    uint32_t N = cc.GetRingDimension();
    std::vector<uint32_t> map(N);
    Ciphertext<DCRTPoly> result;
    DCRTPoly first;

    for (uint32_t j = 0; j < gStep; ++j) {
        auto inner = FHECKKSRNS::EvalMultExt(cc.KeySwitchExt(ctxt, true), A[bStep * j]);
        for (uint32_t i = 1; i < bStep; ++i) {
            if (bStep * j + i < slots)
                FHECKKSRNS::EvalAddExtInPlace(inner, FHECKKSRNS::EvalMultExt(fastRotation[i - 1], A[bStep * j + i]));
        }

        if (j == 0) {
            first         = cc.KeySwitchDownFirstElement(inner);
            auto elements = inner->GetElements();
            elements[0].SetValuesToZero();
            inner->SetElements(std::move(elements));
            result = std::move(inner);
        }
        else {
            inner = cc.KeySwitchDown(inner);
            // Find the automorphism index that corresponds to the rotation index.
            uint32_t autoIndex = FindAutomorphismIndex2nComplex(bStep * j, M);
            PrecomputeAutoMap(N, autoIndex, &map);
            first += inner->GetElements()[0].AutomorphismTransform(autoIndex, map);

            auto&& innerDigits = cc.EvalFastRotationPrecompute(inner);
            FHECKKSRNS::EvalAddExtInPlace(result, cc.EvalFastRotationExt(inner, bStep * j, innerDigits, false));
        }
    }

    result = cc.KeySwitchDown(result);
    result->GetElements()[0] += first;
    return result;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalLTRectWithPrecomputeSwitch(
    const CryptoContextImpl<DCRTPoly>& cc, const std::vector<std::vector<std::complex<double>>>& A,
    ConstCiphertext<DCRTPoly> ct, bool wide, uint32_t dim1, uint32_t L) const {
    uint32_t n = std::min(A.size(), A[0].size());

    // Computing the baby-step bStep and the giant-step gStep
    uint32_t bStep = (dim1 == 0) ? getRatioBSGSLT(n) : dim1;
    uint32_t gStep = std::ceil(static_cast<double>(n) / bStep);

    uint32_t M = cc.GetCyclotomicOrder();
    uint32_t N = cc.GetRingDimension();

    // Computes the NTTs for each CRT limb (for the hoisted automorphisms used later on)
    auto digits = cc.EvalFastRotationPrecompute(ct);

    std::vector<Ciphertext<DCRTPoly>> fastRotation(bStep - 1);

    // Make sure the plaintext is created only with the necessary amount of moduli
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ct->GetCryptoParameters());

    ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());
    uint32_t towersToDrop                         = 0;

    // For FLEXIBLEAUTOEXT we do not need extra modulus in auxiliary plaintexts
    if (L != 0) {
        towersToDrop = elementParams.GetParams().size() - L - 1;
        for (uint32_t i = 0; i < towersToDrop; i++)
            elementParams.PopLastParam();
    }
    if (cryptoParamsCKKS->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
        towersToDrop += 1;
        elementParams.PopLastParam();
    }

    const auto& paramsQ = elementParams.GetParams();
    const auto& paramsP = cryptoParamsCKKS->GetParamsP()->GetParams();

    size_t sizeQP = paramsQ.size() + paramsP.size();
    std::vector<NativeInteger> moduli;
    moduli.reserve(sizeQP);
    std::vector<NativeInteger> roots;
    roots.reserve(sizeQP);
    for (const auto& elem : paramsQ) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }
    for (const auto& elem : paramsP) {
        moduli.emplace_back(elem->GetModulus());
        roots.emplace_back(elem->GetRootOfUnity());
    }

    auto elementParamsPtr  = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);
    auto elementParamsPtr2 = std::dynamic_pointer_cast<typename DCRTPoly::Params>(elementParamsPtr);

// Hoisted automorphisms
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(bStep - 1))
    for (uint32_t j = 1; j < bStep; ++j)
        fastRotation[j - 1] = cc.EvalFastRotationExt(ct, j, digits, true);

    std::vector<uint32_t> map(N);
    Ciphertext<DCRTPoly> result;
    DCRTPoly first;

    for (uint32_t j = 0; j < gStep; j++) {
        int32_t offset = (j == 0) ? 0 : -static_cast<int32_t>(bStep * j);
        auto temp      = cc.MakeCKKSPackedPlaintext(Rotate(Fill(A[bStep * j], N / 2), offset), 1, towersToDrop,
                                                    elementParamsPtr2, N / 2);
        auto inner     = FHECKKSRNS::EvalMultExt(cc.KeySwitchExt(ct, true), temp);

        for (uint32_t i = 1; i < bStep; i++) {
            if (bStep * j + i < n) {
                auto tempi = cc.MakeCKKSPackedPlaintext(Rotate(Fill(A[bStep * j + i], N / 2), offset), 1, towersToDrop,
                                                        elementParamsPtr2, N / 2);
                FHECKKSRNS::EvalAddExtInPlace(inner, FHECKKSRNS::EvalMultExt(fastRotation[i - 1], tempi));
            }
        }

        if (j == 0) {
            first         = cc.KeySwitchDownFirstElement(inner);
            auto elements = inner->GetElements();
            elements[0].SetValuesToZero();
            inner->SetElements(std::move(elements));
            result = std::move(inner);
        }
        else {
            inner = cc.KeySwitchDown(inner);
            // Find the automorphism index that corresponds to rotation index index.
            uint32_t autoIndex = FindAutomorphismIndex2nComplex(bStep * j, M);
            PrecomputeAutoMap(N, autoIndex, &map);
            first += inner->GetElements()[0].AutomorphismTransform(autoIndex, map);
            auto&& innerDigits = cc.EvalFastRotationPrecompute(inner);
            FHECKKSRNS::EvalAddExtInPlace(result, cc.EvalFastRotationExt(inner, bStep * j, innerDigits, false));
        }
    }
    result = cc.KeySwitchDown(result);
    result->GetElements()[0] += first;

    // A represents the diagonals, which lose the information whether the initial matrix is tall or wide
    if (wide) {
        uint32_t logl = lbcrypto::GetMSB(A[0].size() / A.size()) - 1;  // These are powers of two, so log(l) is integer
        std::vector<Ciphertext<DCRTPoly>> ctxt(logl + 1);
        ctxt[0] = result;
        for (uint32_t j = 1; j <= logl; ++j) {
            ctxt[j] = cc.EvalAdd(ctxt[j - 1], cc.EvalAtIndex(ctxt[j - 1], A.size() * (1 << (j - 1))));
        }
        result = ctxt[logl];
    }

    return result;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalSlotsToCoeffsSwitch(const CryptoContextImpl<DCRTPoly>& cc,
                                                            ConstCiphertext<DCRTPoly> ctxt) const {
    if (m_U0Pre.size() == 0)
        OPENFHE_THROW("Precomputations not generated. Call EvalCKKSToFHEWPrecompute to proceed.");

    uint32_t m    = 4 * m_numSlotsCKKS;
    uint32_t M    = cc.GetCyclotomicOrder();
    bool isSparse = (M != m);

    auto ctxtToDecode = ctxt->Clone();

    uint32_t numTowersToKeep = 2;
    const auto cryptoParams  = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
    if (cryptoParams->GetScalingTechnique() == ScalingTechnique::FLEXIBLEAUTO ||
        cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
        ctxtToDecode = cc.Compress(ctxtToDecode, numTowersToKeep + 1);

        double targetSF =
            cryptoParams->GetScalingFactorReal(cryptoParams->GetElementParams()->GetParams().size() - numTowersToKeep);
        double sourceSF    = ctxtToDecode->GetScalingFactor();
        uint32_t numTowers = ctxtToDecode->GetElements()[0].GetNumOfElements();
        double modToDrop = cryptoParams->GetElementParams()->GetParams()[numTowers - 1]->GetModulus().ConvertToDouble();
        double adjustmentFactor = (targetSF / sourceSF) * (modToDrop / sourceSF);

        ctxtToDecode = cc.EvalMult(ctxtToDecode, adjustmentFactor);
        cc.GetScheme()->ModReduceInternalInPlace(ctxtToDecode, 1);
        ctxtToDecode->SetScalingFactor(targetSF);
    }
    else {
        ctxtToDecode = cc.Compress(ctxtToDecode, numTowersToKeep);
    }

    Ciphertext<DCRTPoly> ctxtDecoded;
    if (!isSparse) {  // fully packed
        // ctxtToDecode = cc.EvalAdd(ctxtToDecode, cc.GetScheme()->MultByMonomial(ctxtToDecode, M / 4));
        ctxtDecoded = EvalLTWithPrecomputeSwitch(cc, ctxtToDecode, m_U0Pre, m_dim1CF);
    }
    else {  // sparsely packed
        ctxtDecoded = EvalLTWithPrecomputeSwitch(cc, ctxtToDecode, m_U0Pre, m_dim1CF);
        cc.EvalAddInPlace(ctxtDecoded, cc.EvalAtIndex(ctxtDecoded, m_numSlotsCKKS));
    }
    return ctxtDecoded;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalPartialHomDecryption(const CryptoContextImpl<DCRTPoly>& cc,
                                                             const std::vector<std::vector<std::complex<double>>>& A,
                                                             ConstCiphertext<DCRTPoly> ct, uint32_t dim1, double scale,
                                                             uint32_t L) const {
    // Currently, by design, the # rows (# LWE ciphertexts to switch) is a power of two.
    // Ensure that # cols (LWE lattice parameter n) is padded up to a power of two
    auto Acopy      = A;
    size_t cols_po2 = 1 << static_cast<uint32_t>(std::ceil(std::log2(A[0].size())));
    if (cols_po2 > A[0].size()) {
        for (size_t i = 0; i < A.size(); ++i)
            Acopy[i].resize(cols_po2);
    }

    auto Apre = EvalLTRectPrecomputeSwitch(Acopy, dim1, scale);
    // The result is repeated every Acopy.size() slots
    return EvalLTRectWithPrecomputeSwitch(cc, Apre, ct, (Acopy.size() < A[0].size()), dim1, L);
}

//------------------------------------------------------------------------------
// Scheme switching Wrapper
//------------------------------------------------------------------------------
LWEPrivateKey SWITCHCKKSRNS::EvalCKKStoFHEWSetup(const SchSwchParams& params) {
    if (params.GetSecurityLevelFHEW() != TOY && params.GetSecurityLevelFHEW() != STD128)
        OPENFHE_THROW("Only STD128 or TOY are currently supported.");

    uint32_t ringDim = params.GetRingDimension();
    m_numSlotsCKKS   = (params.GetNumSlotsCKKS() == 0) ? ringDim / 2 : params.GetNumSlotsCKKS();

    m_modulus_CKKS_initial = params.GetInitialCKKSModulus();
    // Modulus to switch to in order to have secure RLWE samples with ring dimension n.
    // We can select any Qswitch less than 27 bits corresponding to 128 bits of security for lattice parameter n=1024 < 1305
    // according to https://homomorphicencryption.org/wp-content/uploads/2018/11/HomomorphicEncryptionStandardv1.1.pdf
    // or any Qswitch for TOY security.

    // Ensure that Qswitch is larger than Q_FHEW and smaller than Q_CKKS.
    if (params.GetCtxtModSizeFHEWIntermedSwch() <= params.GetCtxtModSizeFHEWLargePrec() ||
        params.GetCtxtModSizeFHEWIntermedSwch() > m_modulus_CKKS_initial.GetMSB() - 1) {
        OPENFHE_THROW("Qswitch should be larger than QFHEW and smaller than QCKKS.");
    }

    // Intermediate cryptocontext
    CCParams<CryptoContextCKKSRNS> parameters;
    parameters.SetMultiplicativeDepth(0);
    parameters.SetFirstModSize(params.GetCtxtModSizeFHEWIntermedSwch());
    // scaling mod size is not used in this case
    parameters.SetScalingModSize(params.GetCtxtModSizeFHEWIntermedSwch());
    // This doesn't need this to be the same scaling technique as the outer cryptocontext, since we only do a key switch
    parameters.SetScalingTechnique(FIXEDMANUAL);
    parameters.SetSecurityLevel(params.GetSecurityLevelCKKS());
    parameters.SetRingDim(ringDim);
    parameters.SetBatchSize(params.GetBatchSize());
    parameters.SetCKKSDataType(REAL);

    m_ccKS = GenCryptoContext(parameters);

    // Enable the features that you wish to use
    m_ccKS->Enable(PKE);
    m_ccKS->Enable(KEYSWITCH);

    // Get the ciphertext modulus
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(m_ccKS->GetCryptoParameters());
    m_modulus_CKKS_from     = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus();

    m_ccLWE = std::make_shared<BinFHEContext>();
    m_ccLWE->BinFHEContext::GenerateBinFHEContext(
        params.GetSecurityLevelFHEW(), params.GetArbitraryFunctionEvaluation(), params.GetCtxtModSizeFHEWLargePrec(), 0,
        GINX, params.GetUseDynamicModeFHEW());

    // For arbitrary functions, the LWE ciphertext needs to be at most the ring dimension in FHEW bootstrapping
    m_modulus_LWE = (!params.GetArbitraryFunctionEvaluation()) ?
                        1 << params.GetCtxtModSizeFHEWLargePrec() :
                        m_ccLWE->GetParams()->GetLWEParams()->Getq().ConvertToInt();

    // The baby-step and number of levels for the linear transformation associated to the homomorphic decoding
    m_dim1CF = (params.GetBStepLTrCKKStoFHEW() == 0) ? getRatioBSGSLT(params.GetNumSlotsCKKS()) :
                                                       params.GetBStepLTrCKKStoFHEW();
    m_LCF    = params.GetLevelLTrCKKStoFHEW();

    // Return LWE private key
    return m_ccLWE->KeyGen();
}

std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalCKKStoFHEWKeyGen(
    const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk) {
    auto& privateKey = keyPair.secretKey;

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());

    if (cryptoParams->GetKeySwitchTechnique() != HYBRID)
        OPENFHE_THROW("CKKS to FHEW scheme switching is only supported for the Hybrid key switching method.");
#if NATIVEINT == 128
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        OPENFHE_THROW("128-bit CKKS to FHEW scheme switching is supported for FIXEDMANUAL and FIXEDAUTO methods only.");
#endif

    auto ccCKKS = privateKey->GetCryptoContext();

    // Intermediate cryptocontext for CKKS to FHEW
    auto keys2 = m_ccKS->KeyGen();

    m_ctxtKS = m_ccKS->Encrypt(keys2.publicKey, m_ccKS->MakeCKKSPackedPlaintext(std::vector<double>{0.0}));

    // Compute switching key between RLWE and LWE via the intermediate cryptocontext, keep it in RLWE form
    m_CKKStoFHEWswk = switchingKeyGenRLWEcc(keys2.secretKey, privateKey, lwesk);

    // Compute automorphism keys
    uint32_t M     = ccCKKS->GetCyclotomicOrder();
    uint32_t slots = m_numSlotsCKKS;

    // Compute indices for rotations for slotToCoeff transform
    std::vector<int32_t> indexRotationS2C = FindLTRotationIndicesSwitch(m_dim1CF, M, slots);
    indexRotationS2C.emplace_back(static_cast<int32_t>(slots));

    // Remove possible duplicates and zero
    sort(indexRotationS2C.begin(), indexRotationS2C.end());
    indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());
    indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());

    auto algo = ccCKKS->GetScheme();

    // Compute multiplication key
    algo->EvalMultKeyGen(privateKey);

    auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);

    // Compute conjugation key
    (*evalKeys)[M - 1] = FHECKKSRNS::ConjugateKeyGen(privateKey);

    return evalKeys;
}

void SWITCHCKKSRNS::EvalCKKStoFHEWPrecompute(const CryptoContextImpl<DCRTPoly>& cc, double scale) {
    uint32_t M     = cc.GetCyclotomicOrder();
    uint32_t slots = m_numSlotsCKKS;
    uint32_t m     = 4 * m_numSlotsCKKS;
    uint32_t mmask = m - 1;  // assumes m is power of 2
    bool isSparse  = (M != m);

    // Computes indices for all primitive roots of unity
    std::vector<uint32_t> rotGroup(slots);
    uint32_t fivePows = 1;
    for (uint32_t i = 0; i < slots; ++i) {
        rotGroup[i] = fivePows;
        fivePows *= 5;
        fivePows &= mmask;
    }
    // Computes all powers of a primitive root of unity exp(2*M_PI/m)
    std::vector<std::complex<double>> ksiPows(m + 1);
    double ak = 2 * M_PI / m;
    for (uint32_t j = 0; j < m; ++j) {
        double angle = ak * j;
        ksiPows[j].real(std::cos(angle));
        ksiPows[j].imag(std::sin(angle));
    }
    ksiPows[m] = ksiPows[0];

    // Matrices for decoding
    std::vector<std::vector<std::complex<double>>> U0(slots, std::vector<std::complex<double>>(slots));
    std::vector<std::vector<std::complex<double>>> U1(slots, std::vector<std::complex<double>>(slots));

    for (uint32_t i = 0; i < slots; ++i) {
        for (uint32_t j = 0; j < slots; ++j) {
            U0[i][j] = ksiPows[(j * rotGroup[i]) & mmask];
            U1[i][j] = std::complex<double>(0, 1) * U0[i][j];
        }
    }

    // Obtain the right scaling for encoded messages in FHEW coming from encoded messages in CKKS
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc.GetCryptoParameters());
    double scFactor = cryptoParams->GetScalingFactorReal(cryptoParams->GetElementParams()->GetParams().size() - 1);
    scale *= m_modulus_CKKS_initial.ConvertToDouble() / scFactor;

    if (!isSparse) {  // fully packed
        m_U0Pre = EvalLTPrecomputeSwitch(cc, U0, m_dim1CF, m_LCF, scale);
    }
    else {  // sparsely packed
        m_U0Pre = EvalLTPrecomputeSwitch(cc, U0, U1, m_dim1CF, m_LCF, scale);
    }
}

std::vector<std::shared_ptr<LWECiphertextImpl>> SWITCHCKKSRNS::EvalCKKStoFHEW(ConstCiphertext<DCRTPoly> ciphertext,
                                                                              uint32_t numCtxts) {
    if (numCtxts == 0 || numCtxts > m_numSlotsCKKS)
        numCtxts = m_numSlotsCKKS;

    // Step 1. Homomorphic decoding
    auto ccCKKS      = ciphertext->GetCryptoContext();
    auto ctxtDecoded = EvalSlotsToCoeffsSwitch(*ccCKKS, ciphertext);
    ccCKKS->GetScheme()->ModReduceInternalInPlace(ctxtDecoded, 1);

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());

    // Step 2. Modulus switch to Q', such that CKKS is secure for (Q',n)
    auto ctxtKS = m_ctxtKS->Clone();
    ModSwitch(ctxtDecoded, ctxtKS, m_modulus_CKKS_from);

    // Step 3: Key switch from the CKKS key with the new modulus Q' to the RLWE version of the FHEW key with the new modulus Q'
    auto ccKS       = ctxtKS->GetCryptoContext();  // Use this instead of m_ccKS to work with serialization
    auto ctSwitched = ccKS->KeySwitch(ctxtKS, m_CKKStoFHEWswk);

    // Step 4. Extract LWE ciphertexts with the modulus Q'
    uint32_t n = m_ccLWE->GetParams()->GetLWEParams()->Getn();  // lattice parameter for additive LWE
    std::vector<std::shared_ptr<LWECiphertextImpl>> LWEciphertexts(numCtxts);
    auto AandB = ExtractLWEpacked(ctSwitched);

    uint32_t gap = ccKS->GetRingDimension() / (2 * m_numSlotsCKKS);
    for (uint32_t i = 0, idx = 0; i < numCtxts; ++i, idx += gap)
        LWEciphertexts[i] = ExtractLWECiphertext(AandB, m_modulus_CKKS_from, n, idx);

    // Step 5. Modulus switch to q in FHEW

    // Compute the necessary factor to obtaine the message Q'/pLWE
    if (m_modulus_LWE != m_modulus_CKKS_from) {
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numCtxts))
        for (uint32_t i = 0; i < numCtxts; ++i) {
            auto& original_a = LWEciphertexts[i]->GetA();
            auto original_b = LWEciphertexts[i]->GetB();
            // multiply by Q_LWE/Q' and round to Q_LWE
            NativeVector a_round(n, m_modulus_LWE);
            for (uint32_t j = 0; j < n; ++j)
                a_round[j] = RoundqQAlter(original_a[j], m_modulus_LWE, m_modulus_CKKS_from);
            NativeInteger b_round = RoundqQAlter(original_b, m_modulus_LWE, m_modulus_CKKS_from);
            LWEciphertexts[i]     = std::make_shared<LWECiphertextImpl>(std::move(a_round), std::move(b_round));
        }
    }

    return LWEciphertexts;
}

//------------------------------------------------------------------------------
// Scheme switching Wrapper
//------------------------------------------------------------------------------
void SWITCHCKKSRNS::EvalFHEWtoCKKSSetup(const CryptoContextImpl<DCRTPoly>& ccCKKS,
                                        const std::shared_ptr<BinFHEContext>& ccLWE, uint32_t numSlotsCKKS,
                                        uint32_t logQ) {
    if (ccLWE->GetParams()->GetLWEParams()->Getn() * 2 > ccCKKS.GetRingDimension())
        OPENFHE_THROW("The lattice parameter in LWE cannot be larger than half the RLWE ring dimension.");
    m_ccLWE = ccLWE;

    if (numSlotsCKKS == 0) {
        if (ccCKKS.GetEncodingParams()->GetBatchSize() != 0)
            m_numSlotsCKKS = ccCKKS.GetEncodingParams()->GetBatchSize();
        else
            m_numSlotsCKKS = ccCKKS.GetRingDimension() / 2;
    }
    else {
        m_numSlotsCKKS = numSlotsCKKS;
    }

    m_modulus_LWE = (logQ != 0) ? 1 << logQ : m_ccLWE->GetParams()->GetLWEParams()->Getq().ConvertToInt();
}

std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalFHEWtoCKKSKeyGen(
    const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk, uint32_t numSlots, uint32_t numCtxts, uint32_t dim1,
    uint32_t L) {
    auto& privateKey = keyPair.secretKey;
    auto& publicKey  = keyPair.publicKey;

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());
    auto ccCKKS             = privateKey->GetCryptoContext();

    uint32_t n       = lwesk->GetElement().GetLength();
    uint32_t ringDim = ccCKKS->GetRingDimension();

    // Generate FHEW to CKKS switching key, i.e., CKKS encryption of FHEW secret key. Pad up to the closest power of two
    uint32_t n_po2      = 1 << static_cast<uint32_t>(std::ceil(std::log2(n)));
    auto& skLWEElements = lwesk->GetElement();
    const auto neg      = lwesk->GetModulus().ConvertToDouble() - 1.0;
    std::vector<std::complex<double>> skLWEDouble(n_po2);
    for (uint32_t i = 0; i < n; ++i) {
        auto tmp       = skLWEElements[i].ConvertToDouble();
        skLWEDouble[i] = std::complex<double>(tmp == neg ? -1.0 : tmp, 0);
    }

    // Check encoding and specify the number of slots, otherwise, if batchsize is set and is smaller, it will throw an error.
    Plaintext skLWEPlainswk;
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, BASE_NUM_LEVELS_TO_DROP,
                                                        nullptr, ringDim / 2);
    else
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, 0, nullptr, ringDim / 2);

    m_FHEWtoCKKSswk = ccCKKS->Encrypt(publicKey, skLWEPlainswk);

    // Compute automorphism keys for CKKS for baby-step giant-step
    if (numCtxts == 0)
        numCtxts = m_numSlotsCKKS;  // If no value is specified, default to the column size of the linear transformation
    uint32_t M = ccCKKS->GetCyclotomicOrder();
    if (dim1 == 0)
        dim1 = getRatioBSGSLT(numCtxts);
    m_dim1FC = dim1;
    m_LFC    = L;

    // Compute indices for rotations for homomorphic decryption in CKKS
    std::vector<int32_t> indexRotationHomDec = FindLTRotationIndicesSwitch(dim1, M, numCtxts);

    // If the linear transform is wide instead of tall, we need extra rotations
    if (numCtxts < n_po2) {
        uint32_t logl = lbcrypto::GetMSB(n_po2 / numCtxts) - 1;  // These are powers of two, so log(l) is integer
        indexRotationHomDec.reserve(indexRotationHomDec.size() + logl);
        for (uint32_t j = 1; j <= logl; ++j)
            indexRotationHomDec.emplace_back(numCtxts * (1 << (j - 1)));
    }

    uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;
    // Compute indices for rotations to bring back the final CKKS ciphertext encoding to slots
    if (ringDim > 2 * slots) {  // if the encoding is full, this does not execute
        indexRotationHomDec.reserve(indexRotationHomDec.size() + GetMSB(ringDim) - 2);
        for (uint32_t j = 1; j < ringDim / (2 * slots); j <<= 1)
            indexRotationHomDec.emplace_back(j * slots);
    }

    // Remove possible duplicates and zero
    sort(indexRotationHomDec.begin(), indexRotationHomDec.end());
    indexRotationHomDec.erase(unique(indexRotationHomDec.begin(), indexRotationHomDec.end()),
                              indexRotationHomDec.end());
    indexRotationHomDec.erase(std::remove(indexRotationHomDec.begin(), indexRotationHomDec.end(), 0),
                              indexRotationHomDec.end());

    auto algo     = ccCKKS->GetScheme();
    auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationHomDec);

    // Compute multiplication key
    ccCKKS->EvalMultKeyGen(privateKey);

    return evalKeys;
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalFHEWtoCKKS(std::vector<std::shared_ptr<LWECiphertextImpl>>& LWECiphertexts,
                                                   uint32_t numCtxts, uint32_t numSlots, uint32_t p, double pmin,
                                                   double pmax, uint32_t dim1) const {
    if (LWECiphertexts.empty())
        OPENFHE_THROW("Empty input FHEW ciphertext vector");

    // This is the number of CKKS slots to use in encoding
    const uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    uint32_t numLWECtxts = LWECiphertexts.size();
    uint32_t numValues   = (numCtxts == 0) ? numLWECtxts : std::min(numCtxts, numLWECtxts);
    numValues = std::min(numValues, slots);  // This is the number of LWE ciphertexts to pack into the CKKS ciphertext

    uint32_t n = LWECiphertexts[0]->GetA().GetLength();

    auto ccCKKS                 = m_FHEWtoCKKSswk->GetCryptoContext();
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());

    uint32_t m    = 4 * slots;
    uint32_t M    = ccCKKS->GetCyclotomicOrder();
    uint32_t N    = ccCKKS->GetRingDimension();
    bool isSparse = (M != m);

    double K = 0.0;
    std::vector<double> coefficientsFHEW;  // EvalFHEWtoCKKS assumes lattice parameter n is at most 2048.
    if (n == 32) {
        K = 16.0;
        coefficientsFHEW.assign(g_coefficientsFHEW16);
    }
    else {
        K = 128.0;  // Failure probability of 2^{-49}
        if (p <= 4) {
            // If the output messages are bits, we could use a lower degree polynomial
            coefficientsFHEW.assign(g_coefficientsFHEW128_8);
        }
        else {
            coefficientsFHEW.assign(g_coefficientsFHEW128_9);
        }
    }

    // Step 1. Form matrix A and vector b from the LWE ciphertexts, but only extract the first necessary number of them
    std::vector<std::vector<std::complex<double>>> A(numValues);

    // To have the same encoding as A*s, create b with the appropriate number of elements
    const uint32_t b_size = std::min(((numValues % n) != 0) ? (numValues + n - (numValues % n)) : numValues, N / 2);
    std::vector<std::complex<double>> b(b_size);

    // Combine the scale with the division by K to consume fewer levels, but careful since the value might be too small
    const double prescale = (1.0 / LWECiphertexts[0]->GetModulus().ConvertToDouble()) / K;

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numValues))
    for (uint32_t i = 0; i < numValues; ++i) {
        auto& a = LWECiphertexts[i]->GetA();
        A[i].resize(a.GetLength());
        for (uint32_t j = 0; j < a.GetLength(); ++j)
            A[i][j] = std::complex<double>(a[j].ConvertToDouble(), 0);
        b[i] = std::complex<double>(prescale * LWECiphertexts[i]->GetB().ConvertToDouble(), 0);
    }

    // Step 2. Perform the homomorphic linear transformation of A*skLWE
    if (dim1 == 0)
        dim1 = m_dim1FC;
    auto AdotS = EvalPartialHomDecryption(*ccCKKS, A, m_FHEWtoCKKSswk, dim1, prescale, 0);

    // Step 3. Get the ciphertext of B - A*s
    Plaintext BPlain = ccCKKS->MakeCKKSPackedPlaintext(b, AdotS->GetNoiseScaleDeg(), AdotS->GetLevel(), nullptr, N / 2);
    auto BminusAdotS = ccCKKS->EvalAdd(ccCKKS->EvalNegate(AdotS), BPlain);
    if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL)
        ccCKKS->ModReduceInPlace(BminusAdotS);
    else if (BminusAdotS->GetNoiseScaleDeg() == 2)
        ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS, BASE_NUM_LEVELS_TO_DROP);

    // Step 4. Do the modulus reduction: homomorphically evaluate modular function. We do it by using sine approximation.
    // auto BminusAdotS2 = BminusAdotS;  // Instead of zeroing out slots which are not of interest as done above

    double a_cheby = -1.0;
    double b_cheby = 1.0;  // The division by K was performed before

    // double a_cheby = -K; double b_cheby = K; // Alternatively, do this separately to not lose precision when scaling with everything at once
    auto BminusAdotS3 = ccCKKS->EvalChebyshevSeries(BminusAdotS, coefficientsFHEW, a_cheby, b_cheby);
    if (cryptoParamsCKKS->GetScalingTechnique() != FIXEDMANUAL)
        ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS3, BASE_NUM_LEVELS_TO_DROP);

    const int32_t BT_ITER = 3;
    for (int32_t j = 1; j <= BT_ITER; ++j) {
        BminusAdotS3 = ccCKKS->EvalMult(BminusAdotS3, BminusAdotS3);
        ccCKKS->EvalAddInPlace(BminusAdotS3, BminusAdotS3);
        double scalar = 1.0 / std::pow((2.0 * M_PI), std::pow(2.0, j - BT_ITER));
        ccCKKS->EvalSubInPlace(BminusAdotS3, scalar);
        if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL)
            ccCKKS->ModReduceInPlace(BminusAdotS3);
        else
            ccCKKS->GetScheme()->ModReduceInternalInPlace(BminusAdotS3, BASE_NUM_LEVELS_TO_DROP);
    }

    /* For p <= 4 and when we only encrypt bits, we don't need sin(2pi*x)/2pi to approximate x,
     * we can directly use sin(0) for 0 and sin(pi/2) for 1.
     * Here pmax is actually the plaintext modulus, not the maximum value of the messages that we
     * consider. For plaintext modulus > 4, even if we only care about encrypting bits, 2pi is not
     * the correct post-scaling factor.
     * Moreover, we have to account for the different encoding the end ciphertext should have.
     */
    double postScale = (p >= 1 && p <= 4) ? (2.0 * M_PI) : static_cast<double>(p);
    double postBias  = 0.0;
    if (pmin != 0) {
        postScale *= (pmax - pmin) / 4.0;
        postBias = (pmax - pmin) / 4.0;
    }

    // numValues are set; the rest of values up to N/2 are made zero when creating the plaintext
    std::vector<std::complex<double>> postScaleVec(numValues, std::complex<double>(postScale, 0));
    std::vector<std::complex<double>> postBiasVec(numValues, std::complex<double>(postBias, 0));

    ILDCRTParams<DCRTPoly::Integer> elementParams = *(cryptoParamsCKKS->GetElementParams());

    uint32_t towersToDrop = BminusAdotS3->GetLevel() + BminusAdotS3->GetNoiseScaleDeg() - 1;
    for (uint32_t i = 0; i < towersToDrop; i++)
        elementParams.PopLastParam();

    // Use full packing here to clear up the junk in the slots after numValues
    auto postScalePlain = ccCKKS->MakeCKKSPackedPlaintext(postScaleVec, 1, towersToDrop, nullptr, N / 2);
    auto BminusAdotSres = ccCKKS->EvalMult(BminusAdotS3, postScalePlain);

    // Add the plaintext for bias at the correct level and depth
    auto postBiasPlain = ccCKKS->MakeCKKSPackedPlaintext(postBiasVec, BminusAdotSres->GetNoiseScaleDeg(),
                                                         BminusAdotSres->GetLevel(), nullptr, N / 2);

    ccCKKS->EvalAddInPlace(BminusAdotSres, postBiasPlain);

    // Go back to the sparse encoding if needed
    if (isSparse) {
        for (uint32_t j = 1; j < N / (2 * slots); j <<= 1) {
            auto temp = ccCKKS->EvalAtIndex(BminusAdotSres, j * slots);
            ccCKKS->EvalAddInPlace(BminusAdotSres, temp);
        }
        BminusAdotSres->SetSlots(slots);
    }

    if (cryptoParamsCKKS->GetScalingTechnique() == FIXEDMANUAL) {
        ccCKKS->ModReduceInPlace(BminusAdotSres);
    }

    return BminusAdotSres;
}

LWEPrivateKey SWITCHCKKSRNS::EvalSchemeSwitchingSetup(const SchSwchParams& params) {
    // FHEW to CKKS
    // Save the parameters to be used in EvalSchemeSwitchingKeyGen
    m_argmin = params.GetComputeArgmin();
    m_oneHot = params.GetOneHotEncoding();
    m_alt    = params.GetUseAltArgmin();

    // Set parameters for linear transform for FHEW to CKKS
    if (!m_argmin || (m_argmin && m_alt)) {
        m_numCtxts = (params.GetNumValues() == 0) ? m_numSlotsCKKS : params.GetNumValues();
    }
    else {  // argmin not in the alternative mode
        m_numCtxts = (params.GetNumValues() == 0) ? m_numSlotsCKKS / 2 : params.GetNumValues() / 2;
    }

    // There are multiple dim1's required in argmin, but they are specified individually in EvalSchemeSwitchingKeyGen
    m_dim1FC = (params.GetBStepLTrFHEWtoCKKS() == 0) ? getRatioBSGSLT(m_numCtxts) : params.GetBStepLTrFHEWtoCKKS();
    m_LFC    = params.GetLevelLTrFHEWtoCKKS();

    // CKKS to FHEW
    return EvalCKKStoFHEWSetup(params);
}

std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalSchemeSwitchingKeyGen(
    const KeyPair<DCRTPoly>& keyPair, ConstLWEPrivateKey& lwesk) {
    auto& privateKey = keyPair.secretKey;
    auto& publicKey  = keyPair.publicKey;

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(privateKey->GetCryptoParameters());

    if (cryptoParams->GetKeySwitchTechnique() != HYBRID)
        OPENFHE_THROW("CKKS to FHEW scheme switching is only supported for the Hybrid key switching method.");
#if NATIVEINT == 128
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        OPENFHE_THROW("128-bit CKKS to FHEW scheme switching is supported for FIXEDMANUAL and FIXEDAUTO methods only.");
#endif

    auto ccCKKS = privateKey->GetCryptoContext();

    uint32_t M       = ccCKKS->GetCyclotomicOrder();
    uint32_t slots   = m_numSlotsCKKS;
    uint32_t n       = lwesk->GetElement().GetLength();
    uint32_t ringDim = ccCKKS->GetRingDimension();

    // Intermediate cryptocontext for CKKS to FHEW
    auto keys2 = m_ccKS->KeyGen();

    Plaintext ptxtZeroKS = m_ccKS->MakeCKKSPackedPlaintext(std::vector<double>{0.0}, 1, 0, nullptr, slots);
    m_ctxtKS             = m_ccKS->Encrypt(keys2.publicKey, ptxtZeroKS);

    // Compute switching key between RLWE and LWE via the intermediate cryptocontext, keep it in RLWE form
    m_CKKStoFHEWswk = switchingKeyGenRLWEcc(keys2.secretKey, privateKey, lwesk);

    // Generate FHEW to CKKS switching key, i.e., CKKS encryption of FHEW secret key. Pad up to the closest power of two
    uint32_t n_po2     = 1 << static_cast<uint32_t>(std::ceil(std::log2(n)));
    auto skLWEElements = lwesk->GetElement();
    const auto neg     = lwesk->GetModulus().ConvertToDouble() - 1.0;
    std::vector<std::complex<double>> skLWEDouble(n_po2);
    for (uint32_t i = 0; i < n; i++) {
        auto tmp       = skLWEElements[i].ConvertToDouble();
        skLWEDouble[i] = std::complex<double>(tmp == neg ? -1.0 : tmp, 0);
    }

    // Check encoding and specify the number of slots, otherwise, if batchsize is set and is smaller, it will throw an error.
    Plaintext skLWEPlainswk;
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT)
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, BASE_NUM_LEVELS_TO_DROP,
                                                        nullptr, ringDim / 2);
    else
        skLWEPlainswk = ccCKKS->MakeCKKSPackedPlaintext(Fill(skLWEDouble, ringDim / 2), 1, 0, nullptr, ringDim / 2);

    m_FHEWtoCKKSswk = ccCKKS->Encrypt(publicKey, skLWEPlainswk);

    // Compute automorphism keys

    /* CKKS to FHEW */
    // Compute indices for rotations for slotToCoeff transform
    std::vector<int32_t> indexRotationS2C = FindLTRotationIndicesSwitch(m_dim1CF, M, slots);
    indexRotationS2C.emplace_back(static_cast<int32_t>(slots));

    // Compute indices for rotations for sparse packing
    if (ringDim > 2 * slots) {  // if the encoding is full, this does not execute
        indexRotationS2C.reserve(indexRotationS2C.size() + GetMSB(ringDim) - 2 + GetMSB(slots) - 1);
        for (uint32_t i = 1; i < ringDim / 2; i <<= 1) {
            indexRotationS2C.emplace_back(static_cast<int32_t>(i));
            if (i <= slots)
                indexRotationS2C.emplace_back(-static_cast<int32_t>(i));
        }
    }

    /* FHEW to CKKS */
    std::vector<int32_t> indexRotationHomDec;
    std::vector<int32_t> indexRotationArgmin;

    if (!m_argmin || (m_argmin && m_alt)) {
        // Compute indices for rotations for homomorphic decryption
        indexRotationHomDec = FindLTRotationIndicesSwitch(m_dim1FC, M, m_numCtxts);

        // If the linear transform is wide instead of tall, we need extra rotations
        if (m_numCtxts < n_po2) {
            uint32_t logl = lbcrypto::GetMSB(n_po2 / m_numCtxts) - 1;  // These are powers of two, so log(l) is integer
            indexRotationHomDec.reserve(indexRotationHomDec.size() + logl);
            for (size_t j = 1; j <= logl; ++j) {
                indexRotationHomDec.emplace_back(m_numCtxts * (1 << (j - 1)));
            }
        }
        if (m_argmin) {
            // Rotations for postprocessing after a level of the binary tree
            indexRotationArgmin.reserve(GetMSB(m_numCtxts) - 2);
            for (uint32_t i = 1; i < m_numCtxts; i <<= 1) {
                indexRotationArgmin.emplace_back(static_cast<int32_t>(m_numCtxts / (2 * i)));
            }
        }
    }
    else {  // argmin not in the alternative mode
        // Compute indices for rotations for all homomorphic decryptions for the levels of the tree
        indexRotationHomDec = FindLTRotationIndicesSwitchArgmin(M, m_numCtxts, n_po2);

        // Rotations for postprocessing after a level of the binary tree
        indexRotationArgmin.reserve(GetMSB(m_numCtxts) - 1 + 2 * (GetMSB(m_numCtxts) - 1));
        for (uint32_t i = 1; i < 2 * m_numCtxts; i <<= 1) {
            indexRotationArgmin.emplace_back(static_cast<int32_t>(m_numCtxts / (2 * i)));
            indexRotationArgmin.emplace_back(-static_cast<int32_t>(m_numCtxts / (2 * i)));
            if (i > 1) {
                for (uint32_t j = 2 * m_numCtxts / i; j < 2 * m_numCtxts; j <<= 1)
                    indexRotationArgmin.emplace_back(-static_cast<int32_t>(j));
            }
        }
    }

    // Compute indices for rotations to bring back the final CKKS ciphertext encoding to slots
    if (ringDim > 2 * slots) {  // if the encoding is full, this does not execute
        indexRotationHomDec.reserve(indexRotationHomDec.size() + GetMSB(ringDim) - 2);
        for (uint32_t j = 1; j < ringDim / (2 * slots); j <<= 1) {
            indexRotationHomDec.emplace_back(j * slots);
        }
    }

    // Combine the indices lists
    indexRotationS2C.reserve(indexRotationS2C.size() + indexRotationHomDec.size() + indexRotationArgmin.size());
    indexRotationS2C.insert(indexRotationS2C.end(), indexRotationHomDec.begin(), indexRotationHomDec.end());
    indexRotationS2C.insert(indexRotationS2C.end(), indexRotationArgmin.begin(), indexRotationArgmin.end());

    // Remove possible duplicates and zero
    sort(indexRotationS2C.begin(), indexRotationS2C.end());
    indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());
    indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());

    auto algo     = ccCKKS->GetScheme();
    auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);

    // Compute conjugation key
    auto conjKey       = FHECKKSRNS::ConjugateKeyGen(privateKey);
    (*evalKeys)[M - 1] = conjKey;

    // Compute multiplication key
    ccCKKS->EvalMultKeyGen(privateKey);

    // Compute automorphism keys if we don't want one hot encoding for argmin
    if (m_argmin && (m_oneHot == false)) {
        ccCKKS->EvalSumKeyGen(privateKey);
    }

    /* FHEW computations */
    // Generate the bootstrapping keys (refresh and switching keys)
    m_ccLWE->BTKeyGen(lwesk);

    return evalKeys;
}

void SWITCHCKKSRNS::EvalCompareSwitchPrecompute(const CryptoContextImpl<DCRTPoly>& ccCKKS, uint32_t pLWE,
                                                double scaleSign, bool unit) {
    double scaleCF = 1.0;
    if ((pLWE != 0) && (!unit)) {  // The messages are already scaled between 0 and 1, no need to divide by pLWE
        scaleCF = 1.0 / pLWE;
    }
    // Else perform no scaling; the implicit FHEW plaintext modulus will be m_modulus_CKKS_initial / scFactor
    scaleCF *= scaleSign;

    EvalCKKStoFHEWPrecompute(ccCKKS, scaleCF);
}

Ciphertext<DCRTPoly> SWITCHCKKSRNS::EvalCompareSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext1,
                                                               ConstCiphertext<DCRTPoly> ciphertext2, uint32_t numCtxts,
                                                               uint32_t numSlots, uint32_t pLWE, double scaleSign,
                                                               bool unit) {
    auto ccCKKS = ciphertext1->GetCryptoContext();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ccCKKS->GetCryptoParameters());

    auto cDiff = ccCKKS->EvalSub(ciphertext1, ciphertext2);

    if (unit) {
        if (pLWE == 0)
            OPENFHE_THROW("To scale to the unit circle, pLWE must be non-zero.");
        cDiff = ccCKKS->EvalMult(cDiff, 1.0 / static_cast<double>(pLWE));
        cDiff = ccCKKS->Rescale(cDiff);
    }

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0) {
        double scaleCF = 1.0;
        if ((pLWE != 0) && (!unit)) {
            scaleCF = 1.0 / pLWE;
        }
        scaleCF *= scaleSign;

        EvalCKKStoFHEWPrecompute(*ccCKKS, scaleCF);
    }

    auto LWECiphertexts = EvalCKKStoFHEW(cDiff, numCtxts);
    const uint32_t n    = LWECiphertexts.size();
    std::vector<LWECiphertext> cSigns(n);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
    for (uint32_t i = 0; i < n; ++i)
        cSigns[i] = m_ccLWE->EvalSign(LWECiphertexts[i], true);

    return EvalFHEWtoCKKS(cSigns, numCtxts, numSlots, 4, -1.0, 1.0, 0);
}

std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext,
                                                                        PublicKey<DCRTPoly> publicKey,
                                                                        uint32_t numValues, uint32_t numSlots,
                                                                        uint32_t pLWE, double scaleSign) {
    auto cc = ciphertext->GetCryptoContext();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0)
        EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS

    uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    auto cInd          = cc->Encrypt(publicKey, pInd);
    auto newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        const auto n = numValues / (2 * M);

        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, n);

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(n);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
        for (uint32_t j = 0; j < n; ++j)
            LWESign[j] = m_ccLWE->EvalSign(cTemp[j], true);

        // Scheme switching from FHEW to CKKS
        auto dim1    = getRatioBSGSLT(n);
        auto cSelect = EvalFHEWtoCKKS(LWESign, n, numSlots, 4, -1.0, 1.0, dim1);

        std::vector<std::complex<double>> ones(n, 1.0);
        Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
        cc->EvalAddInPlace(cSelect, cc->EvalAtIndex(cc->EvalSub(ptxtOnes, cSelect), -static_cast<int32_t>(n)));

        if (M > 1) {
            for (uint32_t j = numValues / M; j < numValues; j <<= 1)
                cc->EvalAddInPlace(cSelect, cc->EvalAtIndex(cSelect, -static_cast<int32_t>(j)));
        }

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(newCiphertext);

        cInd = cc->EvalMult(cInd, cSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(cInd);
    }

    // After computing the minimum and argument
    if (!m_oneHot)
        cInd = cc->EvalSum(cInd, numValues);

    return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}

std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitchingAlt(ConstCiphertext<DCRTPoly> ciphertext,
                                                                           PublicKey<DCRTPoly> publicKey,
                                                                           uint32_t numValues, uint32_t numSlots,
                                                                           uint32_t pLWE, double scaleSign) {
    auto cc = ciphertext->GetCryptoContext();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0)
        EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output.

    uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    auto cInd          = cc->Encrypt(publicKey, pInd);
    auto newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        const auto n = numValues / (2 * M);

        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, n);

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(numValues);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
        for (uint32_t j = 0; j < n; ++j) {
            LWECiphertext tempSign    = m_ccLWE->EvalSign(cTemp[j], true);
            LWECiphertext negTempSign = std::make_shared<LWECiphertextImpl>(*tempSign);
            m_ccLWE->GetLWEScheme()->EvalAddConstEq(negTempSign, negTempSign->GetModulus() >> 1);  // "negated" tempSign
            for (uint32_t i = 0; i < 2 * M; i += 2) {
                LWESign[i * n + j]       = tempSign;
                LWESign[(i + 1) * n + j] = negTempSign;
            }
        }

        // Scheme switching from FHEW to CKKS
        auto dim1          = getRatioBSGSLT(numValues);
        auto cExpandSelect = EvalFHEWtoCKKS(LWESign, numValues, numSlots, 4, -1.0, 1.0, dim1);

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(newCiphertext);

        cInd = cc->EvalMult(cInd, cExpandSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(cInd);
    }

    // After computing the minimum and argument
    if (!m_oneHot)
        cInd = cc->EvalSum(cInd, numValues);

    return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}

// TODO: used anywhere?
std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMaxSchemeSwitching(ConstCiphertext<DCRTPoly> ciphertext,
                                                                        PublicKey<DCRTPoly> publicKey,
                                                                        uint32_t numValues, uint32_t numSlots,
                                                                        uint32_t pLWE, double scaleSign) {
    auto cc = ciphertext->GetCryptoContext();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0)
        EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output.

    uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    auto cInd          = cc->Encrypt(publicKey, pInd);
    auto newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        const auto n = numValues / (2 * M);

        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, n);

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(n);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
        for (uint32_t j = 0; j < n; ++j)
            LWESign[j] = m_ccLWE->EvalSign(cTemp[j], true);

        // Scheme switching from FHEW to CKKS
        auto dim1    = getRatioBSGSLT(n);
        auto cSelect = EvalFHEWtoCKKS(LWESign, n, numSlots, 4, -1.0, 1.0, dim1);

        std::vector<std::complex<double>> ones(n, 1.0);
        Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
        cSelect = cc->EvalAdd(cc->EvalSub(ptxtOnes, cSelect), cc->EvalAtIndex(cSelect, -static_cast<int32_t>(n)));

        if (M > 1) {
            for (uint32_t j = numValues / M; j < numValues; j <<= 1)
                cc->EvalAddInPlace(cSelect, cc->EvalAtIndex(cSelect, -static_cast<int32_t>(j)));
        }

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(newCiphertext);

        cInd = cc->EvalMult(cInd, cSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(cInd);
    }

    // After computing the minimum and argument
    if (!m_oneHot)
        cInd = cc->EvalSum(cInd, numValues);

    return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}

// TODO: used anywhere?
std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMaxSchemeSwitchingAlt(ConstCiphertext<DCRTPoly> ciphertext,
                                                                           PublicKey<DCRTPoly> publicKey,
                                                                           uint32_t numValues, uint32_t numSlots,
                                                                           uint32_t pLWE, double scaleSign) {
    auto cc = ciphertext->GetCryptoContext();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(ciphertext->GetCryptoParameters());

    // The precomputation has already been performed, but if it is scaled differently than desired, recompute it
    if (pLWE != 0)
        EvalCKKStoFHEWPrecompute(*cc, scaleSign / pLWE);

    uint32_t towersToDrop = 12;  // How many levels are consumed in the EvalFHEWtoCKKS, for binary FHEW output

    uint32_t slots = (numSlots == 0) ? m_numSlotsCKKS : numSlots;

    Plaintext pInd;
    if (m_oneHot) {
        std::vector<std::complex<double>> ind(numValues, 1.0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    else {
        std::vector<std::complex<double>> ind(numValues);
        std::iota(ind.begin(), ind.end(), 0);
        pInd = cc->MakeCKKSPackedPlaintext(ind, 1, towersToDrop, nullptr, slots);
    }
    auto cInd          = cc->Encrypt(publicKey, pInd);
    auto newCiphertext = ciphertext->Clone();

    for (uint32_t M = 1; M < numValues; M <<= 1) {
        const auto n = numValues / (2 * M);

        // Compute CKKS ciphertext encoding difference of the first numValues
        auto cDiff = cc->EvalSub(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        // Transform the ciphertext from CKKS to FHEW
        auto cTemp = EvalCKKStoFHEW(cDiff, n);

        // Evaluate the sign
        // We always assume for the moment that numValues is a power of 2
        std::vector<LWECiphertext> LWESign(numValues);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
        for (uint32_t j = 0; j < n; ++j) {
            LWECiphertext tempSign    = m_ccLWE->EvalSign(cTemp[j], true);
            LWECiphertext negTempSign = std::make_shared<LWECiphertextImpl>(*tempSign);
            m_ccLWE->GetLWEScheme()->EvalAddConstEq(negTempSign, negTempSign->GetModulus() >> 1);  // "negated" tempSign
            for (uint32_t i = 0; i < 2 * M; i += 2) {
                LWESign[i * n + j]       = negTempSign;
                LWESign[(i + 1) * n + j] = tempSign;
            }
        }

        // Scheme switching from FHEW to CKKS
        auto dim1          = getRatioBSGSLT(numValues);
        auto cExpandSelect = EvalFHEWtoCKKS(LWESign, numValues, numSlots, 4, -1.0, 1.0, dim1);

        // Update the ciphertext of values and the indicator
        newCiphertext = cc->EvalMult(newCiphertext, cExpandSelect);
        cc->EvalAddInPlace(newCiphertext, cc->EvalAtIndex(newCiphertext, n));

        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(newCiphertext);

        cInd = cc->EvalMult(cInd, cExpandSelect);
        if (cryptoParams->GetScalingTechnique() == FIXEDMANUAL)
            cc->ModReduceInPlace(cInd);
    }

    // After computing the minimum and argument
    if (!m_oneHot)
        cInd = cc->EvalSum(cInd, numValues);

    return std::vector<Ciphertext<DCRTPoly>>{newCiphertext, cInd};
}

}  // namespace lbcrypto