Program Listing for File ckksrns-leveledshe.cpp

Return to documentation for file (pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2025, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
CKKS implementation. See https://eprint.iacr.org/2020/1118 for details.
 */

#include "cryptocontext.h"
#include "math/hal/basicint.h"
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"
#include "scheme/ckksrns/ckksrns-leveledshe.h"
#include "schemebase/base-scheme.h"

#include <algorithm>
#include <map>
#include <memory>
#include <utility>
#include <vector>

namespace lbcrypto {

// SHE ADDITION CONSTANT

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalAdd(ConstCiphertext<DCRTPoly>& ciphertext, double operand) const {
    auto result = ciphertext->Clone();
    EvalAddInPlace(result, operand);
    return result;
}

void LeveledSHECKKSRNS::EvalAddInPlace(Ciphertext<DCRTPoly>& ciphertext, double operand) const {
    auto elmnts = GetElementForEvalAddOrSub(ciphertext, operand);
    auto& polys = ciphertext->GetElements()[0].GetAllElements();

    const uint32_t limit = polys.size();
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(limit))
    for (uint32_t i = 0; i < limit; ++i)
        polys[i] += elmnts[i];
}

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalAdd(ConstCiphertext<DCRTPoly>& ciphertext,
                                                std::complex<double> operand) const {
    auto result = ciphertext->Clone();
    EvalAddInPlace(result, operand);
    return result;
}

void LeveledSHECKKSRNS::EvalAddInPlace(Ciphertext<DCRTPoly>& ciphertext, std::complex<double> operand) const {
    auto& cv       = ciphertext->GetElements();
    auto elemsRe   = GetElementForEvalAddOrSub(ciphertext, std::fabs(operand.real()));
    auto elemsIm   = GetElementForEvalAddOrSub(ciphertext, std::fabs(operand.imag()));
    uint32_t N     = cv[0].GetLength();
    uint32_t Nhalf = N >> 1;
    auto posreal   = operand.real() > 0.;
    auto posimag   = operand.imag() > 0.;

    DCRTPoly elemsComplex(cv[0].GetParams(), Format::COEFFICIENT, true);
    const uint32_t sizeQl = elemsComplex.GetNumOfElements();
    for (uint32_t i = 0; i < sizeQl; ++i) {
        auto element = cv[0].GetElementAtIndex(i);
        auto modulus = element.GetModulus();
        NativeVector vec(N, modulus);
        vec[0]     = posreal ? NativeInteger(elemsRe[i].Mod(modulus)) : modulus.ModSub(elemsRe[i], modulus);
        vec[Nhalf] = posimag ? NativeInteger(elemsIm[i].Mod(modulus)) : modulus.ModSub(elemsIm[i], modulus);
        element.SetValues(std::move(vec), Format::COEFFICIENT);
        elemsComplex.SetElementAtIndex(i, std::move(element));
    }
    elemsComplex.SetFormat(Format::EVALUATION);

    cv[0] += elemsComplex;
}

// SHE SUBTRACTION CONSTANT

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalSub(ConstCiphertext<DCRTPoly>& ciphertext, double operand) const {
    auto result = ciphertext->Clone();
    EvalSubInPlace(result, operand);
    return result;
}

void LeveledSHECKKSRNS::EvalSubInPlace(Ciphertext<DCRTPoly>& ciphertext, double operand) const {
    auto elmnts = GetElementForEvalAddOrSub(ciphertext, operand);
    auto& polys = ciphertext->GetElements()[0].GetAllElements();

    const uint32_t limit = polys.size();
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(limit))
    for (uint32_t i = 0; i < limit; ++i)
        polys[i] -= elmnts[i];
}

// SHE MULTIPLICATION

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalMult(ConstCiphertext<DCRTPoly>& ciphertext, double operand) const {
    auto result = ciphertext->Clone();
    EvalMultInPlace(result, operand);
    return result;
}

void LeveledSHECKKSRNS::EvalMultInPlace(Ciphertext<DCRTPoly>& ciphertext, double operand) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    if (cryptoParams->GetScalingTechnique() != FIXEDMANUAL) {
        if (ciphertext->GetNoiseScaleDeg() == 2)
            ModReduceInternalInPlace(ciphertext, cryptoParams->GetCompositeDegree());
    }
    EvalMultCoreInPlace(ciphertext, operand);
}

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalMult(ConstCiphertext<DCRTPoly>& ciphertext,
                                                 std::complex<double> operand) const {
    auto result = ciphertext->Clone();
    EvalMultInPlace(result, operand);
    return result;
}

