Program Listing for File cryptocontext.cpp

Return to documentation for file (pke/lib/cryptocontext.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  Control for encryption operations
 */

#include "cryptocontext.h"
#include "key/privatekey.h"
#include "key/publickey.h"
#include "math/chebyshev.h"
#include "scheme/ckksrns/ckksrns-cryptoparameters.h"
#include "schemerns/rns-scheme.h"

namespace lbcrypto {

template <typename Element>
std::map<std::string, std::vector<EvalKey<Element>>> CryptoContextImpl<Element>::s_evalMultKeyMap{};
template <typename Element>
std::map<std::string, std::shared_ptr<std::map<uint32_t, EvalKey<Element>>>>
    CryptoContextImpl<Element>::s_evalAutomorphismKeyMap{};

template <typename Element>
void CryptoContextImpl<Element>::ClearStaticMapsAndVectors() {
    CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.clear();
    CryptoContextImpl<Element>::s_evalMultKeyMap.clear();
    PackedEncoding::Destroy();
    NatChineseRemainderTransformFTT<NativeVector>().Reset();
#ifdef WITH_BE2
    bigintfxd::ChineseRemainderTransformFTTFxd<M2Vector>().Reset();
#endif
#ifdef WITH_BE4
    bigintdyn::ChineseRemainderTransformFTTDyn<M4Vector>().Reset();
#endif
#ifdef WITH_NTL
    NTL::ChineseRemainderTransformFTTNtl<M6Vector>().Reset();
#endif
}

template <typename Element>
void CryptoContextImpl<Element>::SetKSTechniqueInScheme() {
    // check if the scheme is an RNS scheme
    auto schemeRNSPtr = std::dynamic_pointer_cast<SchemeRNS>(m_scheme);
    if (schemeRNSPtr == nullptr)
        OPENFHE_THROW("The scheme is not RNS-based");

    // check if the parameter object is RNS-based
    auto elPtr = std::dynamic_pointer_cast<const CryptoParametersRNS>(m_params);
    if (elPtr == nullptr)
        OPENFHE_THROW("The parameter object is not RNS-based");

    schemeRNSPtr->SetKeySwitchingTechnique(elPtr->GetKeySwitchTechnique());
}

// SHE MULTIPLICATION
template <typename Element>
void CryptoContextImpl<Element>::EvalMultKeyGen(const PrivateKey<Element>& key) {
    ValidateKey(key);
    if (CryptoContextImpl<Element>::s_evalMultKeyMap.find(key->GetKeyTag()) ==
        CryptoContextImpl<Element>::s_evalMultKeyMap.end()) {
        // the key is not found in the map, so the key has to be generated
        CryptoContextImpl<Element>::s_evalMultKeyMap[key->GetKeyTag()] = {GetScheme()->EvalMultKeyGen(key)};
    }
}

template <typename Element>
void CryptoContextImpl<Element>::EvalMultKeysGen(const PrivateKey<Element>& key) {
    ValidateKey(key);
    if (CryptoContextImpl<Element>::s_evalMultKeyMap.find(key->GetKeyTag()) ==
        CryptoContextImpl<Element>::s_evalMultKeyMap.end()) {
        // the key is not found in the map, so the key has to be generated
        CryptoContextImpl<Element>::s_evalMultKeyMap[key->GetKeyTag()] = GetScheme()->EvalMultKeysGen(key);
    }
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalMultKeys() {
    CryptoContextImpl<Element>::s_evalMultKeyMap.clear();
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalMultKeys(const std::string& keyTag) {
    auto kd = CryptoContextImpl<Element>::s_evalMultKeyMap.find(keyTag);
    if (kd != CryptoContextImpl<Element>::s_evalMultKeyMap.end())
        CryptoContextImpl<Element>::s_evalMultKeyMap.erase(kd);
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalMultKeys(const CryptoContext<Element>& cc) {
    for (auto it = CryptoContextImpl<Element>::s_evalMultKeyMap.begin();
         it != CryptoContextImpl<Element>::s_evalMultKeyMap.end();) {
        if (it->second[0]->GetCryptoContext() == cc) {
            it = CryptoContextImpl<Element>::s_evalMultKeyMap.erase(it);
        }
        else {
            ++it;
        }
    }
}

template <typename Element>
void CryptoContextImpl<Element>::InsertEvalMultKey(const std::vector<EvalKey<Element>>& vectorToInsert,
                                                   const std::string& keyTag) {
    const std::string& tag = (keyTag.empty()) ? vectorToInsert[0]->GetKeyTag() : keyTag;
    if (CryptoContextImpl<Element>::s_evalMultKeyMap.find(tag) != CryptoContextImpl<Element>::s_evalMultKeyMap.end()) {
        // we do not allow to override the existing key vector if its keyTag is identical to the keyTag of the new keys
        OPENFHE_THROW("Can not save a EvalMultKeys vector as there is a key vector for the given keyTag");
    }
    CryptoContextImpl<Element>::s_evalMultKeyMap[tag] = vectorToInsert;
}

// ADVANCED SHE

template <typename Element>
void CryptoContextImpl<Element>::EvalSumKeyGen(const PrivateKey<Element> privateKey) {
    ValidateKey(privateKey);
    auto&& evalKeys = GetScheme()->EvalSumKeyGen(privateKey);
    CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::EvalSumRowsKeyGen(
    const PrivateKey<Element> privateKey, uint32_t rowSize, uint32_t subringDim) {
    ValidateKey(privateKey);
    std::vector<uint32_t> indices;
    auto&& evalKeys = GetScheme()->EvalSumRowsKeyGen(privateKey, rowSize, subringDim, indices);
    CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
    return CryptoContextImpl<Element>::GetPartialEvalAutomorphismKeyMapPtr(privateKey->GetKeyTag(), indices);
}

// TODO: this is here for backwards compatibility; should remove in v2.0
template <typename Element>
inline std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::EvalSumRowsKeyGen(
    const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey, uint32_t rowSize, uint32_t subringDim) {
    return CryptoContextImpl<Element>::EvalSumRowsKeyGen(privateKey, rowSize, subringDim);
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::EvalSumColsKeyGen(
    const PrivateKey<Element> privateKey) {
    ValidateKey(privateKey);
    std::vector<uint32_t> indices;
    auto&& evalKeys = GetScheme()->EvalSumColsKeyGen(privateKey, indices);
    CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
    return CryptoContextImpl<Element>::GetPartialEvalAutomorphismKeyMapPtr(privateKey->GetKeyTag(), indices);
}

template <typename Element>
const std::map<uint32_t, EvalKey<Element>>& CryptoContextImpl<Element>::GetEvalSumKeyMap(const std::string& keyTag) {
    return CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(keyTag);
}

template <typename Element>
std::map<std::string, std::vector<EvalKey<Element>>>& CryptoContextImpl<Element>::GetAllEvalMultKeys() {
    return CryptoContextImpl<Element>::s_evalMultKeyMap;
}

template <typename Element>
const std::vector<EvalKey<Element>>& CryptoContextImpl<Element>::GetEvalMultKeyVector(const std::string& keyTag) {
    auto ekv = CryptoContextImpl<Element>::s_evalMultKeyMap.find(keyTag);
    if (ekv == CryptoContextImpl<Element>::s_evalMultKeyMap.end()) {
        std::string errMsg(std::string("Call EvalMultKeyGen() to have EvalMultKey available for ID [") + keyTag + "].");
        OPENFHE_THROW(errMsg);
    }
    return ekv->second;
}

template <typename Element>
std::map<std::string, std::shared_ptr<std::map<uint32_t, EvalKey<Element>>>>&
CryptoContextImpl<Element>::GetAllEvalAutomorphismKeys() {
    return CryptoContextImpl<Element>::s_evalAutomorphismKeyMap;
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::GetEvalAutomorphismKeyMapPtr(
    const std::string& keyTag) {
    auto ekv = CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.find(keyTag);
    if (ekv == CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.end()) {
        OPENFHE_THROW("EvalAutomorphismKeys are not generated for ID [" + keyTag + "].");
    }
    return ekv->second;
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::GetPartialEvalAutomorphismKeyMapPtr(
    const std::string& keyTag, const std::vector<uint32_t>& indexList) {
    if (!indexList.size())
        OPENFHE_THROW("indexList is empty");

    std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> keyMap =
        CryptoContextImpl<Element>::GetEvalAutomorphismKeyMapPtr(keyTag);

    // create a return map if specific indices are provided
    std::map<uint32_t, EvalKey<Element>> retMap;
    for (uint32_t indx : indexList) {
        const auto it = keyMap->find(indx);
        if (it == keyMap->end()) {
            OPENFHE_THROW("Key is not generated for index [" + std::to_string(indx) + "] and keyTag [" + keyTag + "]");
        }
        retMap.emplace(indx, it->second);
    }
    return std::make_shared<std::map<uint32_t, EvalKey<Element>>>(retMap);
}

template <typename Element>
std::map<std::string, std::shared_ptr<std::map<uint32_t, EvalKey<Element>>>>&
CryptoContextImpl<Element>::GetAllEvalSumKeys() {
    return CryptoContextImpl<Element>::GetAllEvalAutomorphismKeys();
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalSumKeys() {
    CryptoContextImpl<Element>::ClearEvalAutomorphismKeys();
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalSumKeys(const std::string& keyTag) {
    CryptoContextImpl<Element>::ClearEvalAutomorphismKeys(keyTag);
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalSumKeys(const CryptoContext<Element> cc) {
    CryptoContextImpl<Element>::ClearEvalAutomorphismKeys(cc);
}

// SHE AUTOMORPHISM

template <typename Element>
void CryptoContextImpl<Element>::EvalAtIndexKeyGen(const PrivateKey<Element> privateKey,
                                                   const std::vector<int32_t>& indexList) {
    ValidateKey(privateKey);
    auto&& evalKeys = GetScheme()->EvalAtIndexKeyGen(privateKey, indexList);
    CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalAutomorphismKeys() {
    CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.clear();
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalAutomorphismKeys(const std::string& keyTag) {
    auto kd = CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.find(keyTag);
    if (kd != CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.end())
        CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.erase(kd);
}

template <typename Element>
void CryptoContextImpl<Element>::ClearEvalAutomorphismKeys(const CryptoContext<Element> cc) {
    for (auto it = CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.begin();
         it != CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.end();) {
        if (it->second->begin()->second->GetCryptoContext() == cc) {
            it = CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.erase(it);
        }
        else {
            ++it;
        }
    }
}

template <typename Element>
std::set<uint32_t> CryptoContextImpl<Element>::GetExistingEvalAutomorphismKeyIndices(const std::string& keyTag) {
    auto keyMapIt = CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.find(keyTag);
    if (keyMapIt == CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.end())
        // there is no keys for the given keyTag, return empty vector
        return std::set<uint32_t>();

    // get all inidices from the existing automorphism key map
    auto& keyMap = *(keyMapIt->second);
    std::set<uint32_t> indices;
    for (const auto& [key, _] : keyMap) {
        indices.insert(key);
    }

    return indices;
}

template <typename Element>
std::set<uint32_t> CryptoContextImpl<Element>::GetUniqueValues(const std::set<uint32_t>& oldValues,
                                                               const std::set<uint32_t>& newValues) {
    std::set<uint32_t> newUniqueValues;
    std::set_difference(newValues.begin(), newValues.end(), oldValues.begin(), oldValues.end(),
                        std::inserter(newUniqueValues, newUniqueValues.begin()));
    return newUniqueValues;
}

template <typename Element>
void CryptoContextImpl<Element>::InsertEvalAutomorphismKey(
    const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> mapToInsert, const std::string& keyTag) {
    // check if the map is empty
    if (mapToInsert->empty()) {
        return;
    }

    auto mapToInsertIt    = mapToInsert->begin();
    const std::string& id = (keyTag.empty()) ? mapToInsertIt->second->GetKeyTag() : keyTag;
    std::set<uint32_t> existingIndices{CryptoContextImpl<Element>::GetExistingEvalAutomorphismKeyIndices(id)};
    if (existingIndices.empty()) {
        // there is no keys for the given id, so we insert full mapToInsert
        CryptoContextImpl<Element>::s_evalAutomorphismKeyMap[id] = mapToInsert;
    }
    else {
        // get all indices from mapToInsert
        std::set<uint32_t> newIndices;
        for (const auto& [key, _] : *mapToInsert) {
            newIndices.insert(key);
        }

        // find all indices in mapToInsert that are not in the exising map and
        // insert those new indices and their corresponding keys to the existing map
        std::set<uint32_t> indicesToInsert{CryptoContextImpl<Element>::GetUniqueValues(existingIndices, newIndices)};
        auto keyMapIt = CryptoContextImpl<Element>::s_evalAutomorphismKeyMap.find(id);
        auto& keyMap  = *(keyMapIt->second);
        for (uint32_t indx : indicesToInsert) {
            keyMap[indx] = (*mapToInsert)[indx];
        }
    }
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalSum(ConstCiphertext<Element>& ciphertext,
                                                        uint32_t batchSize) const {
    ValidateCiphertext(ciphertext);
    auto&& evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());
    return GetScheme()->EvalSum(ciphertext, batchSize, evalSumKeys);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalSumRows(ConstCiphertext<Element>& ciphertext, uint32_t numRows,
                                                            const std::map<uint32_t, EvalKey<Element>>& evalSumKeys,
                                                            uint32_t subringDim) const {
    ValidateCiphertext(ciphertext);
    return GetScheme()->EvalSumRows(ciphertext, numRows, evalSumKeys, subringDim);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalSumCols(
    ConstCiphertext<Element>& ciphertext, uint32_t numCols,
    const std::map<uint32_t, EvalKey<Element>>& evalSumKeysRight) const {
    ValidateCiphertext(ciphertext);
    auto&& evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());
    return GetScheme()->EvalSumCols(ciphertext, numCols, evalSumKeys, evalSumKeysRight);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalAtIndex(ConstCiphertext<Element>& ciphertext, int32_t index) const {
    ValidateCiphertext(ciphertext);
    // If the index is zero, no rotation is needed, copy the ciphertext and return
    // This is done after the keyMap so that it is protected if there's not a valid key.
    if (0 == index)
        return ciphertext->Clone();
    auto&& evalAutomorphismKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());
    return GetScheme()->EvalAtIndex(ciphertext, index, evalAutomorphismKeys);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalMerge(
    const std::vector<Ciphertext<Element>>& ciphertextVector) const {
    if (0 == ciphertextVector.size())
        OPENFHE_THROW("Input ciphertext vector is empty");
    ValidateCiphertext(ciphertextVector[0]);
    auto evalAutomorphismKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertextVector[0]->GetKeyTag());
    return GetScheme()->EvalMerge(ciphertextVector, evalAutomorphismKeys);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalInnerProduct(ConstCiphertext<Element>& ct1,
                                                                 ConstCiphertext<Element>& ct2,
                                                                 uint32_t batchSize) const {
    ValidateCiphertext(ct1);
    if (ct2 == nullptr || ct1->GetKeyTag() != ct2->GetKeyTag())
        OPENFHE_THROW("Information was not generated with this crypto context");
    auto& evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ct1->GetKeyTag());
    auto& ek          = CryptoContextImpl<Element>::GetEvalMultKeyVector(ct1->GetKeyTag());
    return GetScheme()->EvalInnerProduct(ct1, ct2, batchSize, evalSumKeys, ek[0]);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalInnerProduct(ConstCiphertext<Element>& ct1, ConstPlaintext& ct2,
                                                                 uint32_t batchSize) const {
    ValidateCiphertext(ct1);
    if (ct2 == nullptr)
        OPENFHE_THROW("Information was not generated with this crypto context");
    auto& evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ct1->GetKeyTag());
    return GetScheme()->EvalInnerProduct(ct1, ct2, batchSize, evalSumKeys);
}

template <typename Element>
Plaintext CryptoContextImpl<Element>::GetPlaintextForDecrypt(PlaintextEncodings pte, std::shared_ptr<ParmType> evp,
                                                             EncodingParams ep, CKKSDataType cdt) {
    auto vp = std::make_shared<typename NativePoly::Params>(evp->GetCyclotomicOrder(), ep->GetPlaintextModulus(), 1);
    if (pte == CKKS_PACKED_ENCODING)
        return PlaintextFactory::MakePlaintext(pte, evp, ep, INVALID_SCHEME, cdt);
    return PlaintextFactory::MakePlaintext(pte, vp, ep);
}

template <typename Element>
DecryptResult CryptoContextImpl<Element>::Decrypt(ConstCiphertext<Element>& ciphertext,
                                                  const PrivateKey<Element>& privateKey, Plaintext* plaintext) {
    if (ciphertext == nullptr)
        OPENFHE_THROW("ciphertext is empty");
    if (plaintext == nullptr)
        OPENFHE_THROW("plaintext is empty");
    ValidateKey(privateKey);

    // determine which type of plaintext that you need to decrypt into
    // Plaintext decrypted =
    // CryptoContextImpl<Element>::GetPlaintextForDecrypt(ciphertext->GetEncodingType(),
    // this->GetElementParams(), this->GetEncodingParams());
    Plaintext decrypted = CryptoContextImpl<Element>::GetPlaintextForDecrypt(
        ciphertext->GetEncodingType(), ciphertext->GetElements()[0].GetParams(), this->GetEncodingParams(),
        this->GetCKKSDataType());

    DecryptResult result;

    if ((ciphertext->GetEncodingType() == CKKS_PACKED_ENCODING) && (typeid(Element) != typeid(NativePoly))) {
        result = GetScheme()->Decrypt(ciphertext, privateKey, &decrypted->GetElement<Poly>());
    }
    else {
        result = GetScheme()->Decrypt(ciphertext, privateKey, &decrypted->GetElement<NativePoly>());
    }

    if (result.isValid == false)  // TODO (dsuponit): why don't we throw an exception here?
        return result;

    decrypted->SetScalingFactorInt(result.scalingFactorInt);

    if (ciphertext->GetEncodingType() == CKKS_PACKED_ENCODING) {
        auto decryptedCKKS = std::dynamic_pointer_cast<CKKSPackedEncoding>(decrypted);
        decryptedCKKS->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg());
        decryptedCKKS->SetLevel(ciphertext->GetLevel());
        decryptedCKKS->SetScalingFactor(ciphertext->GetScalingFactor());
        decryptedCKKS->SetSlots(ciphertext->GetSlots());

        const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersRNS>(this->GetCryptoParameters());

        decryptedCKKS->Decode(ciphertext->GetNoiseScaleDeg(), ciphertext->GetScalingFactor(),
                              cryptoParamsCKKS->GetScalingTechnique(), cryptoParamsCKKS->GetExecutionMode());
    }
    else {
        decrypted->Decode();
    }

    *plaintext = std::move(decrypted);
    return result;
}

//------------------------------------------------------------------------------
// Advanced SHE CHEBYSHEV SERIES EXAMPLES
//------------------------------------------------------------------------------

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalChebyshevFunction(std::function<double(double)> func,
                                                                      ConstCiphertext<Element>& ciphertext, double a,
                                                                      double b, uint32_t degree) const {
    std::vector<double> coefficients = EvalChebyshevCoefficients(func, a, b, degree);
    return EvalChebyshevSeries(ciphertext, coefficients, a, b);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalSin(ConstCiphertext<Element>& ciphertext, double a, double b,
                                                        uint32_t degree) const {
    return EvalChebyshevFunction([](double x) -> double { return std::sin(x); }, ciphertext, a, b, degree);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalCos(ConstCiphertext<Element>& ciphertext, double a, double b,
                                                        uint32_t degree) const {
    return EvalChebyshevFunction([](double x) -> double { return std::cos(x); }, ciphertext, a, b, degree);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalLogistic(ConstCiphertext<Element>& ciphertext, double a, double b,
                                                             uint32_t degree) const {
    return EvalChebyshevFunction([](double x) -> double { return 1 / (1 + std::exp(-x)); }, ciphertext, a, b, degree);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalDivide(ConstCiphertext<Element>& ciphertext, double a, double b,
                                                           uint32_t degree) const {
    return EvalChebyshevFunction([](double x) -> double { return 1 / x; }, ciphertext, a, b, degree);
}

}  // namespace lbcrypto

// the code below is from cryptocontext-impl.cpp
namespace lbcrypto {

template <>
Plaintext CryptoContextImpl<DCRTPoly>::GetPlaintextForDecrypt(PlaintextEncodings pte, std::shared_ptr<ParmType> evp,
                                                              EncodingParams ep, CKKSDataType cdt) {
    if ((pte == CKKS_PACKED_ENCODING) && (evp->GetParams().size() > 1)) {
        auto vp = std::make_shared<typename Poly::Params>(evp->GetCyclotomicOrder(), ep->GetPlaintextModulus(), 1);
        return PlaintextFactory::MakePlaintext(pte, vp, ep, INVALID_SCHEME, cdt);
    }
    else {
        auto vp =
            std::make_shared<typename NativePoly::Params>(evp->GetCyclotomicOrder(), ep->GetPlaintextModulus(), 1);
        return PlaintextFactory::MakePlaintext(pte, vp, ep, INVALID_SCHEME, cdt);
    }
}

template <>
DecryptResult CryptoContextImpl<DCRTPoly>::Decrypt(ConstCiphertext<DCRTPoly>& ciphertext,
                                                   const PrivateKey<DCRTPoly>& privateKey, Plaintext* plaintext) {
    if (ciphertext == nullptr)
        OPENFHE_THROW("ciphertext is empty");
    if (plaintext == nullptr)
        OPENFHE_THROW("plaintext is empty");
    if (privateKey == nullptr || Mismatched(privateKey->GetCryptoContext()))
        OPENFHE_THROW("Information was not generated with this crypto context");

    // determine which type of plaintext that you need to decrypt into
    // Plaintext decrypted =
    // CryptoContextImpl<Element>::GetPlaintextForDecrypt(ciphertext->GetEncodingType(),
    // this->GetElementParams(), this->GetEncodingParams());
    Plaintext decrypted = CryptoContextImpl<DCRTPoly>::GetPlaintextForDecrypt(
        ciphertext->GetEncodingType(), ciphertext->GetElements()[0].GetParams(), this->GetEncodingParams(),
        this->GetCKKSDataType());

    DecryptResult result;

    if ((ciphertext->GetEncodingType() == CKKS_PACKED_ENCODING) &&
        (ciphertext->GetElements()[0].GetParams()->GetParams().size() > 1))  // more than one tower in DCRTPoly
        result = GetScheme()->Decrypt(ciphertext, privateKey, &decrypted->GetElement<Poly>());
    else
        result = GetScheme()->Decrypt(ciphertext, privateKey, &decrypted->GetElement<NativePoly>());

    if (result.isValid == false)
        return result;

    decrypted->SetScalingFactorInt(result.scalingFactorInt);

    if (ciphertext->GetEncodingType() == CKKS_PACKED_ENCODING) {
        auto decryptedCKKS = std::dynamic_pointer_cast<CKKSPackedEncoding>(decrypted);
        decryptedCKKS->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg());
        decryptedCKKS->SetLevel(ciphertext->GetLevel());
        decryptedCKKS->SetScalingFactor(ciphertext->GetScalingFactor());
        decryptedCKKS->SetSlots(ciphertext->GetSlots());

        const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(this->GetCryptoParameters());

        decryptedCKKS->Decode(ciphertext->GetNoiseScaleDeg(), ciphertext->GetScalingFactor(),
                              cryptoParamsCKKS->GetScalingTechnique(), cryptoParamsCKKS->GetExecutionMode());
    }
    else {
        decrypted->Decode();
    }

    *plaintext = std::move(decrypted);
    return result;
}

template <>
DecryptResult CryptoContextImpl<DCRTPoly>::MultipartyDecryptFusion(
    const std::vector<Ciphertext<DCRTPoly>>& partialCiphertextVec, Plaintext* plaintext) const {
    DecryptResult result;

    // Make sure we're processing ciphertexts.
    size_t last_ciphertext = partialCiphertextVec.size();
    if (last_ciphertext < 1)
        return result;

    for (size_t i = 0; i < last_ciphertext; i++) {
        ValidateCiphertext(partialCiphertextVec[i]);
        if (partialCiphertextVec[i]->GetEncodingType() != partialCiphertextVec[0]->GetEncodingType())
            OPENFHE_THROW("Ciphertexts have mismatched encoding types");
    }

    // determine which type of plaintext that you need to decrypt into
    Plaintext decrypted = CryptoContextImpl<DCRTPoly>::GetPlaintextForDecrypt(
        partialCiphertextVec[0]->GetEncodingType(), partialCiphertextVec[0]->GetElements()[0].GetParams(),
        this->GetEncodingParams(), this->GetCKKSDataType());

    if ((partialCiphertextVec[0]->GetEncodingType() == CKKS_PACKED_ENCODING) &&
        (partialCiphertextVec[0]->GetElements()[0].GetParams()->GetParams().size() > 1))
        result = GetScheme()->MultipartyDecryptFusion(partialCiphertextVec, &decrypted->GetElement<Poly>());
    else
        result = GetScheme()->MultipartyDecryptFusion(partialCiphertextVec, &decrypted->GetElement<NativePoly>());

    if (result.isValid == false)
        return result;

    decrypted->SetScalingFactorInt(result.scalingFactorInt);

    if (partialCiphertextVec[0]->GetEncodingType() == CKKS_PACKED_ENCODING) {
        auto decryptedCKKS = std::dynamic_pointer_cast<CKKSPackedEncoding>(decrypted);
        decryptedCKKS->SetSlots(partialCiphertextVec[0]->GetSlots());
        const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(this->GetCryptoParameters());
        decryptedCKKS->Decode(partialCiphertextVec[0]->GetNoiseScaleDeg(), partialCiphertextVec[0]->GetScalingFactor(),
                              cryptoParamsCKKS->GetScalingTechnique(), cryptoParamsCKKS->GetExecutionMode());
    }
    else {
        decrypted->Decode();
    }

    *plaintext = std::move(decrypted);

    return result;
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::IntMPBootAdjustScale(ConstCiphertext<Element>& ciphertext) const {
    return GetScheme()->IntMPBootAdjustScale(ciphertext);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::IntMPBootRandomElementGen(const PublicKey<Element> publicKey) const {
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(this->GetCryptoParameters());
    if (cryptoParamsCKKS == nullptr)
        OPENFHE_THROW("The parameter object is not of the CryptoParametersCKKSRNS type");

    return GetScheme()->IntMPBootRandomElementGen(cryptoParamsCKKS, publicKey);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::IntMPBootRandomElementGen(ConstCiphertext<Element>& ciphertext) const {
    const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
    if (cryptoParamsCKKS == nullptr)
        OPENFHE_THROW("The parameter object is not of the CryptoParametersCKKSRNS type");

    return GetScheme()->IntMPBootRandomElementGen(cryptoParamsCKKS, ciphertext);
}

template <typename Element>
std::vector<Ciphertext<Element>> CryptoContextImpl<Element>::IntMPBootDecrypt(const PrivateKey<Element> privateKey,
                                                                              ConstCiphertext<Element>& ciphertext,
                                                                              ConstCiphertext<Element>& a) const {
    return GetScheme()->IntMPBootDecrypt(privateKey, ciphertext, a);
}

template <typename Element>
std::vector<Ciphertext<Element>> CryptoContextImpl<Element>::IntMPBootAdd(
    std::vector<std::vector<Ciphertext<Element>>>& sharesPairVec) const {
    return GetScheme()->IntMPBootAdd(sharesPairVec);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::IntMPBootEncrypt(const PublicKey<Element> publicKey,
                                                                 const std::vector<Ciphertext<Element>>& sharesPair,
                                                                 ConstCiphertext<Element>& a,
                                                                 ConstCiphertext<Element>& ciphertext) const {
    return GetScheme()->IntMPBootEncrypt(publicKey, sharesPair, a, ciphertext);
}

// Function for sharing and recovery of secret for Threshold FHE with aborts
template <>
std::unordered_map<uint32_t, DCRTPoly> CryptoContextImpl<DCRTPoly>::ShareKeys(const PrivateKey<DCRTPoly>& sk,
                                                                              uint32_t N, uint32_t threshold,
                                                                              uint32_t index,
                                                                              const std::string& shareType) const {
    // conditions on N and threshold for security with aborts
    if (N < 2)
        OPENFHE_THROW("Number of parties needs to be at least 3 for aborts");

    if (threshold <= N / 2)
        OPENFHE_THROW("Threshold required to be majority (more than N/2)");

    const auto cryptoParams = sk->GetCryptoContext()->GetCryptoParameters();
    auto elementParams      = cryptoParams->GetElementParams();
    auto vecSize            = elementParams->GetParams().size();
    auto ring_dimension     = elementParams->GetRingDimension();

    // condition for inverse in lagrange coeff to exist.
    for (size_t i = 0; i < vecSize; ++i) {
        auto modq_k = elementParams->GetParams()[i]->GetModulus();
        if (N >= modq_k)
            OPENFHE_THROW("Number of parties N needs to be less than DCRTPoly moduli");
    }

    // secret sharing
    std::unordered_map<uint32_t, DCRTPoly> SecretShares;

    if (shareType == "additive") {
        // generate a random share of N-2 elements and create the last share as sk - (sk_1 + ... + sk_N-2)
        typename DCRTPoly::DugType dug;
        DCRTPoly rsum(dug, elementParams, Format::EVALUATION);

        const uint32_t num_of_shares = N - 1;
        std::vector<DCRTPoly> SecretSharesVec;
        SecretSharesVec.reserve(num_of_shares);
        SecretSharesVec.push_back(rsum);
        for (size_t i = 1; i < num_of_shares - 1; ++i) {
            DCRTPoly r(dug, elementParams, Format::EVALUATION);  // should re-generate uniform r for each share
            rsum += r;
            SecretSharesVec.push_back(std::move(r));
        }
        SecretSharesVec.push_back(sk->GetPrivateElement() - rsum);

        for (size_t i = 1, ctr = 0; i <= N; ++i) {
            if (i != index) {
                SecretShares[i] = SecretSharesVec[ctr++];
            }
        }
    }
    else if (shareType == "shamir") {
        // vector to store columnwise randomly generated coefficients for polynomial f from Z_q for every secret key entry
        // set constant term of polynomial f_i to s_i
        std::vector<DCRTPoly> fs{sk->GetPrivateElement()};
        fs.back().SetFormat(Format::COEFFICIENT);

        // generate random coefficients
        fs.reserve(threshold);
        typename DCRTPoly::DugType dug;
        for (size_t i = 1; i < threshold; ++i) {
            fs.emplace_back(dug, elementParams, Format::COEFFICIENT);
        }

        // evaluate the polynomial at the index of the parties 1 to N
        for (size_t i = 1; i <= N; ++i) {
            if (i != index) {
                DCRTPoly feval(elementParams, Format::COEFFICIENT, true);
                for (size_t k = 0; k < vecSize; k++) {
                    auto modq_k = elementParams->GetParams()[k]->GetModulus();

                    NativePoly powtemppoly(elementParams->GetParams()[k], Format::COEFFICIENT);
                    NativePoly fevalpoly(elementParams->GetParams()[k], Format::COEFFICIENT, true);

                    NativeInteger powtemp(1);
                    for (size_t t = 1; t < threshold; t++) {
                        NativeVector powtempvec(ring_dimension, modq_k, (powtemp = powtemp.ModMul(i, modq_k)));

                        powtemppoly.SetValues(std::move(powtempvec), Format::COEFFICIENT);

                        auto& fst = fs[t].GetElementAtIndex(k);

                        for (size_t i = 0; i < ring_dimension; ++i) {
                            fevalpoly[i] += powtemppoly[i].ModMul(fst[i], modq_k);
                        }
                    }
                    fevalpoly += fs[0].GetElementAtIndex(k);

                    fevalpoly.SetFormat(Format::COEFFICIENT);
                    feval.SetElementAtIndex(k, std::move(fevalpoly));
                }
                // assign fi
                SecretShares.emplace(i, std::move(feval));
            }
        }
    }
    return SecretShares;
}

template <>
void CryptoContextImpl<DCRTPoly>::RecoverSharedKey(PrivateKey<DCRTPoly>& sk,
                                                   std::unordered_map<uint32_t, DCRTPoly>& sk_shares, uint32_t N,
                                                   uint32_t threshold, const std::string& shareType) const {
    if (sk_shares.size() < threshold)
        OPENFHE_THROW("Number of shares available less than threshold of the sharing scheme");

    // conditions on N and threshold for security with aborts
    if (N < 2)
        OPENFHE_THROW("Number of parties needs to be at least 3 for aborts");

    if (threshold <= N / 2)
        OPENFHE_THROW("Threshold required to be majority (more than N/2)");

    const auto& cryptoParams  = sk->GetCryptoContext()->GetCryptoParameters();
    const auto& elementParams = cryptoParams->GetElementParams();
    size_t ring_dimension     = elementParams->GetRingDimension();
    size_t vecSize            = elementParams->GetParams().size();

    // condition for inverse in lagrange coeff to exist.
    for (size_t k = 0; k < vecSize; k++) {
        auto modq_k = elementParams->GetParams()[k]->GetModulus();
        if (N >= modq_k)
            OPENFHE_THROW("Number of parties N needs to be less than DCRTPoly moduli");
    }

    // vector of indexes of the clients
    std::vector<uint32_t> client_indexes;
    client_indexes.reserve(N);
    for (uint32_t i = 1; i <= N; ++i) {
        if (sk_shares.find(i) != sk_shares.end())
            client_indexes.push_back(i);
    }
    const uint32_t client_indexes_size = client_indexes.size();

    if (client_indexes_size < threshold)
        OPENFHE_THROW("Not enough shares to recover the secret");

    if (shareType == "additive") {
        DCRTPoly sum_of_elems(elementParams, Format::EVALUATION, true);
        for (uint32_t i = 0; i < threshold; ++i) {
            sum_of_elems += sk_shares[client_indexes[i]];
        }
        sk->SetPrivateElement(std::move(sum_of_elems));
    }
    else if (shareType == "shamir") {
        // use lagrange interpolation to recover the secret
        // vector of lagrange coefficients L_j = Pdt_i ne j (i (i-j)^-1)
        std::vector<DCRTPoly> Lagrange_coeffs(client_indexes_size, DCRTPoly(elementParams, Format::EVALUATION));

        // recovery of the secret with lagrange coefficients and the secret shares
        for (uint32_t j = 0; j < client_indexes_size; j++) {
            auto cj = client_indexes[j];
            for (size_t k = 0; k < vecSize; k++) {
                auto modq_k = elementParams->GetParams()[k]->GetModulus();
                NativePoly multpoly(elementParams->GetParams()[k], Format::COEFFICIENT, true);
                multpoly.AddILElementOne();
                for (uint32_t i = 0; i < client_indexes_size; i++) {
                    auto ci = client_indexes[i];
                    if (ci != cj) {
                        auto&& denominator = (cj < ci) ? NativeInteger(ci - cj) : modq_k - NativeInteger(cj - ci);
                        auto denom_inv{denominator.ModInverse(modq_k)};
                        for (size_t d = 0; d < ring_dimension; ++d)
                            multpoly[d].ModMulFastEq(NativeInteger(ci).ModMul(denom_inv, modq_k), modq_k);
                    }
                }
                multpoly.SetFormat(Format::EVALUATION);
                Lagrange_coeffs[j].SetElementAtIndex(k, std::move(multpoly));
            }
            Lagrange_coeffs[j].SetFormat(Format::COEFFICIENT);
        }

        DCRTPoly lagrange_sum_of_elems(elementParams, Format::COEFFICIENT, true);
        for (size_t k = 0; k < vecSize; ++k) {
            NativePoly lagrange_sum_of_elems_poly(elementParams->GetParams()[k], Format::COEFFICIENT, true);
            for (uint32_t i = 0; i < client_indexes_size; ++i) {
                const auto& coeff = Lagrange_coeffs[i].GetAllElements()[k];
                const auto& share = sk_shares[client_indexes[i]].GetAllElements()[k];
                lagrange_sum_of_elems_poly += coeff.TimesNoCheck(share);
            }
            lagrange_sum_of_elems.SetElementAtIndex(k, std::move(lagrange_sum_of_elems_poly));
        }
        lagrange_sum_of_elems.SetFormat(Format::EVALUATION);
        sk->SetPrivateElement(std::move(lagrange_sum_of_elems));
    }
}

// explicit template instantiations (including the instantiations reqiured for pybind11 binding)
// clang-format off
template class CryptoContextImpl<DCRTPoly>;

#define INSTANTIATE_FUNCTION_TEMPLATES(VECTOR_TYPE) \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalChebyshevSeries<VECTOR_TYPE>(ConstCiphertext<DCRTPoly>&, const std::vector<VECTOR_TYPE>&, double, double) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalChebyshevSeriesLinear<VECTOR_TYPE>(ConstCiphertext<DCRTPoly>&, const std::vector<VECTOR_TYPE>&, double, double) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalChebyshevSeriesPS<VECTOR_TYPE>(ConstCiphertext<DCRTPoly>&, const std::vector<VECTOR_TYPE>&, double, double) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalLinearWSum<VECTOR_TYPE>(std::vector<ReadOnlyCiphertext<DCRTPoly>>&, const std::vector<VECTOR_TYPE>&) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalLinearWSumMutable<VECTOR_TYPE>(const std::vector<VECTOR_TYPE>&, std::vector<Ciphertext<DCRTPoly>>&) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalPoly<VECTOR_TYPE>(ConstCiphertext<DCRTPoly>&, const std::vector<VECTOR_TYPE>&) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalPolyLinear<VECTOR_TYPE>(ConstCiphertext<DCRTPoly>&, const std::vector<VECTOR_TYPE>&) const; \
    template Ciphertext<DCRTPoly> CryptoContextImpl<DCRTPoly>::EvalPolyPS<VECTOR_TYPE>(ConstCiphertext<DCRTPoly>&, const std::vector<VECTOR_TYPE>&) const;

INSTANTIATE_FUNCTION_TEMPLATES(int64_t)
INSTANTIATE_FUNCTION_TEMPLATES(double)
INSTANTIATE_FUNCTION_TEMPLATES(std::complex<double>)
// clang-format on

}  // namespace lbcrypto