Program Listing for File keyswitch-bv.cpp

Return to documentation for file (pke/lib/keyswitch/keyswitch-bv.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#define PROFILE

#include "keyswitch/keyswitch-bv.h"

#include "key/privatekey.h"
#include "key/publickey.h"
#include "key/evalkeyrelin.h"
#include "schemerns/rns-cryptoparameters.h"
#include "cryptocontext.h"

namespace lbcrypto {

EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                    const PrivateKey<DCRTPoly> newKey) const {
    EvalKeyRelin<DCRTPoly> ek(std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext()));

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newKey->GetCryptoParameters());

    const DCRTPoly& sNew = newKey->GetPrivateElement();
    auto elementParams   = sNew.GetParams();
    const DCRTPoly& sOld = oldKey->GetPrivateElement();

    const auto ns      = cryptoParams->GetNoiseScale();
    const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
    DugType dug;

    usint digitSize = cryptoParams->GetDigitSize();

    usint sizeSOld = sOld.GetNumOfElements();
    usint nWindows = 0;
    std::vector<usint> arrWindows;
    arrWindows.reserve(sizeSOld);
    if (digitSize > 0) {
        // creates an array of digits up to a certain tower
        for (usint i = 0; i < sizeSOld; i++) {
            usint sOldMSB    = sOld.GetElementAtIndex(i).GetModulus().GetLengthForBase(2);
            usint curWindows = sOldMSB / digitSize;
            if (sOldMSB % digitSize > 0)
                curWindows++;
            arrWindows.push_back(nWindows);
            nWindows += curWindows;
        }
    }
    else {
        nWindows = sizeSOld;
    }

    std::vector<DCRTPoly> av(nWindows);
    std::vector<DCRTPoly> bv(nWindows);

    if (digitSize > 0) {
        for (usint i = 0; i < sOld.GetNumOfElements(); i++) {
            std::vector<DCRTPoly::PolyType> sOldDecomposed = sOld.GetElementAtIndex(i).PowersOfBase(digitSize);

            for (usint k = 0; k < sOldDecomposed.size(); k++) {
                DCRTPoly filtered(elementParams, Format::EVALUATION, true);
                filtered.SetElementAtIndex(i, sOldDecomposed[k]);

                DCRTPoly a(dug, elementParams, Format::EVALUATION);
                DCRTPoly e(dgg, elementParams, Format::EVALUATION);

                av[k + arrWindows[i]] = std::move(a);
                bv[k + arrWindows[i]] = filtered - (av[k + arrWindows[i]] * sNew + ns * e);
            }
        }
    }
    else {
        for (usint i = 0; i < sOld.GetNumOfElements(); i++) {
            DCRTPoly filtered(elementParams, Format::EVALUATION, true);
            filtered.SetElementAtIndex(i, sOld.GetElementAtIndex(i));

            DCRTPoly a(dug, elementParams, Format::EVALUATION);
            DCRTPoly e(dgg, elementParams, Format::EVALUATION);

            av[i] = std::move(a);
            bv[i] = filtered - (av[i] * sNew + ns * e);
        }
    }

    ek->SetAVector(std::move(av));
    ek->SetBVector(std::move(bv));
    ek->SetKeyTag(newKey->GetKeyTag());

    return ek;
}

EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
                                                    const PrivateKey<DCRTPoly> newKey,
                                                    const EvalKey<DCRTPoly> ek) const {
    EvalKeyRelin<DCRTPoly> evalKey(std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newKey->GetCryptoContext()));

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(oldKey->GetCryptoParameters());

    const DCRTPoly& sNew = newKey->GetPrivateElement();
    auto elementParams   = sNew.GetParams();
    DCRTPoly sOld        = oldKey->GetPrivateElement();
    sOld.DropLastElements(oldKey->GetCryptoContext()->GetKeyGenLevel());

    const auto ns      = cryptoParams->GetNoiseScale();
    const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
    DugType dug;

    usint digitSize = cryptoParams->GetDigitSize();

    usint sizeSOld = sOld.GetNumOfElements();
    usint nWindows = 0;
    std::vector<usint> arrWindows;
    arrWindows.reserve(sizeSOld);
    if (digitSize > 0) {
        // creates an array of digits up to a certain tower
        for (usint i = 0; i < sizeSOld; i++) {
            usint sOldMSB    = sOld.GetElementAtIndex(i).GetModulus().GetLengthForBase(2);
            usint curWindows = sOldMSB / digitSize;
            if (sOldMSB % digitSize > 0)
                curWindows++;
            arrWindows.push_back(nWindows);
            nWindows += curWindows;
        }
    }
    else {
        nWindows = sizeSOld;
    }

    std::vector<DCRTPoly> av(nWindows);
    std::vector<DCRTPoly> bv(nWindows);

    if (digitSize > 0) {
        for (usint i = 0; i < sizeSOld; i++) {
            std::vector<DCRTPoly::PolyType> sOldDecomposed = sOld.GetElementAtIndex(i).PowersOfBase(digitSize);

            for (usint k = 0; k < sOldDecomposed.size(); k++) {
                DCRTPoly filtered(elementParams, Format::EVALUATION, true);
                filtered.SetElementAtIndex(i, sOldDecomposed[k]);

                if (ek == nullptr) {  // single-key HE
                    // Generate a_i vectors
                    av[k + arrWindows[i]] = DCRTPoly(dug, elementParams, Format::EVALUATION);
                }
                else {  // threshold HE
                    av[k + arrWindows[i]] = ek->GetAVector()[k + arrWindows[i]];
                }

                DCRTPoly e(dgg, elementParams, Format::EVALUATION);
                bv[k + arrWindows[i]] = filtered - (av[k + arrWindows[i]] * sNew + ns * e);
            }
        }
    }
    else {
        for (usint i = 0; i < sizeSOld; i++) {
            DCRTPoly filtered(elementParams, Format::EVALUATION, true);
            filtered.SetElementAtIndex(i, sOld.GetElementAtIndex(i));

            if (ek == nullptr) {  // single-key HE
                // Generate a_i vectors
                av[i] = DCRTPoly(dug, elementParams, Format::EVALUATION);
            }
            else {  // threshold HE
                av[i] = ek->GetAVector()[i];
            }

            DCRTPoly e(dgg, elementParams, Format::EVALUATION);
            bv[i] = filtered - (av[i] * sNew + ns * e);
        }
    }

    evalKey->SetAVector(std::move(av));
    evalKey->SetBVector(std::move(bv));
    evalKey->SetKeyTag(newKey->GetKeyTag());

    return evalKey;
}

EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldSk,
                                                    const PublicKey<DCRTPoly> newPk) const {
    EvalKeyRelin<DCRTPoly> ek = std::make_shared<EvalKeyRelinImpl<DCRTPoly>>(newPk->GetCryptoContext());

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(newPk->GetCryptoParameters());

    const auto ns                = cryptoParams->GetNoiseScale();
    const DCRTPoly::DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
    DCRTPoly::TugType tug;

    const DCRTPoly& sOld = oldSk->GetPrivateElement();

    std::vector<DCRTPoly> av;
    std::vector<DCRTPoly> bv;

    uint32_t digitSize = cryptoParams->GetDigitSize();

    const DCRTPoly& newp0 = newPk->GetPublicElements().at(0);
    const DCRTPoly& newp1 = newPk->GetPublicElements().at(1);
    auto elementParams    = newp0.GetParams();

    if (digitSize > 0) {
        av.reserve(sOld.GetNumOfElements() * digitSize);
        bv.reserve(sOld.GetNumOfElements() * digitSize);
        for (usint i = 0; i < sOld.GetNumOfElements(); i++) {
            std::vector<DCRTPoly::PolyType> sOldDecomposed = sOld.GetElementAtIndex(i).PowersOfBase(digitSize);

            for (size_t k = 0; k < sOldDecomposed.size(); k++) {
                DCRTPoly filtered(elementParams, Format::EVALUATION, true);
                filtered.SetElementAtIndex(i, sOldDecomposed[k]);

                DCRTPoly u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ?
                                 DCRTPoly(dgg, elementParams, Format::EVALUATION) :
                                 DCRTPoly(tug, elementParams, Format::EVALUATION);

                DCRTPoly e0(dgg, elementParams, Format::EVALUATION);
                DCRTPoly c0 = newp0 * u + ns * e0 + filtered;
                bv.push_back(std::move(c0));

                DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
                DCRTPoly c1 = newp1 * u + ns * e1;
                av.push_back(std::move(c1));
            }
        }
    }
    else {
        av.reserve(sOld.GetNumOfElements());
        bv.reserve(sOld.GetNumOfElements());
        for (usint i = 0; i < sOld.GetNumOfElements(); i++) {
            DCRTPoly filtered(elementParams, Format::EVALUATION, true);
            filtered.SetElementAtIndex(i, sOld.GetElementAtIndex(i));

            DCRTPoly u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ?
                             DCRTPoly(dgg, elementParams, Format::EVALUATION) :
                             DCRTPoly(tug, elementParams, Format::EVALUATION);

            DCRTPoly e0(dgg, elementParams, Format::EVALUATION);
            DCRTPoly c0 = newp0 * u + ns * e0 + filtered;
            bv.push_back(std::move(c0));

            DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
            DCRTPoly c1 = newp1 * u + ns * e1;
            av.push_back(std::move(c1));
        }
    }

    ek->SetAVector(std::move(av));
    ek->SetBVector(std::move(bv));
    ek->SetKeyTag(newPk->GetKeyTag());

    return ek;
}