void LeveledSHECKKSRNS::EvalMultInPlace(Ciphertext<DCRTPoly>& ciphertext, std::complex<double> operand) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    if (cryptoParams->GetScalingTechnique() != FIXEDMANUAL) {
        if (ciphertext->GetNoiseScaleDeg() == 2)
            ModReduceInternalInPlace(ciphertext, cryptoParams->GetCompositeDegree());
    }
    EvalMultCoreInPlace(ciphertext, operand);
}

void LeveledSHECKKSRNS::EvalMultInPlace(Ciphertext<DCRTPoly>& ciphertext, ConstPlaintext& plaintext) const {
    LeveledSHERNS::EvalMultInPlace(ciphertext, plaintext);
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    if (cryptoParams->GetScalingTechnique() != NORESCALE)
        ciphertext->SetScalingFactor(ciphertext->GetScalingFactor() * ciphertext->GetScalingFactor());
}

// SHE MULTIPLICATION PLAINTEXT

// Mod Reduce

void LeveledSHECKKSRNS::ModReduceInternalInPlace(Ciphertext<DCRTPoly>& ciphertext, size_t levels) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    auto& cv = ciphertext->GetElements();

    size_t sizeQ  = cryptoParams->GetElementParams()->GetParams().size();
    size_t sizeQl = cv[0].GetNumOfElements();
    size_t diffQl = sizeQ - sizeQl;

    ciphertext->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() - levels / cryptoParams->GetCompositeDegree());
    ciphertext->SetLevel(ciphertext->GetLevel() + levels);

    for (size_t i = 0; i < levels; ++i) {
        for (auto& dcrtpoly : cv)
            dcrtpoly.DropLastElementAndScale(cryptoParams->GetQlQlInvModqlDivqlModq(diffQl + i),
                                             cryptoParams->GetqlInvModq(diffQl + i));
        double modReduceFactor = cryptoParams->GetModReduceFactor(sizeQl - 1 - i);
        ciphertext->SetScalingFactor(ciphertext->GetScalingFactor() / modReduceFactor);
    }
}

// Level Reduce

void LeveledSHECKKSRNS::LevelReduceInternalInPlace(Ciphertext<DCRTPoly>& ciphertext, size_t levels) const {
    for (auto& element : ciphertext->GetElements())
        element.DropLastElements(levels);
    ciphertext->SetLevel(ciphertext->GetLevel() + levels);
}

// Compress

// CKKS Core

