Program Listing for File ildcrtparams.h
↰ Return to documentation for file (core/include/lattice/hal/default/ildcrtparams.h
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
Wraps parameters for integer lattice operations using double-CRT representation. Inherits from ElemParams
*/
#ifndef LBCRYPTO_INC_LATTICE_ILDCRTPARAMS_H
#define LBCRYPTO_INC_LATTICE_ILDCRTPARAMS_H
#include "lattice/hal/elemparams.h"
#include "lattice/hal/default/ilparams.h"
#include "math/hal/basicint.h"
#include "math/math-hal.h"
#include "math/nbtheory-impl.h"
#include "utils/exception.h"
#include "utils/inttypes.h"
#include <iomanip>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace lbcrypto {
template <typename IntType>
class ILDCRTParams final : public ElemParams<IntType> {
public:
using Integer = IntType;
using ILNativeParams = ILParamsImpl<NativeInteger>;
ILDCRTParams(uint32_t corder, const IntType& modulus, const IntType& rootOfUnity = IntType(0))
: ElemParams<IntType>(corder, modulus), m_originalModulus(modulus) {
// NOTE params generation uses this constructor to make an empty params that
// it will later populate during the gen process. For that special case...
// we don't populate, and we just return
if (corder == 0)
return;
auto q{LastPrime<NativeInteger>(MAX_MODULUS_SIZE, corder)};
m_params.reserve(32);
m_params.push_back(std::make_shared<ILNativeParams>(corder, q));
IntType compositeModulus(1);
while ((compositeModulus *= IntType(q.template ConvertToInt<BasicInteger>())) < modulus)
m_params.push_back(std::make_shared<ILNativeParams>(corder, (q = PreviousPrime(q, corder))));
ElemParams<IntType>::m_ciphertextModulus = compositeModulus;
}
explicit ILDCRTParams(uint32_t corder = 0, uint32_t depth = 1, uint32_t bits = MAX_MODULUS_SIZE)
: ElemParams<IntType>(corder, 0) {
if (corder == 0)
return;
if (bits > MAX_MODULUS_SIZE)
OPENFHE_THROW("Invalid bits for ILDCRTParams");
auto q{LastPrime<NativeInteger>(bits, corder)};
m_params.reserve(depth);
m_params.push_back(std::make_shared<ILNativeParams>(corder, q));
IntType compositeModulus(q.template ConvertToInt<BasicInteger>());
for (uint32_t _ = 1; _ < depth; ++_) {
m_params.push_back(std::make_shared<ILNativeParams>(corder, (q = PreviousPrime(q, corder))));
compositeModulus *= IntType(q.template ConvertToInt<BasicInteger>());
}
ElemParams<IntType>::m_ciphertextModulus = compositeModulus;
}
ILDCRTParams(uint32_t corder, const std::vector<NativeInteger>& moduli,
const std::vector<NativeInteger>& rootsOfUnity)
: ElemParams<IntType>(corder, 0) {
size_t limbs{moduli.size()};
if (limbs != rootsOfUnity.size())
OPENFHE_THROW("sizes of moduli and roots of unity do not match 1");
m_params.reserve(limbs);
IntType compositeModulus(1);
for (size_t i = 0; i < limbs; ++i) {
m_params.push_back(std::make_shared<ILNativeParams>(corder, moduli[i], rootsOfUnity[i]));
compositeModulus *= IntType(moduli[i].template ConvertToInt<BasicInteger>());
}
ElemParams<IntType>::m_ciphertextModulus = compositeModulus;
}
ILDCRTParams(uint32_t corder, const std::vector<NativeInteger>& moduli,
const std::vector<NativeInteger>& rootsOfUnity, const std::vector<NativeInteger>& moduliBig,
const std::vector<NativeInteger>& rootsOfUnityBig, const IntType& inputOriginalModulus = IntType(0))
: ElemParams<IntType>(corder, 0), m_originalModulus(inputOriginalModulus) {
size_t limbs{moduli.size()};
if (limbs != rootsOfUnity.size() || limbs != moduliBig.size() || limbs != rootsOfUnityBig.size())
OPENFHE_THROW("sizes of moduli and roots of unity do not match 2");
m_params.reserve(limbs);
IntType compositeModulus(1);
for (size_t i = 0; i < limbs; ++i) {
m_params.push_back(
std::make_shared<ILNativeParams>(corder, moduli[i], rootsOfUnity[i], moduliBig[i], rootsOfUnityBig[i]));
compositeModulus *= IntType(moduli[i].template ConvertToInt<BasicInteger>());
}
ElemParams<IntType>::m_ciphertextModulus = compositeModulus;
}
ILDCRTParams(uint32_t corder, const std::vector<NativeInteger>& moduli,
const IntType& inputOriginalModulus = IntType(0))
: ElemParams<IntType>(corder, 0), m_originalModulus(inputOriginalModulus) {
size_t limbs{moduli.size()};
m_params.reserve(limbs);
IntType compositeModulus(1);
for (size_t i = 0; i < limbs; ++i) {
m_params.push_back(std::make_shared<ILNativeParams>(corder, moduli[i]));
compositeModulus *= IntType(moduli[i].template ConvertToInt<BasicInteger>());
}
ElemParams<IntType>::m_ciphertextModulus = compositeModulus;
}
ILDCRTParams(uint32_t corder, const std::vector<std::shared_ptr<ILNativeParams>>& params,
const IntType& inputOriginalModulus = IntType(0))
: ElemParams<IntType>(corder, 0), m_params(params), m_originalModulus(inputOriginalModulus) {
RecalculateModulus();
}
ILDCRTParams(const ILDCRTParams& rhs)
: ElemParams<IntType>(rhs), m_params(rhs.m_params), m_originalModulus(rhs.m_originalModulus) {}
ILDCRTParams(ILDCRTParams&& rhs) noexcept
: ElemParams<IntType>(rhs),
m_params(std::move(rhs.m_params)),
m_originalModulus(std::move(rhs.m_originalModulus)) {}
ILDCRTParams& operator=(const ILDCRTParams& rhs) {
ElemParams<IntType>::operator=(rhs);
m_params = rhs.m_params;
m_originalModulus = rhs.m_originalModulus;
return *this;
}
ILDCRTParams& operator=(ILDCRTParams&& rhs) noexcept {
ElemParams<IntType>::operator=(rhs);
m_params = std::move(rhs.m_params);
m_originalModulus = std::move(rhs.m_originalModulus);
return *this;
}
// ACCESSORS
const std::vector<std::shared_ptr<ILNativeParams>>& GetParams() const {
return m_params;
}
std::vector<std::shared_ptr<ILNativeParams>> GetParamPartition(uint32_t start, uint32_t end) const {
if (end < start || end > m_params.size())
OPENFHE_THROW("Incorrect parameters for GetParamPartition - (start: " + std::to_string(start) +
", end:" + std::to_string(end) + ")");
return std::vector<std::shared_ptr<ILNativeParams>>(m_params.begin() + start, m_params.begin() + end + 1);
}
const IntType& GetOriginalModulus() const {
return m_originalModulus;
}
void SetOriginalModulus(const IntType& inputOriginalModulus) {
m_originalModulus = inputOriginalModulus;
}
std::shared_ptr<ILNativeParams>& operator[](size_t i) {
return m_params[i];
}
const std::shared_ptr<ILNativeParams>& operator[](size_t i) const {
return m_params[i];
}
void PopLastParam() {
ElemParams<IntType>::m_ciphertextModulus /=
IntType(m_params.back()->GetModulus().template ConvertToInt<BasicInteger>());
m_params.pop_back();
}
void PopFirstParam() {
ElemParams<IntType>::m_ciphertextModulus /=
IntType(m_params[0]->GetModulus().template ConvertToInt<BasicInteger>());
m_params.erase(m_params.begin());
}
~ILDCRTParams() override = default;
bool operator==(const ElemParams<IntType>& other) const override {
const auto* dcrtParams = dynamic_cast<const ILDCRTParams*>(&other);
if (!dcrtParams)
return false;
if (ElemParams<IntType>::operator==(other) == false)
return false;
if (m_params.size() != dcrtParams->m_params.size())
return false;
for (size_t i = 0; i < m_params.size(); ++i) {
if (*m_params[i] != *dcrtParams->m_params[i])
return false;
}
return (m_originalModulus == dcrtParams->GetOriginalModulus());
}
void RecalculateModulus() {
ElemParams<IntType>::m_ciphertextModulus = 1;
for (size_t i = 0; i < m_params.size(); ++i)
ElemParams<IntType>::m_ciphertextModulus *=
IntType(m_params[i]->GetModulus().template ConvertToInt<BasicInteger>());
}
void RecalculateBigModulus() {
ElemParams<IntType>::m_bigCiphertextModulus = 1;
for (size_t i = 0; i < m_params.size(); ++i)
ElemParams<IntType>::m_bigCiphertextModulus *=
IntType(m_params[i]->GetBigModulus().template ConvertToInt<BasicInteger>());
}
template <class Archive>
void save(Archive& ar, std::uint32_t const version) const {
ar(::cereal::base_class<ElemParams<IntType>>(this));
ar(::cereal::make_nvp("p", m_params));
ar(::cereal::make_nvp("m", m_originalModulus));
}
template <class Archive>
void load(Archive& ar, std::uint32_t const version) {
if (version > SerializedVersion()) {
OPENFHE_THROW("serialized object version " + std::to_string(version) +
" is from a later version of the library");
}
ar(::cereal::base_class<ElemParams<IntType>>(this));
ar(::cereal::make_nvp("p", m_params));
ar(::cereal::make_nvp("m", m_originalModulus));
}
std::string SerializedObjectName() const override {
return "DCRTParams";
}
static uint32_t SerializedVersion() {
return 1;
}
private:
std::ostream& doprint(std::ostream& out) const override {
out << "ILDCRTParams ";
ElemParams<IntType>::doprint(out);
out << std::endl << " m_params:" << std::endl;
for (size_t i = 0; i < m_params.size(); ++i)
out << " " << i << ": " << *m_params[i];
return out << " m_originalModulus: " << m_originalModulus << std::endl;
}
// array of smaller ILParams
std::vector<std::shared_ptr<ILNativeParams>> m_params;
// original modulus when being constructed from a Poly or when
// ctor is passed that parameter
// note orignalModulus will be <= composite modules
// i.e. \Prod_i=0^k-1 m_params[i]->GetModulus()
// note not using ElemParams::ciphertextModulus due to object stripping
IntType m_originalModulus;
};
} // namespace lbcrypto
#endif