Program Listing for File keyswitch-hybrid.cpp

Return to documentation for file (pke/lib/keyswitch/keyswitch-hybrid.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#define PROFILE

#include "keyswitch/keyswitch-hybrid.h"

#include "key/privatekey.h"
#include "key/publickey.h"
#include "key/evalkeyrelin.h"
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"
#include "ciphertext.h"

namespace lbcrypto {

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                        const PrivateKey<DCRTPoly> newKey) const {
    return KeySwitchHYBRID::KeySwitchGenInternal(oldKey, newKey, nullptr);
}

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                        const PrivateKey<DCRTPoly> newKey,
                                                        const EvalKey<DCRTPoly> ekPrev) const {
    EvalKeyRelin<DCRTPoly> ek(std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext()));

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newKey->GetCryptoParameters());

    const std::shared_ptr<ParmType> paramsQ  = cryptoParams->GetElementParams();
    const std::shared_ptr<ParmType> paramsQP = cryptoParams->GetParamsQP();

    size_t sizeQ  = paramsQ->GetParams().size();
    size_t sizeQP = paramsQP->GetParams().size();

    DCRTPoly sOld = oldKey->GetPrivateElement();
    DCRTPoly sNew = newKey->GetPrivateElement().Clone();

    // skNew is currently in basis Q. This extends it to basis QP.
    sNew.SetFormat(Format::COEFFICIENT);

    DCRTPoly sNewExt(paramsQP, Format::COEFFICIENT, true);

    // The part with basis Q
    for (size_t i = 0; i < sizeQ; i++) {
        sNewExt.SetElementAtIndex(i, sNew.GetElementAtIndex(i));
    }

    // The part with basis P
    for (size_t j = sizeQ; j < sizeQP; j++) {
        const NativeInteger& pj    = paramsQP->GetParams()[j]->GetModulus();
        const NativeInteger& rootj = paramsQP->GetParams()[j]->GetRootOfUnity();
        auto sNew0                 = sNew.GetElementAtIndex(0);
        sNew0.SwitchModulus(pj, rootj, 0, 0);
        sNewExt.SetElementAtIndex(j, std::move(sNew0));
    }

    sNewExt.SetFormat(Format::EVALUATION);

    const auto ns      = cryptoParams->GetNoiseScale();
    const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
    DugType dug;

    size_t numPartQ = cryptoParams->GetNumPartQ();

    std::vector<DCRTPoly> av(numPartQ);
    std::vector<DCRTPoly> bv(numPartQ);

    std::vector<NativeInteger> PModq = cryptoParams->GetPModq();
    size_t numPerPartQ               = cryptoParams->GetNumPerPartQ();

    for (size_t part = 0; part < numPartQ; ++part) {
        DCRTPoly a = (ekPrev == nullptr) ? DCRTPoly(dug, paramsQP, Format::EVALUATION) :  // single-key HE
                                           ekPrev->GetAVector()[part];                                      // threshold HE
        DCRTPoly e(dgg, paramsQP, Format::EVALUATION);
        DCRTPoly b(paramsQP, Format::EVALUATION, true);

        // starting and ending position of current part
        size_t startPartIdx = numPerPartQ * part;
        size_t endPartIdx   = (sizeQ > (startPartIdx + numPerPartQ)) ? (startPartIdx + numPerPartQ) : sizeQ;

        for (size_t i = 0; i < sizeQP; ++i) {
            auto ai    = a.GetElementAtIndex(i);
            auto ei    = e.GetElementAtIndex(i);
            auto sNewi = sNewExt.GetElementAtIndex(i);

            if (i < startPartIdx || i >= endPartIdx) {
                b.SetElementAtIndex(i, -ai * sNewi + ns * ei);
            }
            else {
                // P * sOld is only applied for the current part
                auto sOldi = sOld.GetElementAtIndex(i);
                b.SetElementAtIndex(i, -ai * sNewi + PModq[i] * sOldi + ns * ei);
            }
        }

        av[part] = a;
        bv[part] = b;
    }

    ek->SetAVector(std::move(av));
    ek->SetBVector(std::move(bv));
    ek->SetKeyTag(newKey->GetKeyTag());
    return ek;
}

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                        const PublicKey<DCRTPoly> newKey) const {
    EvalKeyRelin<DCRTPoly> ek = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext());

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newKey->GetCryptoParameters());

    const std::shared_ptr<ParmType> paramsQ  = cryptoParams->GetElementParams();
    const std::shared_ptr<ParmType> paramsQP = cryptoParams->GetParamsQP();

    usint sizeQ  = paramsQ->GetParams().size();
    usint sizeQP = paramsQP->GetParams().size();

    DCRTPoly sOld = oldKey->GetPrivateElement();

    DCRTPoly newp0 = newKey->GetPublicElements().at(0);
    DCRTPoly newp1 = newKey->GetPublicElements().at(1);

    const auto ns      = cryptoParams->GetNoiseScale();
    const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
    TugType tug;

    auto numPartQ = cryptoParams->GetNumPartQ();

    std::vector<DCRTPoly> av(numPartQ);
    std::vector<DCRTPoly> bv(numPartQ);

    std::vector<NativeInteger> PModq = cryptoParams->GetPModq();
    usint numPerPartQ                = cryptoParams->GetNumPerPartQ();

    for (usint part = 0; part < numPartQ; part++) {
        DCRTPoly u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ? DCRTPoly(dgg, paramsQP, Format::EVALUATION) :
                                                                      DCRTPoly(tug, paramsQP, Format::EVALUATION);

        DCRTPoly e0(dgg, paramsQP, Format::EVALUATION);
        DCRTPoly e1(dgg, paramsQP, Format::EVALUATION);

        DCRTPoly a(paramsQP, Format::EVALUATION, true);
        DCRTPoly b(paramsQP, Format::EVALUATION, true);

        // starting and ending position of current part
        usint startPartIdx = numPerPartQ * part;
        usint endPartIdx   = (sizeQ > startPartIdx + numPerPartQ) ? (startPartIdx + numPerPartQ) : sizeQ;

        for (usint i = 0; i < sizeQP; i++) {
            auto e0i = e0.GetElementAtIndex(i);
            auto e1i = e1.GetElementAtIndex(i);

            auto ui = u.GetElementAtIndex(i);

            auto newp0i = newp0.GetElementAtIndex(i);
            auto newp1i = newp1.GetElementAtIndex(i);

            a.SetElementAtIndex(i, newp1i * ui + ns * e1i);

            if (i < startPartIdx || i >= endPartIdx) {
                b.SetElementAtIndex(i, newp0i * ui + ns * e0i);
            }
            else {
                // P * sOld is only applied for the current part
                auto sOldi = sOld.GetElementAtIndex(i);
                b.SetElementAtIndex(i, newp0i * ui + ns * e0i + PModq[i] * sOldi);
            }
        }

        av[part] = a;
        bv[part] = b;
    }

    ek->SetAVector(std::move(av));
    ek->SetBVector(std::move(bv));
    ek->SetKeyTag(newKey->GetKeyTag());

    return ek;
}