#if NATIVEINT == 128
std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalAddOrSub(ConstCiphertext<DCRTPoly>& ciphertext,
                                                                            double operand) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    uint32_t precision = 52;
    double powP        = std::pow(2, precision);

    const auto& cv     = ciphertext->GetElements();
    uint32_t numTowers = cv[0].GetNumOfElements();
    std::vector<DCRTPoly::Integer> moduli(numTowers);

    for (uint32_t i = 0; i < numTowers; ++i)
        moduli[i] = cv[0].GetElementAtIndex(i).GetModulus();

    // the idea is to break down real numbers
    // expressed as input_mantissa * 2^input_exponent
    // into (input_mantissa * 2^52) * 2^(p - 52 + input_exponent)
    // to preserve 52-bit precision of doubles
    // when converting to 128-bit numbers
    int32_t n1       = 0;
    int64_t scaled64 = std::llround(static_cast<double>(std::frexp(operand, &n1)) * powP);

    int32_t pCurrent   = cryptoParams->GetPlaintextModulus() - precision;
    int32_t pRemaining = pCurrent + n1;

    DCRTPoly::Integer scaledConstant;
    if (pRemaining < 0) {
        scaledConstant = NativeInteger(((uint128_t)scaled64) >> (-pRemaining));
    }
    else {
        int128_t ppRemaining = ((int128_t)1) << pRemaining;
        scaledConstant       = NativeInteger((int128_t)scaled64 * ppRemaining);
    }

    DCRTPoly::Integer intPowP;
    uint64_t powp64 = (static_cast<uint64_t>(1)) << precision;
    if (pCurrent < 0) {
        intPowP = NativeInteger((uint128_t)powp64 >> (-pCurrent));
    }
    else {
        intPowP = NativeInteger((uint128_t)powp64 << pCurrent);
    }

    std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);
    std::vector<DCRTPoly::Integer> currPowP(numTowers, scaledConstant);

    // multiply c*powP with powP a total of (depth-1) times to get c*powP^d
    for (uint32_t i = 0; i < ciphertext->GetNoiseScaleDeg() - 1; ++i)
        currPowP = CKKSPackedEncoding::CRTMult(currPowP, crtPowP, moduli);
    return currPowP;
}
#else  // NATIVEINT == 64
std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalAddOrSub(ConstCiphertext<DCRTPoly>& ciphertext,
                                                                            double operand) const {
    const auto& polys     = ciphertext->GetElements()[0].GetAllElements();
    const uint32_t sizeQl = polys.size();
    std::vector<DCRTPoly::Integer> moduli(sizeQl);
    for (uint32_t i = 0; i < sizeQl; ++i)
        moduli[i] = polys[i].GetModulus();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    double scFactor = 0;
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT && ciphertext->GetLevel() == 0) {
        scFactor = cryptoParams->GetScalingFactorRealBig(ciphertext->GetLevel());
    }
    else {
        scFactor = cryptoParams->GetScalingFactorReal(ciphertext->GetLevel());
    }

    // Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.

    // the logic below was added as the code crashes when linked with clang++ in the Debug mode and
    // with the following flags and res is ZERO:
    // -O2
    // -g
    // -fsanitize-trap=all
    // -fsanitize=alignment,return,returns-nonnull-attribute,vla-bound,unreachable,float-cast-overflow
    // -fsanitize=null
    // -gz=zlib
    // -fno-asynchronous-unwind-tables
    // -fno-optimize-sibling-calls
    // -fsplit-dwarf-inlining
    // -gsimple-template-names
    // -gsplit-dwarf
    int32_t logApprox = 0;
    // Duhyeong: We need to take account the 64-bit overflow for both operand * scFactor and scFactor
    double res = std::fabs(operand * scFactor);
    if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO ||
        cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) {
        res = std::max(res, std::fabs(scFactor));
    }
    if (res > 0) {
        int32_t logSF    = static_cast<int32_t>(std::ceil(std::log2(res)));
        int32_t logValid = (logSF <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ?
                               logSF :
                               LargeScalingFactorConstants::MAX_BITS_IN_WORD;
        logApprox        = logSF - logValid;
    }
    int32_t logApprox_cp = logApprox;
    double approxFactor  = std::pow(2, logApprox);

    DCRTPoly::Integer scConstant = static_cast<uint64_t>(operand * scFactor / approxFactor + 0.5);
    std::vector<DCRTPoly::Integer> crtConstant(sizeQl, scConstant);

    // Scale back up by approxFactor within the CRT multiplications.
    if (logApprox > 0) {
        int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
                              logApprox :
                              LargeScalingFactorConstants::MAX_LOG_STEP;
        auto intStep    = static_cast<DCRTPoly::Integer>(1) << logStep;
        std::vector<DCRTPoly::Integer> crtApprox(sizeQl, intStep);
        logApprox -= logStep;

        while (logApprox > 0) {
            int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
                                  logApprox :
                                  LargeScalingFactorConstants::MAX_LOG_STEP;
            auto intStep    = static_cast<DCRTPoly::Integer>(1) << logStep;
            std::vector<DCRTPoly::Integer> crtSF(sizeQl, intStep);
            crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli);
            logApprox -= logStep;
        }
        crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtApprox, moduli);
    }

    // In FLEXIBLEAUTOEXT mode at level 0, we don't use the depth to calculate the scaling factor,
    // so we return the value before taking the depth into account.
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT && ciphertext->GetLevel() == 0) {
        return crtConstant;
    }

    // COMPOSITESCALING support to 128-bit scaling factor
    if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO ||
        cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) {
        int32_t logSF_cp = static_cast<int32_t>(std::ceil(std::log2(res)));
        if (logSF_cp < 64) {
            DCRTPoly::Integer intScFactor = static_cast<uint64_t>(scFactor + 0.5);
            std::vector<DCRTPoly::Integer> crtScFactor(sizeQl, intScFactor);
            for (uint32_t i = 1; i < ciphertext->GetNoiseScaleDeg(); ++i)
                crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli);
        }
        else {
            // Multiply scFactor in two steps: scFactor / approxFactor and then approxFactor
            DCRTPoly::Integer intScFactor = static_cast<uint64_t>(scFactor / approxFactor + 0.5);
            std::vector<DCRTPoly::Integer> crtScFactor(sizeQl, intScFactor);
            for (uint32_t i = 1; i < ciphertext->GetNoiseScaleDeg(); ++i)
                crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli);
            if (logApprox_cp > 0) {
                int32_t logStep = (logApprox_cp <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
                                      logApprox_cp :
                                      LargeScalingFactorConstants::MAX_LOG_STEP;
                auto intStep    = DCRTPoly::Integer(1) << logStep;
                std::vector<DCRTPoly::Integer> crtApprox(sizeQl, intStep);
                logApprox_cp -= logStep;

                while (logApprox_cp > 0) {
                    int32_t logStep = (logApprox_cp <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
                                          logApprox_cp :
                                          LargeScalingFactorConstants::MAX_LOG_STEP;
                    auto intStep    = DCRTPoly::Integer(1) << logStep;
                    std::vector<DCRTPoly::Integer> crtSF(sizeQl, intStep);
                    crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli);
                    logApprox_cp -= logStep;
                }
                for (uint32_t i = 1; i < ciphertext->GetNoiseScaleDeg(); ++i)
                    crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtApprox, moduli);
            }
        }
    }
    else {
        DCRTPoly::Integer intScFactor = static_cast<uint64_t>(scFactor + 0.5);
        std::vector<DCRTPoly::Integer> crtScFactor(sizeQl, intScFactor);
        for (uint32_t i = 1; i < ciphertext->GetNoiseScaleDeg(); ++i)
            crtConstant = CKKSPackedEncoding::CRTMult(crtConstant, crtScFactor, moduli);
    }

    return crtConstant;
}
#endif

