Program Listing for File base-multiparty.cpp

Return to documentation for file (pke/lib/schemebase/base-multiparty.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#include "cryptocontext.h"
#include "key/evalkey.h"
#include "key/evalkeyrelin.h"
#include "key/privatekey.h"
#include "key/publickey.h"
#include "schemebase/base-multiparty.h"
#include "schemebase/base-pke.h"
#include "schemebase/base-scheme.h"
#include "schemebase/rlwe-cryptoparameters.h"

#include <iostream>
#include <map>
#include <memory>
#include <utility>
#include <vector>

namespace lbcrypto {

// makeSparse is not used by this scheme
template <class Element>
KeyPair<Element> MultipartyBase<Element>::MultipartyKeyGen(CryptoContext<Element> cc,
                                                           const std::vector<PrivateKey<Element>>& privateKeyVec,
                                                           bool makeSparse) {
    const auto cryptoParams  = std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(cc->GetCryptoParameters());
    const auto elementParams = cryptoParams->GetElementParams();

    // Private Key Generation

    Element s(elementParams, Format::EVALUATION, true);
    for (auto& pk : privateKeyVec)
        s += pk->GetPrivateElement();

    // Public Key Generation

    DugType dug;
    Element a(dug, elementParams, Format::EVALUATION);

    Element e(cryptoParams->GetDiscreteGaussianGenerator(), elementParams, Format::EVALUATION);
    NativeInteger ns = cryptoParams->GetNoiseScale();

    // b = ns * e - a * s
    Element b(std::move((e *= ns) -= (a * s)));

    KeyPair<Element> keyPair(std::make_shared<PublicKeyImpl<Element>>(cc),
                             std::make_shared<PrivateKeyImpl<Element>>(cc));
    keyPair.secretKey->SetPrivateElement(std::move(s));
    keyPair.publicKey->SetPublicElements({std::move(b), std::move(a)});
    return keyPair;
}

template <class Element>
KeyPair<Element> MultipartyBase<Element>::MultipartyKeyGen(CryptoContext<Element> cc,
                                                           const PublicKey<Element> publicKey, bool makeSparse,
                                                           bool fresh) {
    const auto cryptoParams  = std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(cc->GetCryptoParameters());
    const auto elementParams = cryptoParams->GetElementParams();
    const auto paramsPK      = cryptoParams->GetParamsPK();
    if (!paramsPK)
        OPENFHE_THROW("PrecomputeCRTTables() must be called before using precomputed params.");

    const DggType& dgg = cryptoParams->GetDiscreteGaussianGenerator();
    TugType tug;

    Element s;
    switch (cryptoParams->GetSecretKeyDist()) {
        case GAUSSIAN:
            s = Element(dgg, paramsPK, Format::EVALUATION);
            break;
        case UNIFORM_TERNARY:
            s = Element(tug, paramsPK, Format::EVALUATION);
            break;
        case SPARSE_TERNARY:
        case SPARSE_ENCAPSULATED:
            s = Element(tug, paramsPK, Format::EVALUATION, 192);
            break;
        default:
            OPENFHE_THROW("Unknown SecretKeyDist.");
    }

    const auto& pk = publicKey->GetPublicElements();
    Element a(pk[1]);
    Element e(dgg, paramsPK, Format::EVALUATION);
    NativeInteger ns = cryptoParams->GetNoiseScale();

    // b = ns * e - a * s
    // When PRE is not used, a joint key is computed
    Element b(std::move((e *= ns) -= (a * s)));
    if (!fresh)
        b += pk[0];

    auto sizeQ  = elementParams->GetParams().size();
    auto sizePK = paramsPK->GetParams().size();
    if (sizePK > sizeQ)
        s.DropLastElements(sizePK - sizeQ);

    KeyPair<Element> keyPair(std::make_shared<PublicKeyImpl<Element>>(cc),
                             std::make_shared<PrivateKeyImpl<Element>>(cc));
    keyPair.secretKey->SetPrivateElement(std::move(s));
    keyPair.publicKey->SetPublicElements({std::move(b), std::move(a)});
    return keyPair;
}

template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiKeySwitchGen(const PrivateKey<Element> oldPrivateKey,
                                                            const PrivateKey<Element> newPrivateKey,
                                                            const EvalKey<Element> evalKey) const {
    return oldPrivateKey->GetCryptoContext()->GetScheme()->KeySwitchGen(oldPrivateKey, newPrivateKey, evalKey);
}

template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiEvalAutomorphismKeyGen(
    const PrivateKey<Element> privateKey, const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap,
    const std::vector<uint32_t>& indexList) const {
    const Element& s = privateKey->GetPrivateElement();
    uint32_t N       = s.GetRingDimension();

    if (indexList.size() > N - 1)
        OPENFHE_THROW("size exceeds the ring dimension");

    const auto cc = privateKey->GetCryptoContext();

    auto result = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();

    // #pragma omp parallel for if (indexList.size() >= 4)
    for (uint32_t i = 0; i < indexList.size(); i++) {
        PrivateKey<Element> privateKeyPermuted = std::make_shared<PrivateKeyImpl<Element>>(cc);

        uint32_t index = NativeInteger(indexList[i]).ModInverse(2 * N).ConvertToInt();
        std::vector<uint32_t> vec(N);
        PrecomputeAutoMap(N, index, &vec);

        Element sPermuted = s.AutomorphismTransform(index, vec);
        privateKeyPermuted->SetPrivateElement(std::move(sPermuted));

        // verify if the key indexList[i] exists in the evalKeyMap
        auto evalKeyIterator = evalKeyMap->find(indexList[i]);
        if (evalKeyIterator == evalKeyMap->end()) {
            OPENFHE_THROW("EvalKey for index [" + std::to_string(indexList[i]) + "] is not found.");
        }

        (*result)[indexList[i]] = MultiKeySwitchGen(privateKey, privateKeyPermuted, evalKeyIterator->second);
    }
    return result;
}

template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiEvalAtIndexKeyGen(
    const PrivateKey<Element> privateKey, const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap,
    const std::vector<int32_t>& indexList) const {
    const auto cc = privateKey->GetCryptoContext();

    uint32_t M = privateKey->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder();

    std::vector<uint32_t> autoIndices(indexList.size());

    for (size_t i = 0; i < indexList.size(); i++) {
        autoIndices[i] = (isCKKS(cc->getSchemeId())) ? FindAutomorphismIndex2nComplex(indexList[i], M) :
                                                       FindAutomorphismIndex2n(indexList[i], M);
    }

    return MultiEvalAutomorphismKeyGen(privateKey, evalKeyMap, autoIndices);
}

template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiEvalSumKeyGen(
    const PrivateKey<Element> privateKey,
    const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap) const {
    const auto cryptoParams = privateKey->GetCryptoParameters();

    uint32_t batchSize = cryptoParams->GetEncodingParams()->GetBatchSize();
    uint32_t M         = cryptoParams->GetElementParams()->GetCyclotomicOrder();

    std::vector<uint32_t> indices;

    if (batchSize > 1) {
        int isize = std::ceil(std::log2(batchSize)) - 1;
        indices.reserve(isize + 1);
        uint32_t g = 5;
        for (int i = 0; i < isize; i++) {
            indices.push_back(g);
            g = (g * g) % M;
        }
        if (2 * batchSize < M)
            indices.push_back(g);
        else
            indices.push_back(M - 1);
    }

    return MultiEvalAutomorphismKeyGen(privateKey, evalKeyMap, indices);
}

template <class Element>
Ciphertext<Element> MultipartyBase<Element>::MultipartyDecryptLead(ConstCiphertext<Element> ciphertext,
                                                                   const PrivateKey<Element> privateKey) const {
    const auto cryptoParams =
        std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(privateKey->GetCryptoParameters());

    const std::shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
    const auto ns                                 = cryptoParams->GetNoiseScale();

    const std::vector<Element>& cv = ciphertext->GetElements();

    const Element& s = privateKey->GetPrivateElement();

    DggType dgg(NoiseFlooding::MP_SD);
    Element e(dgg, elementParams, Format::EVALUATION);

    Element b = cv[0] + s * cv[1] + ns * e;
    //  b.SwitchFormat();

    auto result = ciphertext->CloneEmpty();
    result->SetElement(std::move(b));
    return result;
}

template <class Element>
Ciphertext<Element> MultipartyBase<Element>::MultipartyDecryptMain(ConstCiphertext<Element> ciphertext,
                                                                   const PrivateKey<Element> privateKey) const {
    const auto cryptoParams =
        std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(privateKey->GetCryptoParameters());

    const std::shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
    const auto es                                 = cryptoParams->GetNoiseScale();

    const std::vector<Element>& cv = ciphertext->GetElements();
    const Element& s               = privateKey->GetPrivateElement();

    DggType dgg(NoiseFlooding::MP_SD);
    Element e(dgg, elementParams, Format::EVALUATION);

    // e is added to do noise flooding
    Element b = s * cv[1] + es * e;

    auto result = ciphertext->CloneEmpty();
    result->SetElement(std::move(b));
    return result;
}

template <class Element>
DecryptResult MultipartyBase<Element>::MultipartyDecryptFusion(const std::vector<Ciphertext<Element>>& ciphertextVec,
                                                               NativePoly* plaintext) const {
    const auto cryptoParams =
        std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(ciphertextVec[0]->GetCryptoParameters());

    const std::vector<Element>& cv0 = ciphertextVec[0]->GetElements();

    Element b = cv0[0];
    for (size_t i = 1; i < ciphertextVec.size(); i++) {
        const std::vector<Element>& cvi = ciphertextVec[i]->GetElements();
        b += cvi[0];
    }
    b.SetFormat(Format::COEFFICIENT);

    *plaintext = b.ToNativePoly();

    return DecryptResult(plaintext->GetLength());
}

template <class Element>
PublicKey<Element> MultipartyBase<Element>::MultiAddPubKeys(PublicKey<Element> publicKey1,
                                                            PublicKey<Element> publicKey2) const {
    PublicKey<Element> publicKeySum = std::make_shared<PublicKeyImpl<Element>>(publicKey1->GetCryptoContext());

    const Element& b1 = publicKey1->GetPublicElements()[0];
    const Element& b2 = publicKey2->GetPublicElements()[0];
    const Element& a  = publicKey1->GetPublicElements()[1];
    publicKeySum->SetPublicElements(std::vector<Element>{(b1 + b2), a});

    return publicKeySum;
}

template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiAddEvalKeys(EvalKey<Element> evalKey1, EvalKey<Element> evalKey2) const {
    const auto cc = evalKey1->GetCryptoContext();

    EvalKey<Element> evalKeySum = std::make_shared<EvalKeyRelinImpl<Element>>(cc);

    const std::vector<Element>& a  = evalKey1->GetAVector();
    const std::vector<Element>& b1 = evalKey1->GetBVector();
    const std::vector<Element>& b2 = evalKey2->GetBVector();

    std::vector<Element> b;
    b.reserve(a.size());

    for (uint32_t i = 0; i < a.size(); i++) {
        b.push_back(b1[i] + b2[i]);
    }

    evalKeySum->SetAVector(a);
    evalKeySum->SetBVector(std::move(b));
    return evalKeySum;
}

template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiAddEvalMultKeys(EvalKey<Element> evalKey1,
                                                               EvalKey<Element> evalKey2) const {
    const auto cc = evalKey1->GetCryptoContext();

    EvalKey<Element> evalKeySum = std::make_shared<EvalKeyRelinImpl<Element>>(cc);

    const std::vector<Element>& a1 = evalKey1->GetAVector();
    const std::vector<Element>& a2 = evalKey2->GetAVector();
    const std::vector<Element>& b1 = evalKey1->GetBVector();
    const std::vector<Element>& b2 = evalKey2->GetBVector();

    std::vector<Element> a;
    a.reserve(a1.size());
    std::vector<Element> b;
    b.reserve(a1.size());

    for (uint32_t i = 0; i < a1.size(); i++) {
        a.push_back(a1[i] + a2[i]);
        b.push_back(b1[i] + b2[i]);
    }

    evalKeySum->SetAVector(std::move(a));
    evalKeySum->SetBVector(std::move(b));
    return evalKeySum;
}

template <class Element>
EvalKey<Element> MultipartyBase<Element>::MultiMultEvalKey(PrivateKey<Element> privateKey,
                                                           EvalKey<Element> evalKey) const {
    const auto cc = evalKey->GetCryptoContext();

    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRLWE<Element>>(cc->GetCryptoParameters());

    const DggType& dgg       = cryptoParams->GetDiscreteGaussianGenerator();
    const auto elementParams = cryptoParams->GetElementParams();

    EvalKey<Element> evalKeyResult = std::make_shared<EvalKeyRelinImpl<Element>>(cc);

    const std::vector<Element>& a0 = evalKey->GetAVector();
    const std::vector<Element>& b0 = evalKey->GetBVector();

    const Element& s = privateKey->GetPrivateElement();
    const auto ns    = cryptoParams->GetNoiseScale();

    std::vector<Element> a;
    a.reserve(a0.size());
    std::vector<Element> b;
    b.reserve(a0.size());

    for (uint32_t i = 0; i < a0.size(); i++) {
        a.push_back(a0[i] * s + ns * Element(dgg, elementParams, Format::EVALUATION));
        b.push_back(b0[i] * s + ns * Element(dgg, elementParams, Format::EVALUATION));
    }

    evalKeyResult->SetAVector(std::move(a));
    evalKeyResult->SetBVector(std::move(b));
    return evalKeyResult;
}

template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiAddEvalAutomorphismKeys(
    const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap1,
    const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap2) const {
    auto evalKeyMapAuto = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();

    for (auto it = evalKeyMap1->begin(); it != evalKeyMap1->end(); ++it) {
        auto it2 = evalKeyMap2->find(it->first);
        if (it2 != evalKeyMap2->end())
            (*evalKeyMapAuto)[it->first] = MultiAddEvalKeys(it->second, it2->second);
    }

    return evalKeyMapAuto;
}

template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> MultipartyBase<Element>::MultiAddEvalSumKeys(
    const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap1,
    const std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> evalKeyMap2) const {
    auto EvalKeyMapSum = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();

    for (auto it = evalKeyMap1->begin(); it != evalKeyMap1->end(); ++it) {
        auto it2 = evalKeyMap2->find(it->first);
        if (it2 != evalKeyMap2->end())
            (*EvalKeyMapSum)[it->first] = MultiAddEvalKeys(it->second, it2->second);
    }

    return EvalKeyMapSum;
}

}  // namespace lbcrypto

// the code below is from base-multiparty-impl.cpp
namespace lbcrypto {
template class MultipartyBase<DCRTPoly>;
}  // namespace lbcrypto