Program Listing for File keyswitch-hybrid.cpp

Return to documentation for file (pke/lib/keyswitch/keyswitch-hybrid.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#include "ciphertext.h"
#include "key/evalkeyrelin.h"
#include "key/privatekey.h"
#include "key/publickey.h"
#include "keyswitch/keyswitch-hybrid.h"
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"

namespace lbcrypto {

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                        const PrivateKey<DCRTPoly> newKey) const {
    return KeySwitchHYBRID::KeySwitchGenInternal(oldKey, newKey, nullptr);
}

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                        const PrivateKey<DCRTPoly> newKey,
                                                        const EvalKey<DCRTPoly> ekPrev) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newKey->GetCryptoParameters());
    const auto& paramsQ     = cryptoParams->GetElementParams();
    const auto& paramsQP    = cryptoParams->GetParamsQP();
    const auto& pparamsQP   = paramsQP->GetParams();

    // skNew is currently in basis Q. This extends it to basis QP.

    DCRTPoly sNewExt(paramsQP, Format::EVALUATION, true);
    const auto& sNew = newKey->GetPrivateElement();

    auto sNew0 = sNew.GetElementAtIndex(0);
    sNew0.SetFormat(Format::COEFFICIENT);

    const uint32_t sizeQ  = paramsQ->GetParams().size();
    const uint32_t sizeQP = paramsQP->GetParams().size();

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeQP))
    for (uint32_t i = 0; i < sizeQP; ++i) {
        if (i < sizeQ) {
            auto tmp = sNew.GetElementAtIndex(i);
            tmp.SetFormat(Format::EVALUATION);
            sNewExt.SetElementAtIndex(i, std::move(tmp));
        }
        else {
            auto tmp = sNew0;
            tmp.SwitchModulus(pparamsQP[i]->GetModulus(), pparamsQP[i]->GetRootOfUnity(), 0, 0);
            tmp.SetFormat(Format::EVALUATION);
            sNewExt.SetElementAtIndex(i, std::move(tmp));
        }
    }

    const auto ns = cryptoParams->GetNoiseScale();

    const uint32_t numPerPartQ = cryptoParams->GetNumPerPartQ();
    const uint32_t numPartQ    = cryptoParams->GetNumPartQ();
    std::vector<DCRTPoly> av(numPartQ);
    std::vector<DCRTPoly> bv(numPartQ);

    DugType dug;
    auto dgg = cryptoParams->GetDiscreteGaussianGenerator();

    const auto& sOld  = oldKey->GetPrivateElement();
    const auto& PModq = cryptoParams->GetPModq();

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numPartQ)) private(dug, dgg)
    for (uint32_t part = 0; part < numPartQ; ++part) {
        auto a = (ekPrev == nullptr) ? DCRTPoly(dug, paramsQP, Format::EVALUATION) :  // single-key HE
                                       ekPrev->GetAVector()[part];                                      // threshold HE
        DCRTPoly e(dgg, paramsQP, Format::EVALUATION);
        DCRTPoly b(paramsQP, Format::EVALUATION, true);

        const uint32_t startPartIdx = numPerPartQ * part;
        const uint32_t endPartIdx   = (sizeQ > (startPartIdx + numPerPartQ)) ? (startPartIdx + numPerPartQ) : sizeQ;

        for (uint32_t i = 0; i < sizeQP; ++i) {
            const auto& ai  = a.GetElementAtIndex(i);
            const auto& ei  = e.GetElementAtIndex(i);
            const auto& sni = sNewExt.GetElementAtIndex(i);

            if (i < startPartIdx || i >= endPartIdx) {
                b.SetElementAtIndex(i, (-ai * sni) + (ns * ei));
            }
            else {
                const auto& soi = sOld.GetElementAtIndex(i);
                b.SetElementAtIndex(i, (-ai * sni) + (ns * ei) + (PModq[i] * soi));
            }
        }
        av[part] = std::move(a);
        bv[part] = std::move(b);
    }

    EvalKeyRelin<DCRTPoly> ek(std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext()));
    ek->SetAVector(std::move(av));
    ek->SetBVector(std::move(bv));
    ek->SetKeyTag(newKey->GetKeyTag());
    return ek;
}

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                        const PublicKey<DCRTPoly> newKey) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newKey->GetCryptoParameters());
    const auto& paramsQ     = cryptoParams->GetElementParams();
    const auto& paramsQP    = cryptoParams->GetParamsQP();

    const uint32_t sizeQ  = paramsQ->GetParams().size();
    const uint32_t sizeQP = paramsQP->GetParams().size();

    const auto ns = cryptoParams->GetNoiseScale();

    const uint32_t numPerPartQ = cryptoParams->GetNumPerPartQ();
    const uint32_t numPartQ    = cryptoParams->GetNumPartQ();
    std::vector<DCRTPoly> av(numPartQ);
    std::vector<DCRTPoly> bv(numPartQ);

    TugType tug;
    auto dgg = cryptoParams->GetDiscreteGaussianGenerator();

    const auto& sOld  = oldKey->GetPrivateElement();
    const auto& newp0 = newKey->GetPublicElements().at(0);
    const auto& newp1 = newKey->GetPublicElements().at(1);
    const auto& PModq = cryptoParams->GetPModq();

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numPartQ)) private(dgg, tug)
    for (uint32_t part = 0; part < numPartQ; ++part) {
        auto u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ? DCRTPoly(dgg, paramsQP, Format::EVALUATION) :
                                                                  DCRTPoly(tug, paramsQP, Format::EVALUATION);
        DCRTPoly e0(dgg, paramsQP, Format::EVALUATION);
        DCRTPoly e1(dgg, paramsQP, Format::EVALUATION);
        DCRTPoly a(paramsQP, Format::EVALUATION, true);
        DCRTPoly b(paramsQP, Format::EVALUATION, true);

        // starting and ending position of current part
        const uint32_t startPartIdx = numPerPartQ * part;
        const uint32_t endPartIdx   = (sizeQ > startPartIdx + numPerPartQ) ? (startPartIdx + numPerPartQ) : sizeQ;

        for (uint32_t i = 0; i < sizeQP; ++i) {
            const auto& ui = u.GetElementAtIndex(i);

            const auto& e0i = e0.GetElementAtIndex(i);
            const auto& e1i = e1.GetElementAtIndex(i);

            const auto& newp0i = newp0.GetElementAtIndex(i);
            const auto& newp1i = newp1.GetElementAtIndex(i);

            a.SetElementAtIndex(i, newp1i * ui + ns * e1i);

            if (i < startPartIdx || i >= endPartIdx) {
                b.SetElementAtIndex(i, (newp0i * ui) + (ns * e0i));
            }
            else {
                const auto& soi = sOld.GetElementAtIndex(i);
                b.SetElementAtIndex(i, (newp0i * ui) + (ns * e0i) + (PModq[i] * soi));
            }
        }
        av[part] = std::move(a);
        bv[part] = std::move(b);
    }

    EvalKeyRelin<DCRTPoly> ek = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext());
    ek->SetAVector(std::move(av));
    ek->SetBVector(std::move(bv));
    ek->SetKeyTag(newKey->GetKeyTag());
    return ek;
}