#if NATIVEINT == 128
std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalMult(ConstCiphertext<DCRTPoly>& ciphertext,
                                                                        double operand) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    uint32_t precision = 52;
    double powP        = std::pow(2, precision);

    // the idea is to break down real numbers
    // expressed as input_mantissa * 2^input_exponent
    // into (input_mantissa * 2^52) * 2^(p - 52 + input_exponent)
    // to preserve 52-bit precision of doubles
    // when converting to 128-bit numbers
    int32_t n1         = 0;
    int64_t scaled64   = std::llround(static_cast<double>(std::frexp(operand, &n1)) * powP);
    int32_t pCurrent   = cryptoParams->GetPlaintextModulus() - precision;
    int32_t pRemaining = pCurrent + n1;
    int128_t scaled128 = 0;

    if (pRemaining < 0) {
        scaled128 = scaled64 >> (-pRemaining);
    }
    else {
        int128_t ppRemaining = ((int128_t)1) << pRemaining;
        scaled128            = ppRemaining * scaled64;
    }

    const auto& cv     = ciphertext->GetElements();
    uint32_t numTowers = cv[0].GetNumOfElements();
    std::vector<DCRTPoly::Integer> factors(numTowers);

    for (uint32_t i = 0; i < numTowers; i++) {
        DCRTPoly::Integer modulus = cv[0].GetElementAtIndex(i).GetModulus();

        if (scaled128 < 0) {
            DCRTPoly::Integer reducedUnsigned = static_cast<BasicInteger>(-scaled128);
            reducedUnsigned.ModEq(modulus);
            factors[i] = modulus - reducedUnsigned;
        }
        else {
            DCRTPoly::Integer reducedUnsigned = static_cast<BasicInteger>(scaled128);
            reducedUnsigned.ModEq(modulus);
            factors[i] = reducedUnsigned;
        }
    }
    return factors;
}
#else  // NATIVEINT == 64
std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalMult(ConstCiphertext<DCRTPoly>& ciphertext,
                                                                        double operand) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    const std::vector<DCRTPoly>& cv = ciphertext->GetElements();
    uint32_t numTowers              = cv[0].GetNumOfElements();
    std::vector<DCRTPoly::Integer> moduli(numTowers);
    for (uint32_t i = 0; i < numTowers; ++i)
        moduli[i] = cv[0].GetElementAtIndex(i).GetModulus();

    double scFactor = cryptoParams->GetScalingFactorReal(ciphertext->GetLevel());

    #if defined(HAVE_INT128)
    typedef int128_t DoubleInteger;
    int32_t MAX_BITS_IN_WORD_LOCAL = 125;
    #else
    typedef int64_t DoubleInteger;
    int32_t MAX_BITS_IN_WORD_LOCAL = LargeScalingFactorConstants::MAX_BITS_IN_WORD;
    #endif

    // Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.

    // the logic below was added as the code crashes when linked with clang++ in the Debug mode and
    // with the following flags and res is ZERO:
    // -O2
    // -g
    // -fsanitize-trap=all
    // -fsanitize=alignment,return,returns-nonnull-attribute,vla-bound,unreachable,float-cast-overflow
    // -fsanitize=null
    // -gz=zlib
    // -fno-asynchronous-unwind-tables
    // -fno-optimize-sibling-calls
    // -fsplit-dwarf-inlining
    // -gsimple-template-names
    // -gsplit-dwarf
    int32_t logApprox = 0;
    const double res  = std::fabs(operand * scFactor);
    if (res > 0) {
        int32_t logSF    = static_cast<int32_t>(std::ceil(std::log2(res)));
        int32_t logValid = (logSF <= MAX_BITS_IN_WORD_LOCAL) ? logSF : MAX_BITS_IN_WORD_LOCAL;
        logApprox        = logSF - logValid;
    }
    double approxFactor = std::pow(2, logApprox);

    DoubleInteger large     = static_cast<DoubleInteger>(operand / approxFactor * scFactor + 0.5);
    DoubleInteger large_abs = (large < 0 ? -large : large);
    DoubleInteger bound     = static_cast<uint64_t>(1) << 63;

    std::vector<DCRTPoly::Integer> factors(numTowers);

    if (large_abs >= bound) {
        for (uint32_t i = 0; i < numTowers; i++) {
            DoubleInteger reduced = large % moduli[i].ConvertToInt();

            factors[i] = (reduced < 0) ? static_cast<uint64_t>(reduced + moduli[i].ConvertToInt()) :
                                         static_cast<uint64_t>(reduced);
        }
    }
    else {
        int64_t scConstant = static_cast<int64_t>(large);
        for (uint32_t i = 0; i < numTowers; i++) {
            int64_t reduced = scConstant % static_cast<int64_t>(moduli[i].ConvertToInt());

            factors[i] = (reduced < 0) ? reduced + moduli[i].ConvertToInt() : reduced;
        }
    }

    // Scale back up by approxFactor within the CRT multiplications.
    if (logApprox > 0) {
        int32_t logStep           = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
                                        logApprox :
                                        LargeScalingFactorConstants::MAX_LOG_STEP;
        DCRTPoly::Integer intStep = static_cast<uint64_t>(1) << logStep;
        std::vector<DCRTPoly::Integer> crtApprox(numTowers, intStep);
        logApprox -= logStep;

        while (logApprox > 0) {
            int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
                                  logApprox :
                                  LargeScalingFactorConstants::MAX_LOG_STEP;
            auto intStep    = DCRTPoly::Integer(1) << logStep;
            std::vector<DCRTPoly::Integer> crtSF(numTowers, intStep);
            crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli);
            logApprox -= logStep;
        }
        factors = CKKSPackedEncoding::CRTMult(factors, crtApprox, moduli);
    }

    return factors;
}

