Program Listing for File rlwe-cryptoparameters.h

Return to documentation for file (pke/include/schemebase/rlwe-cryptoparameters.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  ring-learn-with-errors functionality
 */

#ifndef LBCRYPTO_RLWE_CRYPTOPARAMETERS_H
#define LBCRYPTO_RLWE_CRYPTOPARAMETERS_H

#include <memory>
#include <string>

#include "lattice/lat-hal.h"
#include "schemebase/base-cryptoparameters.h"
#include "constants.h"
#include "lattice/constants-lattice.h"

// TODO - temp include for the SecurityLevel
#include "lattice/stdlatticeparms.h"

namespace lbcrypto {

template <class Element>
class CryptoParametersRLWE : public CryptoParametersBase<Element> {
public:
    CryptoParametersRLWE() = default;

    CryptoParametersRLWE(const CryptoParametersRLWE& rhs)
        : CryptoParametersBase<Element>(rhs.GetElementParams(), rhs.GetPlaintextModulus()) {
        m_distributionParameter         = rhs.m_distributionParameter;
        m_assuranceMeasure              = rhs.m_assuranceMeasure;
        m_noiseScale                    = rhs.m_noiseScale;
        m_digitSize                     = rhs.m_digitSize;
        m_noiseEstimate                 = rhs.m_noiseEstimate;
        m_multiplicativeDepth           = rhs.m_multiplicativeDepth;
        m_evalAddCount                  = rhs.m_evalAddCount;
        m_keySwitchCount                = rhs.m_keySwitchCount;
        m_PRENumHops                    = rhs.m_PRENumHops;
        m_maxRelinSkDeg                 = rhs.m_maxRelinSkDeg;
        m_secretKeyDist                 = rhs.m_secretKeyDist;
        m_stdLevel                      = rhs.m_stdLevel;
        m_floodingDistributionParameter = rhs.m_floodingDistributionParameter;
        m_dgg.SetStd(m_distributionParameter);
        m_dggFlooding.SetStd(m_floodingDistributionParameter);
        m_PREMode               = rhs.m_PREMode;
        m_multipartyMode        = rhs.m_multipartyMode;
        m_executionMode         = rhs.m_executionMode;
        m_decryptionNoiseMode   = rhs.m_decryptionNoiseMode;
        m_statisticalSecurity   = rhs.m_statisticalSecurity;
        m_numAdversarialQueries = rhs.m_numAdversarialQueries;
        m_thresholdNumOfParties = rhs.m_thresholdNumOfParties;
    }

    CryptoParametersRLWE(std::shared_ptr<typename Element::Params> params, EncodingParams encodingParams,
                         float distributionParameter, float assuranceMeasure, SecurityLevel stdLevel, uint32_t digitSize,
                         int maxRelinSkDeg = 2, SecretKeyDist secretKeyDist = GAUSSIAN,
                         ProxyReEncryptionMode PREMode = INDCPA, MultipartyMode multipartyMode = FIXED_NOISE_MULTIPARTY,
                         ExecutionMode executionMode             = EXEC_EVALUATION,
                         DecryptionNoiseMode decryptionNoiseMode = FIXED_NOISE_DECRYPT, PlaintextModulus noiseScale = 1,
                         uint32_t statisticalSecurity = 30, uint32_t numAdversarialQueries = 1,
                         uint32_t thresholdNumOfParties = 1)
        : CryptoParametersBase<Element>(params, encodingParams) {
        m_distributionParameter = distributionParameter;
        m_assuranceMeasure      = assuranceMeasure;
        m_noiseScale            = noiseScale;
        m_digitSize             = digitSize;
        m_dgg.SetStd(m_distributionParameter);
        m_maxRelinSkDeg         = maxRelinSkDeg;
        m_secretKeyDist         = secretKeyDist;
        m_stdLevel              = stdLevel;
        m_PREMode               = PREMode;
        m_multipartyMode        = multipartyMode;
        m_executionMode         = executionMode;
        m_decryptionNoiseMode   = decryptionNoiseMode;
        m_statisticalSecurity   = statisticalSecurity;
        m_numAdversarialQueries = numAdversarialQueries;
        m_thresholdNumOfParties = thresholdNumOfParties;
    }

    ~CryptoParametersRLWE() override = default;

    float GetDistributionParameter() const {
        return m_distributionParameter;
    }

    double GetFloodingDistributionParameter() const {
        return m_floodingDistributionParameter;
    }

    float GetAssuranceMeasure() const {
        return m_assuranceMeasure;
    }

    PlaintextModulus GetNoiseScale() const {
        return m_noiseScale;
    }

    uint32_t GetDigitSize() const override {
        return m_digitSize;
    }

    virtual double GetNoiseEstimate() const {
        return m_noiseEstimate;
    }
    virtual uint32_t GetMultiplicativeDepth() const {
        return m_multiplicativeDepth;
    }
    virtual uint32_t GetEvalAddCount() const {
        return m_evalAddCount;
    }
    virtual uint32_t GetKeySwitchCount() const {
        return m_keySwitchCount;
    }
    virtual uint32_t GetPRENumHops() const {
        return m_PRENumHops;
    }

    uint32_t GetMaxRelinSkDeg() const override {
        return m_maxRelinSkDeg;
    }

    SecretKeyDist GetSecretKeyDist() const {
        return m_secretKeyDist;
    }

    ProxyReEncryptionMode GetPREMode() const {
        return m_PREMode;
    }

    MultipartyMode GetMultipartyMode() const {
        return m_multipartyMode;
    }

    ExecutionMode GetExecutionMode() const {
        return m_executionMode;
    }

    DecryptionNoiseMode GetDecryptionNoiseMode() {
        return m_decryptionNoiseMode;
    }

    SecurityLevel GetStdLevel() const {
        return m_stdLevel;
    }

    const typename Element::DggType& GetDiscreteGaussianGenerator() const {
        return m_dgg;
    }

    typename Element::DggType& GetFloodingDiscreteGaussianGenerator() {
        return m_dggFlooding;
    }

    double GetStatisticalSecurity() const {
        return m_statisticalSecurity;
    }

    double GetNumAdversarialQueries() const {
        return m_numAdversarialQueries;
    }

    uint32_t GetThresholdNumOfParties() const {
        return m_thresholdNumOfParties;
    }

    // @Set Properties

    void SetDistributionParameter(float distributionParameter) {
        m_distributionParameter = distributionParameter;
        m_dgg.SetStd(m_distributionParameter);
    }

    void SetFloodingDistributionParameter(double distributionParameter) {
        m_floodingDistributionParameter = distributionParameter;
        m_dggFlooding.SetStd(m_floodingDistributionParameter);
    }

    void SetAssuranceMeasure(float assuranceMeasure) {
        m_assuranceMeasure = assuranceMeasure;
    }

    void SetStdLevel(SecurityLevel securityLevel) {
        m_stdLevel = securityLevel;
    }

    void SetNoiseScale(PlaintextModulus noiseScale) {
        m_noiseScale = noiseScale;
    }

    void SetDigitSize(uint32_t digitSize) {
        m_digitSize = digitSize;
    }

    void SetNoiseEstimate(double noiseEstimate) {
        m_noiseEstimate = noiseEstimate;
    }
    void SetMultiplicativeDepth(uint32_t multiplicativeDepth) {
        m_multiplicativeDepth = multiplicativeDepth;
    }
    void SetEvalAddCount(uint32_t evalAddCount) {
        m_evalAddCount = evalAddCount;
    }
    void SetKeySwitchCount(uint32_t keySwitchCount) {
        m_keySwitchCount = keySwitchCount;
    }
    void SetPRENumHops(uint32_t PRENumHops) {
        m_PRENumHops = PRENumHops;
    }

    void SetMaxRelinSkDeg(uint32_t maxRelinSkDeg) {
        m_maxRelinSkDeg = maxRelinSkDeg;
    }

    void SetSecretKeyDist(SecretKeyDist secretKeyDist) {
        m_secretKeyDist = secretKeyDist;
    }

    void SetPREMode(ProxyReEncryptionMode PREMode) {
        m_PREMode = PREMode;
    }

    void SetMultipartyMode(MultipartyMode multipartyMode) {
        m_multipartyMode = multipartyMode;
    }

    void SetExecutionMode(ExecutionMode executionMode) {
        m_executionMode = executionMode;
    }

    void SetDecryptionNoiseMode(DecryptionNoiseMode decryptionNoiseMode) {
        m_decryptionNoiseMode = decryptionNoiseMode;
    }

    void SetStatisticalSecurity(uint32_t statisticalSecurity) {
        m_statisticalSecurity = statisticalSecurity;
    }

    void SetNumAdversarialQueries(uint32_t numAdversarialQueries) {
        m_numAdversarialQueries = numAdversarialQueries;
    }

    void SetThresholdNumOfParties(uint32_t thresholdNumOfParties) {
        m_thresholdNumOfParties = thresholdNumOfParties;
    }

    template <class Archive>
    void save(Archive& ar, std::uint32_t const version) const {
        ar(::cereal::base_class<CryptoParametersBase<Element>>(this));
        ar(::cereal::make_nvp("dp", m_distributionParameter));
        ar(::cereal::make_nvp("am", m_assuranceMeasure));
        ar(::cereal::make_nvp("ns", m_noiseScale));
        ar(::cereal::make_nvp("rw", m_digitSize));
        ar(::cereal::make_nvp("nest", m_noiseEstimate));
        ar(::cereal::make_nvp("muld", m_multiplicativeDepth));
        ar(::cereal::make_nvp("addc", m_evalAddCount));
        ar(::cereal::make_nvp("kswc", m_keySwitchCount));
        ar(::cereal::make_nvp("phops", m_PRENumHops));
        ar(::cereal::make_nvp("md", m_maxRelinSkDeg));
        ar(::cereal::make_nvp("mo", m_secretKeyDist));
        ar(::cereal::make_nvp("pmo", m_PREMode));
        ar(::cereal::make_nvp("mmo", m_multipartyMode));
        ar(::cereal::make_nvp("exm", m_executionMode));
        ar(::cereal::make_nvp("dnm", m_decryptionNoiseMode));
        ar(::cereal::make_nvp("slv", m_stdLevel));
        ar(::cereal::make_nvp("fdp", m_floodingDistributionParameter));
        ar(::cereal::make_nvp("ss", m_statisticalSecurity));
        ar(::cereal::make_nvp("aq", m_numAdversarialQueries));
        ar(::cereal::make_nvp("tp", m_thresholdNumOfParties));
    }

    template <class Archive>
    void load(Archive& ar, std::uint32_t const version) {
        ar(::cereal::base_class<CryptoParametersBase<Element>>(this));
        ar(::cereal::make_nvp("dp", m_distributionParameter));
        ar(::cereal::make_nvp("am", m_assuranceMeasure));
        ar(::cereal::make_nvp("ns", m_noiseScale));
        ar(::cereal::make_nvp("rw", m_digitSize));
        ar(::cereal::make_nvp("nest", m_noiseEstimate));
        ar(::cereal::make_nvp("muld", m_multiplicativeDepth));
        ar(::cereal::make_nvp("addc", m_evalAddCount));
        ar(::cereal::make_nvp("kswc", m_keySwitchCount));
        ar(::cereal::make_nvp("phops", m_PRENumHops));
        ar(::cereal::make_nvp("md", m_maxRelinSkDeg));
        ar(::cereal::make_nvp("mo", m_secretKeyDist));
        ar(::cereal::make_nvp("pmo", m_PREMode));
        ar(::cereal::make_nvp("mmo", m_multipartyMode));
        ar(::cereal::make_nvp("exm", m_executionMode));
        ar(::cereal::make_nvp("dnm", m_decryptionNoiseMode));
        ar(::cereal::make_nvp("slv", m_stdLevel));
        ar(::cereal::make_nvp("fdp", m_floodingDistributionParameter));
        ar(::cereal::make_nvp("ss", m_statisticalSecurity));
        ar(::cereal::make_nvp("aq", m_numAdversarialQueries));
        ar(::cereal::make_nvp("tp", m_thresholdNumOfParties));

        m_dgg.SetStd(m_distributionParameter);
        m_dggFlooding.SetStd(m_floodingDistributionParameter);
    }

    std::string SerializedObjectName() const override {
        return "CryptoParametersRLWE";
    }

protected:
    // standard deviation in Discrete Gaussian Distribution
    float m_distributionParameter = 0;
    // standard deviation in Discrete Gaussian Distribution with Flooding
    double m_floodingDistributionParameter = 0;
    // assurance measure alpha
    float m_assuranceMeasure = 0;
    // noise scale
    PlaintextModulus m_noiseScale = 1;
    // digit size
    uint32_t m_digitSize = 1;

    double m_noiseEstimate{0};
    uint32_t m_multiplicativeDepth{1};
    uint32_t m_evalAddCount{0};
    uint32_t m_keySwitchCount{0};
    uint32_t m_PRENumHops{0};

    // the highest power of secret key for which relinearization key is generated
    uint32_t m_maxRelinSkDeg = 2;
    // specifies whether the secret polynomials are generated from discrete
    // Gaussian distribution or ternary distribution with the norm of unity
    SecretKeyDist m_secretKeyDist = GAUSSIAN;
    // Security level according in the HomomorphicEncryption.org standard
    SecurityLevel m_stdLevel = HEStd_NotSet;

    // m_dgg gets the same default value as m_distributionParameter does
    typename Element::DggType m_dgg = typename Element::DggType(0);
    // m_dggFlooding gets the same default value as m_floodingDistributionParameter does
    typename Element::DggType m_dggFlooding = typename Element::DggType(0);

    // specifies the security mode used for PRE
    ProxyReEncryptionMode m_PREMode = NOT_SET;

    // specifies the security mode used for multiparty decryption
    MultipartyMode m_multipartyMode = FIXED_NOISE_MULTIPARTY;

    // specifies the execution mode used for NOISE_FLOODING_DECRYPT mode in CKKS
    ExecutionMode m_executionMode = EXEC_EVALUATION;

    // specifies the noise mode used for decryption in CKKS
    DecryptionNoiseMode m_decryptionNoiseMode = FIXED_NOISE_DECRYPT;

    // Statistical security of CKKS in NOISE_FLOODING_DECRYPT mode. This is the bound on the probability of success
    // that any adversary can have. Specifically, they a probability of success of at most 2^(-statisticalSecurity).
    double m_statisticalSecurity = 30;

    // This is the number of adversarial queries a user is expecting for their application, which we use to ensure
    // security of CKKS in NOISE_FLOODING_DECRYPT mode.
    double m_numAdversarialQueries = 1;

    uint32_t m_thresholdNumOfParties = 1;

    bool CompareTo(const CryptoParametersBase<Element>& rhs) const override {
        auto el = dynamic_cast<const CryptoParametersRLWE*>(&rhs);
        if (!el)
            return false;

        return CryptoParametersBase<Element>::CompareTo(rhs) &&
               m_distributionParameter == el->m_distributionParameter &&
               m_assuranceMeasure == el->m_assuranceMeasure && m_noiseScale == el->m_noiseScale &&
               m_digitSize == el->m_digitSize && m_noiseEstimate == el->m_noiseEstimate &&
               m_multiplicativeDepth == el->m_multiplicativeDepth && m_evalAddCount == el->m_evalAddCount &&
               m_keySwitchCount == el->m_keySwitchCount && m_PRENumHops == el->m_PRENumHops &&
               m_secretKeyDist == el->m_secretKeyDist &&
               m_stdLevel == el->m_stdLevel && m_maxRelinSkDeg == el->m_maxRelinSkDeg &&
               m_PREMode == el->m_PREMode && m_multipartyMode == el->m_multipartyMode &&
               m_executionMode == el->m_executionMode &&
               m_floodingDistributionParameter == el->m_floodingDistributionParameter &&
               m_statisticalSecurity == el->m_statisticalSecurity &&
               m_numAdversarialQueries == el->m_numAdversarialQueries &&
               m_thresholdNumOfParties == el->m_thresholdNumOfParties;
    }

    void PrintParameters(std::ostream& os) const override {
        CryptoParametersBase<Element>::PrintParameters(os);

        os << "Distrib parm " << GetDistributionParameter() << ", Assurance measure " << GetAssuranceMeasure()
           << ", Noise scale " << GetNoiseScale() << ", Digit Size " << GetDigitSize() << ", SecretKeyDist "
           << GetSecretKeyDist() << ", Standard security level " << GetStdLevel() << std::endl;
    }
};

}  // namespace lbcrypto

#endif