void KeySwitchHYBRID::KeySwitchInPlace(Ciphertext<DCRTPoly>& ciphertext, const EvalKey<DCRTPoly> ek) const {
    auto& cv = ciphertext->GetElements();
    auto ba  = KeySwitchCore(cv.back(), ek);

    cv[0].SetFormat(Format::EVALUATION);
    cv[0] += (*ba)[0];

    if (cv.size() > 2) {
        cv[1].SetFormat(Format::EVALUATION);
        cv[1] += (*ba)[1];
    }
    else {
        cv[1] = (*ba)[1];
    }

    cv.resize(2);
}

Ciphertext<DCRTPoly> KeySwitchHYBRID::KeySwitchExt(ConstCiphertext<DCRTPoly> ciphertext, bool addFirst) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());

    const auto& cv    = ciphertext->GetElements();
    const auto& PModq = cryptoParams->GetPModq();

    const auto paramsP   = cryptoParams->GetParamsP();
    const auto paramsQl  = cv[0].GetParams();
    const auto paramsQlP = cv[0].GetExtendedCRTBasis(paramsP);

    const uint32_t sizeCv = cv.size();
    const uint32_t sizeQl = paramsQl->GetParams().size();
    std::vector<DCRTPoly> elements(sizeCv);

    for (uint32_t k = 0; k < sizeCv; ++k) {
        elements[k] = DCRTPoly(paramsQlP, Format::EVALUATION, true);
        if ((addFirst) || (k > 0)) {
            auto cMult = cv[k].TimesNoCheck(PModq);
            for (uint32_t i = 0; i < sizeQl; ++i) {
                elements[k].SetElementAtIndex(i, std::move(cMult.GetElementAtIndex(i)));
            }
        }
    }

    auto result = ciphertext->CloneEmpty();
    result->SetElements(std::move(elements));
    return result;
}