#endif

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalFastRotationExt(
    ConstCiphertext<DCRTPoly>& ciphertext, uint32_t index, const std::shared_ptr<std::vector<DCRTPoly>> digits,
    bool addFirst, const std::map<uint32_t, EvalKey<DCRTPoly>>& evalKeys) const {
    //  if (index == 0) {
    //    Ciphertext<DCRTPoly> result = ciphertext->Clone();
    //    return result;
    //  }

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    const uint32_t M = cryptoParams->GetElementParams()->GetCyclotomicOrder();

    // Find the automorphism index that corresponds to rotation index index.
    const uint32_t autoIndex = FindAutomorphismIndex2nComplex(index, M);

    // Retrieve the automorphism key that corresponds to the auto index.
    auto evalKeyIterator = evalKeys.find(autoIndex);
    if (evalKeyIterator == evalKeys.end())
        OPENFHE_THROW("EvalKey for index [" + std::to_string(autoIndex) + "] is not found.");
    auto& evalKey = evalKeyIterator->second;

    const auto& cv      = ciphertext->GetElements();
    const auto paramsQl = cv[0].GetParams();

    const auto cc = ciphertext->GetCryptoContext();
    auto cTilda   = *cc->GetScheme()->EvalFastKeySwitchCoreExt(digits, evalKey, paramsQl);

    if (addFirst) {
        DCRTPoly psiC0(cTilda[0].GetParams(), Format::EVALUATION, true);
        auto cMult = cv[0].TimesNoCheck(cryptoParams->GetPModq());

        const uint32_t sizeQl = paramsQl->GetParams().size();
        for (uint32_t i = 0; i < sizeQl; ++i)
            psiC0.SetElementAtIndex(i, std::move(cMult.GetElementAtIndex(i)));

        cTilda[0] += psiC0;
    }

    const uint32_t N = cryptoParams->GetElementParams()->GetRingDimension();
    std::vector<uint32_t> vec(N);
    PrecomputeAutoMap(N, autoIndex, &vec);

    cTilda[0] = cTilda[0].AutomorphismTransform(autoIndex, vec);
    cTilda[1] = cTilda[1].AutomorphismTransform(autoIndex, vec);

    auto result = ciphertext->CloneEmpty();
    result->SetElements(std::move(cTilda));
    return result;
}