void KeySwitchHYBRID::KeySwitchInPlace(Ciphertext<DCRTPoly>& ciphertext, const EvalKey<DCRTPoly> ek) const {
    std::vector<DCRTPoly>& cv = ciphertext->GetElements();

    std::shared_ptr<std::vector<DCRTPoly>> ba = (cv.size() == 2) ? KeySwitchCore(cv[1], ek) : KeySwitchCore(cv[2], ek);

    cv[0].SetFormat((*ba)[0].GetFormat());
    cv[0] += (*ba)[0];

    cv[1].SetFormat((*ba)[1].GetFormat());
    if (cv.size() > 2) {
        cv[1] += (*ba)[1];
    }
    else {
        cv[1] = (*ba)[1];
    }
    cv.resize(2);
}

Ciphertext<DCRTPoly> KeySwitchHYBRID::KeySwitchExt(ConstCiphertext<DCRTPoly> ciphertext, bool addFirst) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    const std::vector<DCRTPoly>& cv = ciphertext->GetElements();

    const auto paramsQl  = cv[0].GetParams();
    const auto paramsP   = cryptoParams->GetParamsP();
    const auto paramsQlP = cv[0].GetExtendedCRTBasis(paramsP);

    size_t sizeQl = paramsQl->GetParams().size();
    usint sizeCv  = cv.size();
    std::vector<DCRTPoly> resultElements(sizeCv);
    for (usint k = 0; k < sizeCv; k++) {
        resultElements[k] = DCRTPoly(paramsQlP, Format::EVALUATION, true);
        if ((addFirst) || (k > 0)) {
            auto cMult = cv[k].TimesNoCheck(cryptoParams->GetPModq());
            for (usint i = 0; i < sizeQl; i++) {
                resultElements[k].SetElementAtIndex(i, cMult.GetElementAtIndex(i));
            }
        }
    }

    Ciphertext<DCRTPoly> result = ciphertext->CloneZero();
    result->SetElements(resultElements);
    return result;
}