Ciphertext<DCRTPoly> KeySwitchHYBRID::KeySwitchDown(ConstCiphertext<DCRTPoly> ciphertext) const {
    const auto& cv       = ciphertext->GetElements();
    const auto paramsQlP = cv[0].GetParams();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    const auto paramsP      = cryptoParams->GetParamsP();

    // TODO : (Andrey) precompute paramsQl in cryptoparameters
    const uint32_t sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size();
    std::vector<NativeInteger> moduliQ(sizeQl);
    std::vector<NativeInteger> rootsQ(sizeQl);
    for (uint32_t i = 0; i < sizeQl; ++i) {
        moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
        rootsQ[i]  = paramsQlP->GetParams()[i]->GetRootOfUnity();
    }
    const auto paramsQl = std::make_shared<ParmType>(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ);

    const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

    std::vector<DCRTPoly> elements(2);
    elements[0] = cv[0].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                      cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                      cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                      cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                      cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());
    elements[1] = cv[1].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                      cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                      cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                      cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                      cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());

    auto result = ciphertext->CloneEmpty();
    result->SetElements(std::move(elements));
    return result;
}

DCRTPoly KeySwitchHYBRID::KeySwitchDownFirstElement(ConstCiphertext<DCRTPoly> ciphertext) const {
    const auto& cv       = ciphertext->GetElements()[0];
    const auto paramsQlP = cv.GetParams();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    const auto paramsP      = cryptoParams->GetParamsP();

    // TODO : (Andrey) precompute paramsQl in cryptoparameters
    const uint32_t sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size();
    std::vector<NativeInteger> moduliQ(sizeQl);
    std::vector<NativeInteger> rootsQ(sizeQl);
    for (uint32_t i = 0; i < sizeQl; ++i) {
        moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
        rootsQ[i]  = paramsQlP->GetParams()[i]->GetRootOfUnity();
    }
    const auto paramsQl = std::make_shared<ParmType>(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ);

    const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

    return cv.ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                            cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                            cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                            cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                            cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::KeySwitchCore(const DCRTPoly& a,
                                                                      const EvalKey<DCRTPoly> evalKey) const {
    return EvalFastKeySwitchCore(EvalKeySwitchPrecomputeCore(a, evalKey->GetCryptoParameters()), evalKey,
                                 a.GetParams());
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalKeySwitchPrecomputeCore(
    const DCRTPoly& c, std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParamsBase) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(cryptoParamsBase);

    const auto paramsQl  = c.GetParams();
    const auto paramsP   = cryptoParams->GetParamsP();
    const auto paramsQlP = c.GetExtendedCRTBasis(paramsP);

    const uint32_t sizeQl  = paramsQl->GetParams().size();
    const uint32_t sizeP   = paramsP->GetParams().size();
    const uint32_t sizeQlP = sizeQl + sizeP;
    const uint32_t alpha   = cryptoParams->GetNumPerPartQ();
    // The number of digits of the current ciphertext
    uint32_t numPartQl = std::ceil(static_cast<double>(sizeQl) / alpha);
    if (numPartQl > cryptoParams->GetNumberOfQPartitions())
        numPartQl = cryptoParams->GetNumberOfQPartitions();

    auto result = std::make_shared<std::vector<DCRTPoly>>(numPartQl);

    // Digit decomposition
    // Zero-padding and split
    for (uint32_t part = 0; part < numPartQl; ++part) {
        DCRTPoly partsCt;
        if (part == numPartQl - 1) {
            const auto& paramsPartQ = cryptoParams->GetParamsPartQ(part);

            uint32_t sizePartQl = sizeQl - alpha * part;
            std::vector<NativeInteger> moduli(sizePartQl);
            std::vector<NativeInteger> roots(sizePartQl);
            for (uint32_t i = 0; i < sizePartQl; ++i) {
                moduli[i] = paramsPartQ->GetParams()[i]->GetModulus();
                roots[i]  = paramsPartQ->GetParams()[i]->GetRootOfUnity();
            }
            auto&& params = std::make_shared<ParmType>(paramsPartQ->GetCyclotomicOrder(), moduli, roots);
            partsCt       = DCRTPoly(params, Format::EVALUATION, true);
        }
        else {
            partsCt = DCRTPoly(cryptoParams->GetParamsPartQ(part), Format::EVALUATION, true);
        }

        const uint32_t sizePartQl   = partsCt.GetNumOfElements();
        const uint32_t startPartIdx = alpha * part;
        for (uint32_t i = 0, idx = startPartIdx; i < sizePartQl; ++i, ++idx)
            partsCt.SetElementAtIndex(i, c.GetElementAtIndex(idx));

        partsCt.SetFormat(Format::COEFFICIENT);
        auto partsCtCompl = partsCt.ApproxSwitchCRTBasis(cryptoParams->GetParamsPartQ(part),
                                                         cryptoParams->GetParamsComplPartQ(sizeQl - 1, part),
                                                         cryptoParams->GetPartQlHatInvModq(part, sizePartQl - 1),
                                                         cryptoParams->GetPartQlHatInvModqPrecon(part, sizePartQl - 1),
                                                         cryptoParams->GetPartQlHatModp(sizeQl - 1, part),
                                                         cryptoParams->GetmodComplPartqBarrettMu(sizeQl - 1, part));
        partsCtCompl.SetFormat(Format::EVALUATION);

        (*result)[part] = DCRTPoly(paramsQlP, Format::EVALUATION, true);

        const uint32_t endPartIdx = startPartIdx + sizePartQl;
        for (uint32_t i = 0; i < startPartIdx; ++i)
            (*result)[part].SetElementAtIndex(i, std::move(partsCtCompl.GetElementAtIndex(i)));
        for (uint32_t i = startPartIdx; i < endPartIdx; ++i)
            (*result)[part].SetElementAtIndex(i, c.GetElementAtIndex(i));
        for (uint32_t i = endPartIdx; i < sizeQlP; ++i)
            (*result)[part].SetElementAtIndex(i, std::move(partsCtCompl.GetElementAtIndex(i - sizePartQl)));
    }
    return result;
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalFastKeySwitchCore(
    const std::shared_ptr<std::vector<DCRTPoly>> digits, const EvalKey<DCRTPoly> evalKey,
    const std::shared_ptr<ParmType> paramsQl) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(evalKey->GetCryptoParameters());

    const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

    auto result  = EvalFastKeySwitchCoreExt(digits, evalKey, paramsQl);
    (*result)[0] = (*result)[0].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                              cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                              cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                              cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                              cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());
    (*result)[1] = (*result)[1].ApproxModDown(paramsQl, cryptoParams->GetParamsP(), cryptoParams->GetPInvModq(),
                                              cryptoParams->GetPInvModqPrecon(), cryptoParams->GetPHatInvModp(),
                                              cryptoParams->GetPHatInvModpPrecon(), cryptoParams->GetPHatModq(),
                                              cryptoParams->GetModqBarrettMu(), cryptoParams->GettInvModp(),
                                              cryptoParams->GettInvModpPrecon(), t, cryptoParams->GettModqPrecon());
    return result;
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalFastKeySwitchCoreExt(
    const std::shared_ptr<std::vector<DCRTPoly>> digits, const EvalKey<DCRTPoly> evalKey,
    const std::shared_ptr<ParmType> paramsQl) const {
    const auto paramsQlP   = (*digits)[0].GetParams();
    const uint32_t sizeQlP = paramsQlP->GetParams().size();

    const uint32_t limit  = digits->size();
    const uint32_t sizeQl = paramsQl->GetParams().size();
    auto&& cryptoParams   = std::dynamic_pointer_cast<CryptoParametersRNS>(evalKey->GetCryptoParameters());
    const uint32_t delta  = cryptoParams->GetElementParams()->GetParams().size() - sizeQl;

    const auto& av = evalKey->GetAVector();
    const auto& bv = evalKey->GetBVector();

    auto result = std::make_shared<std::vector<DCRTPoly>>();
    result->reserve(2);
    result->emplace_back(paramsQlP, Format::EVALUATION, true);
    result->emplace_back(paramsQlP, Format::EVALUATION, true);
    auto& elements = (*result);

    for (uint32_t j = 0; j < limit; ++j) {
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(sizeQlP))
        for (uint32_t i = 0; i < sizeQlP; ++i) {
            const auto idx  = (i >= sizeQl) ? i + delta : i;
            const auto& cji = (*digits)[j].GetElementAtIndex(i);
            const auto& bji = bv[j].GetElementAtIndex(idx);
            const auto& aji = av[j].GetElementAtIndex(idx);
            elements[0].SetElementAtIndex(i, elements[0].GetElementAtIndex(i) + cji * bji);
            elements[1].SetElementAtIndex(i, elements[1].GetElementAtIndex(i) + cji * aji);
        }
    }

    return result;
}

}  // namespace lbcrypto