Ciphertext<DCRTPoly> LeveledSHECKKSRNS::MultByInteger(ConstCiphertext<DCRTPoly>& ciphertext, uint64_t integer) const {
    const std::vector<DCRTPoly>& cv = ciphertext->GetElements();

    std::vector<DCRTPoly> resultDCRT;
    resultDCRT.reserve(cv.size());
    for (const auto& elem : cv)
        resultDCRT.push_back(elem.Times(NativeInteger(integer)));

    auto result = ciphertext->CloneEmpty();
    result->SetElements(std::move(resultDCRT));
    return result;
}

void LeveledSHECKKSRNS::MultByIntegerInPlace(Ciphertext<DCRTPoly>& ciphertext, uint64_t integer) const {
    auto& cv = ciphertext->GetElements();
    for (uint32_t i = 0; i < cv.size(); ++i)
        cv[i] = cv[i].Times(NativeInteger(integer));
}

void LeveledSHECKKSRNS::AdjustLevelsAndDepthInPlace(Ciphertext<DCRTPoly>& ciphertext1,
                                                    Ciphertext<DCRTPoly>& ciphertext2) const {
    const uint32_t c1lvl     = ciphertext1->GetLevel();
    const uint32_t c2lvl     = ciphertext2->GetLevel();
    const uint32_t c1depth   = ciphertext1->GetNoiseScaleDeg();
    const uint32_t c2depth   = ciphertext2->GetNoiseScaleDeg();
    const uint32_t sizeQl1   = ciphertext1->GetElements()[0].GetNumOfElements();
    const uint32_t sizeQl2   = ciphertext2->GetElements()[0].GetNumOfElements();
    const auto cryptoParams  = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext1->GetCryptoParameters());
    uint32_t compositeDegree = cryptoParams->GetCompositeDegree();

    if (c1lvl < c2lvl) {
        if (c1depth == 2) {
            if (c2depth == 2) {
                double scf1 = ciphertext1->GetScalingFactor();
                double scf2 = ciphertext2->GetScalingFactor();
                double scf  = cryptoParams->GetScalingFactorReal(c1lvl);
                double q1   = cryptoParams->GetModReduceFactor(sizeQl1 - 1);
                for (uint32_t j = 1; j < compositeDegree; ++j)
                    q1 *= cryptoParams->GetModReduceFactor(sizeQl1 - j - 1);
                EvalMultCoreInPlace(ciphertext1, scf2 / scf1 * q1 / scf);
                ModReduceInternalInPlace(ciphertext1, compositeDegree);
                if (c1lvl + compositeDegree < c2lvl)
                    LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl - compositeDegree);
                ciphertext1->SetScalingFactor(ciphertext2->GetScalingFactor());
            }
            else {
                if (c1lvl + compositeDegree == c2lvl) {
                    ModReduceInternalInPlace(ciphertext1, compositeDegree);
                }
                else {
                    double scf1 = ciphertext1->GetScalingFactor();
                    double scf2 = cryptoParams->GetScalingFactorRealBig(c2lvl - compositeDegree);
                    double scf  = cryptoParams->GetScalingFactorReal(c1lvl);
                    double q1   = cryptoParams->GetModReduceFactor(sizeQl1 - 1);
                    for (uint32_t j = 1; j < compositeDegree; ++j)
                        q1 *= cryptoParams->GetModReduceFactor(sizeQl1 - j - 1);
                    EvalMultCoreInPlace(ciphertext1, scf2 / scf1 * q1 / scf);
                    ModReduceInternalInPlace(ciphertext1, compositeDegree);
                    if (c1lvl + 2 * compositeDegree < c2lvl)
                        LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl - 2 * compositeDegree);
                    ModReduceInternalInPlace(ciphertext1, compositeDegree);
                    ciphertext1->SetScalingFactor(ciphertext2->GetScalingFactor());
                }
            }
        }
        else {
            if (c2depth == 2) {
                double scf1 = ciphertext1->GetScalingFactor();
                double scf2 = ciphertext2->GetScalingFactor();
                double scf  = cryptoParams->GetScalingFactorReal(c1lvl);
                EvalMultCoreInPlace(ciphertext1, scf2 / scf1 / scf);
                LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl);
                ciphertext1->SetScalingFactor(scf2);
            }
            else {
                double scf1 = ciphertext1->GetScalingFactor();
                double scf2 = cryptoParams->GetScalingFactorRealBig(c2lvl - compositeDegree);
                double scf  = cryptoParams->GetScalingFactorReal(c1lvl);
                EvalMultCoreInPlace(ciphertext1, scf2 / scf1 / scf);
                if (c1lvl + compositeDegree < c2lvl)
                    LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl - compositeDegree);
                ModReduceInternalInPlace(ciphertext1, compositeDegree);
                ciphertext1->SetScalingFactor(ciphertext2->GetScalingFactor());
            }
        }
    }
    else if (c1lvl > c2lvl) {
        if (c2depth == 2) {
            if (c1depth == 2) {
                double scf2 = ciphertext2->GetScalingFactor();
                double scf1 = ciphertext1->GetScalingFactor();
                double scf  = cryptoParams->GetScalingFactorReal(c2lvl);
                double q2   = cryptoParams->GetModReduceFactor(sizeQl2 - 1);
                for (uint32_t j = 1; j < compositeDegree; ++j)
                    q2 *= cryptoParams->GetModReduceFactor(sizeQl2 - j - 1);
                EvalMultCoreInPlace(ciphertext2, scf1 / scf2 * q2 / scf);
                ModReduceInternalInPlace(ciphertext2, compositeDegree);
                if (c2lvl + compositeDegree < c1lvl)
                    LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl - compositeDegree);
                ciphertext2->SetScalingFactor(ciphertext1->GetScalingFactor());
            }
            else {
                if (c2lvl + compositeDegree == c1lvl) {
                    ModReduceInternalInPlace(ciphertext2, compositeDegree);
                }
                else {
                    double scf2 = ciphertext2->GetScalingFactor();
                    double scf1 = cryptoParams->GetScalingFactorRealBig(c1lvl - compositeDegree);
                    double scf  = cryptoParams->GetScalingFactorReal(c2lvl);
                    double q2   = cryptoParams->GetModReduceFactor(sizeQl2 - 1);
                    for (uint32_t j = 1; j < compositeDegree; ++j)
                        q2 *= cryptoParams->GetModReduceFactor(sizeQl2 - j - 1);
                    EvalMultCoreInPlace(ciphertext2, scf1 / scf2 * q2 / scf);
                    ModReduceInternalInPlace(ciphertext2, compositeDegree);
                    if (c2lvl + 2 * compositeDegree < c1lvl)
                        LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl - 2 * compositeDegree);
                    ModReduceInternalInPlace(ciphertext2, compositeDegree);
                    ciphertext2->SetScalingFactor(ciphertext1->GetScalingFactor());
                }
            }
        }
        else {
            if (c1depth == 2) {
                double scf2 = ciphertext2->GetScalingFactor();
                double scf1 = ciphertext1->GetScalingFactor();
                double scf  = cryptoParams->GetScalingFactorReal(c2lvl);
                EvalMultCoreInPlace(ciphertext2, scf1 / scf2 / scf);
                LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl);
                ciphertext2->SetScalingFactor(scf1);
            }
            else {
                double scf2 = ciphertext2->GetScalingFactor();
                double scf1 = cryptoParams->GetScalingFactorRealBig(c1lvl - compositeDegree);
                double scf  = cryptoParams->GetScalingFactorReal(c2lvl);
                EvalMultCoreInPlace(ciphertext2, scf1 / scf2 / scf);
                if (c2lvl + compositeDegree < c1lvl)
                    LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl - compositeDegree);
                ModReduceInternalInPlace(ciphertext2, compositeDegree);
                ciphertext2->SetScalingFactor(ciphertext1->GetScalingFactor());
            }
        }
    }
    else {
        if (c1depth < c2depth) {
            EvalMultCoreInPlace(ciphertext1, 1.0);
        }
        else if (c2depth < c1depth) {
            EvalMultCoreInPlace(ciphertext2, 1.0);
        }
    }
}