Ciphertext<DCRTPoly> KeySwitchHYBRID::KeySwitchDown(ConstCiphertext<DCRTPoly> ciphertext) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    const auto paramsP   = cryptoParams->GetParamsP();
    const auto paramsQlP = ciphertext->GetElements()[0].GetParams();

    // TODO : (Andrey) precompute paramsQl in cryptoparameters
    usint sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size();
    std::vector<NativeInteger> moduliQ(sizeQl);
    std::vector<NativeInteger> rootsQ(sizeQl);
    for (size_t i = 0; i < sizeQl; i++) {
        moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
        rootsQ[i]  = paramsQlP->GetParams()[i]->GetRootOfUnity();
    }
    auto paramsQl = std::make_shared<typename DCRTPoly::Params>(2 * paramsQlP->GetRingDimension(), moduliQ, rootsQ);

    auto cTilda = ciphertext->GetElements();

    PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

    DCRTPoly ct0 = cTilda[0].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                           cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                           cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                           cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                           cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());

    DCRTPoly ct1 = cTilda[1].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                           cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                           cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                           cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                           cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());

    Ciphertext<DCRTPoly> result = ciphertext->CloneZero();
    result->SetElements({ct0, ct1});
    return result;
}

DCRTPoly KeySwitchHYBRID::KeySwitchDownFirstElement(ConstCiphertext<DCRTPoly> ciphertext) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    const std::vector<DCRTPoly>& cTilda = ciphertext->GetElements();

    const auto paramsP   = cryptoParams->GetParamsP();
    const auto paramsQlP = cTilda[0].GetParams();

    // TODO : (Andrey) precompute paramsQl in cryptoparameters
    usint sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size();
    std::vector<NativeInteger> moduliQ(sizeQl);
    std::vector<NativeInteger> rootsQ(sizeQl);
    for (size_t i = 0; i < sizeQl; i++) {
        moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
        rootsQ[i]  = paramsQlP->GetParams()[i]->GetRootOfUnity();
    }
    auto paramsQl = std::make_shared<typename DCRTPoly::Params>(2 * paramsQlP->GetRingDimension(), moduliQ, rootsQ);

    PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

    DCRTPoly cv0 = cTilda[0].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                           cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                           cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                           cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                           cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());

    return cv0;
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::KeySwitchCore(const DCRTPoly& a,
                                                                      const EvalKey<DCRTPoly> evalKey) const {
    return EvalFastKeySwitchCore(EvalKeySwitchPrecomputeCore(a, evalKey->GetCryptoParameters()), evalKey,
                                 a.GetParams());
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalKeySwitchPrecomputeCore(
    const DCRTPoly& c, std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParamsBase) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(cryptoParamsBase);

    const std::shared_ptr<ParmType> paramsQl  = c.GetParams();
    const std::shared_ptr<ParmType> paramsP   = cryptoParams->GetParamsP();
    const std::shared_ptr<ParmType> paramsQlP = c.GetExtendedCRTBasis(paramsP);

    size_t sizeQl  = paramsQl->GetParams().size();
    size_t sizeP   = paramsP->GetParams().size();
    size_t sizeQlP = sizeQl + sizeP;

    uint32_t alpha = cryptoParams->GetNumPerPartQ();
    // The number of digits of the current ciphertext
    uint32_t numPartQl = ceil((static_cast<double>(sizeQl)) / alpha);
    if (numPartQl > cryptoParams->GetNumberOfQPartitions())
        numPartQl = cryptoParams->GetNumberOfQPartitions();

    std::vector<DCRTPoly> partsCt(numPartQl);

    // Digit decomposition
    // Zero-padding and split
    for (uint32_t part = 0; part < numPartQl; part++) {
        if (part == numPartQl - 1) {
            auto paramsPartQ = cryptoParams->GetParamsPartQ(part);

            uint32_t sizePartQl = sizeQl - alpha * part;

            std::vector<NativeInteger> moduli(sizePartQl);
            std::vector<NativeInteger> roots(sizePartQl);

            for (uint32_t i = 0; i < sizePartQl; i++) {
                moduli[i] = paramsPartQ->GetParams()[i]->GetModulus();
                roots[i]  = paramsPartQ->GetParams()[i]->GetRootOfUnity();
            }

            auto params = DCRTPoly::Params(paramsPartQ->GetCyclotomicOrder(), moduli, roots);

            partsCt[part] = DCRTPoly(std::make_shared<ParmType>(params), Format::EVALUATION, true);
        }
        else {
            partsCt[part] = DCRTPoly(cryptoParams->GetParamsPartQ(part), Format::EVALUATION, true);
        }

        usint sizePartQl   = partsCt[part].GetNumOfElements();
        usint startPartIdx = alpha * part;
        for (uint32_t i = 0, idx = startPartIdx; i < sizePartQl; i++, idx++) {
            partsCt[part].SetElementAtIndex(i, c.GetElementAtIndex(idx));
        }
    }

    std::vector<DCRTPoly> partsCtCompl(numPartQl);
    std::vector<DCRTPoly> partsCtExt(numPartQl);

    for (uint32_t part = 0; part < numPartQl; part++) {
        auto partCtClone = partsCt[part].Clone();
        partCtClone.SetFormat(Format::COEFFICIENT);

        uint32_t sizePartQl = partsCt[part].GetNumOfElements();
        partsCtCompl[part]  = partCtClone.ApproxSwitchCRTBasis(
            cryptoParams->GetParamsPartQ(part), cryptoParams->GetParamsComplPartQ(sizeQl - 1, part),
            cryptoParams->GetPartQlHatInvModq(part, sizePartQl - 1),
            cryptoParams->GetPartQlHatInvModqPrecon(part, sizePartQl - 1),
            cryptoParams->GetPartQlHatModp(sizeQl - 1, part),
            cryptoParams->GetmodComplPartqBarrettMu(sizeQl - 1, part));

        partsCtCompl[part].SetFormat(Format::EVALUATION);

        partsCtExt[part] = DCRTPoly(paramsQlP, Format::EVALUATION, true);

        usint startPartIdx = alpha * part;
        usint endPartIdx   = startPartIdx + sizePartQl;
        for (usint i = 0; i < startPartIdx; i++) {
            partsCtExt[part].SetElementAtIndex(i, partsCtCompl[part].GetElementAtIndex(i));
        }
        for (usint i = startPartIdx, idx = 0; i < endPartIdx; i++, idx++) {
            partsCtExt[part].SetElementAtIndex(i, partsCt[part].GetElementAtIndex(idx));
        }
        for (usint i = endPartIdx; i < sizeQlP; ++i) {
            partsCtExt[part].SetElementAtIndex(i, partsCtCompl[part].GetElementAtIndex(i - sizePartQl));
        }
    }

    return std::make_shared<std::vector<DCRTPoly>>(std::move(partsCtExt));
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalFastKeySwitchCore(
    const std::shared_ptr<std::vector<DCRTPoly>> digits, const EvalKey<DCRTPoly> evalKey,
    const std::shared_ptr<ParmType> paramsQl) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(evalKey->GetCryptoParameters());

    std::shared_ptr<std::vector<DCRTPoly>> cTilda = EvalFastKeySwitchCoreExt(digits, evalKey, paramsQl);

    PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

    DCRTPoly ct0 = (*cTilda)[0].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                              cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                              cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                              cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                              cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());

    DCRTPoly ct1 = (*cTilda)[1].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                              cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                              cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                              cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                              cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());

    return std::make_shared<std::vector<DCRTPoly>>(std::initializer_list<DCRTPoly>{std::move(ct0), std::move(ct1)});
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalFastKeySwitchCoreExt(
    const std::shared_ptr<std::vector<DCRTPoly>> digits, const EvalKey<DCRTPoly> evalKey,
    const std::shared_ptr<ParmType> paramsQl) const {
    const auto cryptoParams         = std::dynamic_pointer_cast<CryptoParametersRNS>(evalKey->GetCryptoParameters());
    const std::vector<DCRTPoly>& bv = evalKey->GetBVector();
    const std::vector<DCRTPoly>& av = evalKey->GetAVector();

    const std::shared_ptr<ParmType> paramsP   = cryptoParams->GetParamsP();
    const std::shared_ptr<ParmType> paramsQlP = (*digits)[0].GetParams();

    size_t sizeQl  = paramsQl->GetParams().size();
    size_t sizeQlP = paramsQlP->GetParams().size();
    size_t sizeQ   = cryptoParams->GetElementParams()->GetParams().size();

    DCRTPoly cTilda0(paramsQlP, Format::EVALUATION, true);
    DCRTPoly cTilda1(paramsQlP, Format::EVALUATION, true);

    for (uint32_t j = 0; j < digits->size(); j++) {
        const DCRTPoly& cj = (*digits)[j];
        const DCRTPoly& bj = bv[j];
        const DCRTPoly& aj = av[j];

        for (usint i = 0; i < sizeQl; i++) {
            const auto& cji = cj.GetElementAtIndex(i);
            const auto& aji = aj.GetElementAtIndex(i);
            const auto& bji = bj.GetElementAtIndex(i);

            cTilda0.SetElementAtIndex(i, cTilda0.GetElementAtIndex(i) + cji * bji);
            cTilda1.SetElementAtIndex(i, cTilda1.GetElementAtIndex(i) + cji * aji);
        }
        for (usint i = sizeQl, idx = sizeQ; i < sizeQlP; i++, idx++) {
            const auto& cji = cj.GetElementAtIndex(i);
            const auto& aji = aj.GetElementAtIndex(idx);
            const auto& bji = bj.GetElementAtIndex(idx);

            cTilda0.SetElementAtIndex(i, cTilda0.GetElementAtIndex(i) + cji * bji);
            cTilda1.SetElementAtIndex(i, cTilda1.GetElementAtIndex(i) + cji * aji);
        }
    }

    return std::make_shared<std::vector<DCRTPoly>>(
        std::initializer_list<DCRTPoly>{std::move(cTilda0), std::move(cTilda1)});
}

}  // namespace lbcrypto