void KeySwitchBV::KeySwitchInPlace(Ciphertext<DCRTPoly>& ciphertext, const EvalKey<DCRTPoly> ek) const {
    std::vector<DCRTPoly>& cv = ciphertext->GetElements();

    std::shared_ptr<std::vector<DCRTPoly>> ba = (cv.size() == 2) ? KeySwitchCore(cv[1], ek) : KeySwitchCore(cv[2], ek);

    cv[0].SetFormat(Format::EVALUATION);
    cv[0] += (*ba)[0];

    if (cv.size() > 2) {
        cv[1].SetFormat(Format::EVALUATION);
        cv[1] += (*ba)[1];
    }
    else {
        cv[1] = (*ba)[1];
    }

    cv.resize(2);
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchBV::KeySwitchCore(const DCRTPoly& a,
                                                                  const EvalKey<DCRTPoly> evalKey) const {
    return EvalFastKeySwitchCore(EvalKeySwitchPrecomputeCore(a, evalKey->GetCryptoParameters()), evalKey,
                                 a.GetParams());
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchBV::EvalKeySwitchPrecomputeCore(
    const DCRTPoly& c, std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParamsBase) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(cryptoParamsBase);
    return std::make_shared<std::vector<DCRTPoly>>(c.CRTDecompose(cryptoParams->GetDigitSize()));
}

std::shared_ptr<std::vector<DCRTPoly>> KeySwitchBV::EvalFastKeySwitchCore(
    const std::shared_ptr<std::vector<DCRTPoly>> digits, const EvalKey<DCRTPoly> evalKey,
    const std::shared_ptr<ParmType> paramsQl) const {
    std::vector<DCRTPoly> bv(evalKey->GetBVector());
    std::vector<DCRTPoly> av(evalKey->GetAVector());

    auto sizeQ    = bv[0].GetParams()->GetParams().size();
    auto sizeQl   = paramsQl->GetParams().size();
    size_t diffQl = sizeQ - sizeQl;

    for (size_t k = 0; k < bv.size(); k++) {
        av[k].DropLastElements(diffQl);
        bv[k].DropLastElements(diffQl);
    }

    DCRTPoly ct1 = (av[0] *= (*digits)[0]);
    DCRTPoly ct0 = (bv[0] *= (*digits)[0]);

    for (usint i = 1; i < (*digits).size(); ++i) {
        ct0 += (bv[i] *= (*digits)[i]);
        ct1 += (av[i] *= (*digits)[i]);
    }

    return std::make_shared<std::vector<DCRTPoly>>(std::initializer_list<DCRTPoly>{std::move(ct0), std::move(ct1)});
}

}  // namespace lbcrypto