void LeveledSHECKKSRNS::AdjustLevelsAndDepthToOneInPlace(Ciphertext<DCRTPoly>& ciphertext1,
                                                         Ciphertext<DCRTPoly>& ciphertext2) const {
    AdjustLevelsAndDepthInPlace(ciphertext1, ciphertext2);

    if (ciphertext1->GetNoiseScaleDeg() == 2) {
        const auto cryptoParams =
            std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext1->GetCryptoParameters());
        ModReduceInternalInPlace(ciphertext1, cryptoParams->GetCompositeDegree());
        ModReduceInternalInPlace(ciphertext2, cryptoParams->GetCompositeDegree());
    }
}

void LeveledSHECKKSRNS::EvalMultCoreInPlace(Ciphertext<DCRTPoly>& ciphertext, double operand) const {
    auto& cv     = ciphertext->GetElements();
    auto factors = GetElementForEvalMult(ciphertext, operand);
    for (uint32_t i = 0; i < cv.size(); ++i)
        cv[i] = cv[i] * factors;

    ciphertext->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() + 1);

    auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    double scFactor   = cryptoParams->GetScalingFactorReal(ciphertext->GetLevel());
    ciphertext->SetScalingFactor(ciphertext->GetScalingFactor() * scFactor);
}

void LeveledSHECKKSRNS::EvalMultCoreInPlace(Ciphertext<DCRTPoly>& ciphertext, std::complex<double> operand) const {
    auto& cv = ciphertext->GetElements();

    // MultByMonomialInPlace
    const auto& elemParams = cv[0].GetParams();
    NativePoly monomial(elemParams->GetParams()[0], Format::COEFFICIENT, true);
    monomial[elemParams->GetCyclotomicOrder() >> 2] = NativeInteger(1);

    DCRTPoly monomialDCRT(elemParams, Format::COEFFICIENT, true);
    monomialDCRT = monomial;
    monomialDCRT.SetFormat(Format::EVALUATION);

    auto factorsRe = GetElementForEvalMult(ciphertext, operand.real());
    auto factorsIm = monomialDCRT * GetElementForEvalMult(ciphertext, operand.imag());

    for (uint32_t i = 0; i < cv.size(); ++i)
        cv[i] = (cv[i] * factorsRe) + (cv[i] * factorsIm);

    ciphertext->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() + 1);

    auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    double scFactor   = cryptoParams->GetScalingFactorReal(ciphertext->GetLevel());
    ciphertext->SetScalingFactor(ciphertext->GetScalingFactor() * scFactor);
}

uint32_t LeveledSHECKKSRNS::FindAutomorphismIndex(uint32_t index, uint32_t m) const {
    return FindAutomorphismIndex2nComplex(index, m);
}

}  // namespace lbcrypto