Program Listing for File dcrtpoly.h
↰ Return to documentation for file (core/include/lattice/hal/default/dcrtpoly.h
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
Represents integer lattice elements with double-CRT
*/
#ifndef LBCRYPTO_INC_LATTICE_HAL_DEFAULT_DCRTPOLY_H
#define LBCRYPTO_INC_LATTICE_HAL_DEFAULT_DCRTPOLY_H
#include "lattice/hal/default/ildcrtparams.h"
#include "lattice/hal/default/poly.h"
#include "lattice/hal/dcrtpoly-interface.h"
#include "math/math-hal.h"
#include "math/distrgen.h"
#include "utils/exception.h"
#include "utils/inttypes.h"
#include "utils/parallel.h"
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace lbcrypto {
template <typename VecType>
class DCRTPolyImpl final : public DCRTPolyInterface<DCRTPolyImpl<VecType>, VecType, NativeVector, PolyImpl> {
public:
using Vector = VecType;
using Integer = typename VecType::Integer;
using Params = ILDCRTParams<Integer>;
using PolyType = PolyImpl<NativeVector>;
using PolyLargeType = PolyImpl<VecType>;
using DCRTPolyType = DCRTPolyImpl<VecType>;
using DCRTPolyInterfaceType = DCRTPolyInterface<DCRTPolyImpl<VecType>, VecType, NativeVector, PolyImpl>;
using Precomputations = typename DCRTPolyInterfaceType::CRTBasisExtensionPrecomputations;
using DggType = typename DCRTPolyInterfaceType::DggType;
using DugType = typename DCRTPolyInterfaceType::DugType;
using TugType = typename DCRTPolyInterfaceType::TugType;
using BugType = typename DCRTPolyInterfaceType::BugType;
DCRTPolyImpl() = default;
DCRTPolyImpl(const DCRTPolyType& e) noexcept : m_params{e.m_params}, m_format{e.m_format}, m_vectors{e.m_vectors} {}
DCRTPolyType& operator=(const DCRTPolyType& rhs) noexcept override {
m_params = rhs.m_params;
m_format = rhs.m_format;
m_vectors = rhs.m_vectors;
return *this;
}
DCRTPolyImpl(const PolyLargeType& e, const std::shared_ptr<Params>& params) noexcept;
DCRTPolyType& operator=(const PolyLargeType& rhs) noexcept;
DCRTPolyImpl(const PolyType& e, const std::shared_ptr<Params>& params) noexcept;
DCRTPolyType& operator=(const PolyType& rhs) noexcept;
DCRTPolyImpl(DCRTPolyType&& e) noexcept
: m_params{std::move(e.m_params)}, m_format{e.m_format}, m_vectors{std::move(e.m_vectors)} {}
DCRTPolyType& operator=(DCRTPolyType&& rhs) noexcept override {
m_params = std::move(rhs.m_params);
m_format = std::move(rhs.m_format);
m_vectors = std::move(rhs.m_vectors);
return *this;
}
explicit DCRTPolyImpl(const std::vector<PolyType>& elements);
DCRTPolyImpl(const std::shared_ptr<Params>& params, Format format = Format::EVALUATION,
bool initializeElementToZero = false) noexcept
: m_params{params}, m_format{format} {
m_vectors.reserve(m_params->GetParams().size());
for (const auto& p : m_params->GetParams())
m_vectors.emplace_back(p, m_format, initializeElementToZero);
}
DCRTPolyImpl(const DggType& dgg, const std::shared_ptr<Params>& p, Format f = Format::EVALUATION);
DCRTPolyImpl(const BugType& bug, const std::shared_ptr<Params>& p, Format f = Format::EVALUATION);
DCRTPolyImpl(const TugType& tug, const std::shared_ptr<Params>& p, Format f = Format::EVALUATION, uint32_t h = 0);
DCRTPolyImpl(DugType& dug, const std::shared_ptr<Params>& p, Format f = Format::EVALUATION);
DCRTPolyType& operator=(std::initializer_list<uint64_t> rhs) noexcept override;
DCRTPolyType& operator=(uint64_t val) noexcept;
DCRTPolyType& operator=(const std::vector<int64_t>& rhs) noexcept;
DCRTPolyType& operator=(const std::vector<int32_t>& rhs) noexcept;
DCRTPolyType& operator=(std::initializer_list<std::string> rhs) noexcept;
DCRTPolyType CloneWithNoise(const DiscreteGaussianGeneratorImpl<VecType>& dgg, Format format) const override;
DCRTPolyType CloneTowers(uint32_t startTower, uint32_t endTower) const;
bool operator==(const DCRTPolyType& rhs) const override;
DCRTPolyType& operator+=(const DCRTPolyType& rhs) override;
DCRTPolyType& operator+=(const Integer& rhs) override;
DCRTPolyType& operator+=(const NativeInteger& rhs) override;
DCRTPolyType& operator-=(const DCRTPolyType& rhs) override;
DCRTPolyType& operator-=(const Integer& rhs) override;
DCRTPolyType& operator-=(const NativeInteger& rhs) override;
DCRTPolyType& operator*=(const DCRTPolyType& rhs) override {
size_t size{m_vectors.size()};
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(size))
for (size_t i = 0; i < size; ++i)
m_vectors[i] *= rhs.m_vectors[i];
return *this;
}
DCRTPolyType& operator*=(const Integer& rhs) override;
DCRTPolyType& operator*=(const NativeInteger& rhs) override;
DCRTPolyType Negate() const override;
DCRTPolyType operator-() const override;
std::vector<DCRTPolyType> BaseDecompose(usint baseBits, bool evalModeAnswer) const override;
std::vector<DCRTPolyType> PowersOfBase(usint baseBits) const override;
std::vector<DCRTPolyType> CRTDecompose(uint32_t baseBits) const;
DCRTPolyType AutomorphismTransform(uint32_t i) const override;
DCRTPolyType AutomorphismTransform(uint32_t i, const std::vector<uint32_t>& vec) const override;
DCRTPolyType Plus(const Integer& rhs) const override;
DCRTPolyType Plus(const std::vector<Integer>& rhs) const;
DCRTPolyType Plus(const DCRTPolyType& rhs) const override {
if (m_params->GetRingDimension() != rhs.m_params->GetRingDimension())
OPENFHE_THROW("RingDimension missmatch");
if (m_format != rhs.m_format)
OPENFHE_THROW("Format missmatch");
size_t size{m_vectors.size()};
if (size != rhs.m_vectors.size())
OPENFHE_THROW("tower size mismatch; cannot add");
if (m_vectors[0].GetModulus() != rhs.m_vectors[0].GetModulus())
OPENFHE_THROW("Modulus missmatch");
DCRTPolyType tmp(m_params, m_format);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(size))
for (size_t i = 0; i < size; ++i)
tmp.m_vectors[i] = m_vectors[i].PlusNoCheck(rhs.m_vectors[i]);
return tmp;
}
DCRTPolyType Minus(const DCRTPolyType& rhs) const override;
DCRTPolyType Minus(const Integer& rhs) const override;
DCRTPolyType Minus(const std::vector<Integer>& rhs) const;
DCRTPolyType Times(const DCRTPolyType& rhs) const override {
if (m_params->GetRingDimension() != rhs.m_params->GetRingDimension())
OPENFHE_THROW("RingDimension missmatch");
if (m_format != Format::EVALUATION || rhs.m_format != Format::EVALUATION)
OPENFHE_THROW("operator* for DCRTPolyImpl supported only in Format::EVALUATION");
size_t size{m_vectors.size()};
if (size != rhs.m_vectors.size())
OPENFHE_THROW("tower size mismatch; cannot multiply");
if (m_vectors[0].GetModulus() != rhs.m_vectors[0].GetModulus())
OPENFHE_THROW("Modulus missmatch");
DCRTPolyType tmp(m_params, m_format);
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(size))
for (size_t i = 0; i < size; ++i)
tmp.m_vectors[i] = m_vectors[i].TimesNoCheck(rhs.m_vectors[i]);
return tmp;
}
DCRTPolyType Times(const Integer& rhs) const override;
DCRTPolyType Times(const std::vector<Integer>& rhs) const;
DCRTPolyType Times(NativeInteger::SignedNativeInt rhs) const override;
#if NATIVEINT != 64
DCRTPolyType Times(int64_t rhs) const {
return Times(static_cast<NativeInteger::SignedNativeInt>(rhs));
}
#endif
DCRTPolyType Times(const std::vector<NativeInteger>& rhs) const;
DCRTPolyType TimesNoCheck(const std::vector<NativeInteger>& rhs) const;
DCRTPolyType MultiplicativeInverse() const override;
bool InverseExists() const override;
bool IsEmpty() const override;
void SetValuesToZero() override;
void AddILElementOne() override;
void DropLastElement() override;
void DropLastElements(size_t i) override;
void DropLastElementAndScale(const std::vector<NativeInteger>& QlQlInvModqlDivqlModq,
const std::vector<NativeInteger>& qlInvModq) override;
void ModReduce(const NativeInteger& t, const std::vector<NativeInteger>& tModqPrecon,
const NativeInteger& negtInvModq, const NativeInteger& negtInvModqPrecon,
const std::vector<NativeInteger>& qlInvModq,
const std::vector<NativeInteger>& qlInvModqPrecon) override;
PolyLargeType CRTInterpolate() const override;
PolyType DecryptionCRTInterpolate(PlaintextModulus ptm) const override;
PolyType ToNativePoly() const override;
PolyLargeType CRTInterpolateIndex(usint i) const override;
Integer GetWorkingModulus() const override;
void SetValuesModSwitch(const DCRTPolyType& element, const NativeInteger& modulus) override;
std::shared_ptr<Params> GetExtendedCRTBasis(const std::shared_ptr<Params>& paramsP) const override;
void TimesQovert(const std::shared_ptr<Params>& paramsQ, const std::vector<NativeInteger>& tInvModq,
const NativeInteger& t, const NativeInteger& NegQModt,
const NativeInteger& NegQModtPrecon) override;
DCRTPolyType ApproxSwitchCRTBasis(const std::shared_ptr<Params>& paramsQ, const std::shared_ptr<Params>& paramsP,
const std::vector<NativeInteger>& QHatInvModq,
const std::vector<NativeInteger>& QHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModp,
const std::vector<DoubleNativeInt>& modpBarrettMu) const override;
void ApproxModUp(const std::shared_ptr<Params>& paramsQ, const std::shared_ptr<Params>& paramsP,
const std::shared_ptr<Params>& paramsQP, const std::vector<NativeInteger>& QHatInvModq,
const std::vector<NativeInteger>& QHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModp,
const std::vector<DoubleNativeInt>& modpBarrettMu) override;
DCRTPolyType ApproxModDown(
const std::shared_ptr<Params>& paramsQ, const std::shared_ptr<Params>& paramsP,
const std::vector<NativeInteger>& PInvModq, const std::vector<NativeInteger>& PInvModqPrecon,
const std::vector<NativeInteger>& PHatInvModp, const std::vector<NativeInteger>& PHatInvModpPrecon,
const std::vector<std::vector<NativeInteger>>& PHatModq, const std::vector<DoubleNativeInt>& modqBarrettMu,
const std::vector<NativeInteger>& tInvModp, const std::vector<NativeInteger>& tInvModpPrecon,
const NativeInteger& t, const std::vector<NativeInteger>& tModqPrecon) const override;
DCRTPolyType SwitchCRTBasis(const std::shared_ptr<Params>& paramsP, const std::vector<NativeInteger>& QHatInvModq,
const std::vector<NativeInteger>& QHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModp,
const std::vector<std::vector<NativeInteger>>& alphaQModp,
const std::vector<DoubleNativeInt>& modpBarrettMu,
const std::vector<double>& qInv) const override;
void ExpandCRTBasis(const std::shared_ptr<Params>& paramsQP, const std::shared_ptr<Params>& paramsP,
const std::vector<NativeInteger>& QHatInvModq,
const std::vector<NativeInteger>& QHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModp,
const std::vector<std::vector<NativeInteger>>& alphaQModp,
const std::vector<DoubleNativeInt>& modpBarrettMu, const std::vector<double>& qInv,
Format resultFormat) override;
void ExpandCRTBasisReverseOrder(const std::shared_ptr<Params>& paramsQP, const std::shared_ptr<Params>& paramsP,
const std::vector<NativeInteger>& QHatInvModq,
const std::vector<NativeInteger>& QHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModp,
const std::vector<std::vector<NativeInteger>>& alphaQModp,
const std::vector<DoubleNativeInt>& modpBarrettMu, const std::vector<double>& qInv,
Format resultFormat) override;
void FastExpandCRTBasisPloverQ(const Precomputations& precomputed) override;
void ExpandCRTBasisQlHat(const std::shared_ptr<Params>& paramsQ, const std::vector<NativeInteger>& QlHatModq,
const std::vector<NativeInteger>& QlHatModqPrecon, const usint sizeQ) override;
PolyType ScaleAndRound(const NativeInteger& t, const std::vector<NativeInteger>& tQHatInvModqDivqModt,
const std::vector<NativeInteger>& tQHatInvModqDivqModtPrecon,
const std::vector<NativeInteger>& tQHatInvModqBDivqModt,
const std::vector<NativeInteger>& tQHatInvModqBDivqModtPrecon,
const std::vector<double>& tQHatInvModqDivqFrac,
const std::vector<double>& tQHatInvModqBDivqFrac) const override;
DCRTPolyType ApproxScaleAndRound(const std::shared_ptr<Params>& paramsP,
const std::vector<std::vector<NativeInteger>>& tPSHatInvModsDivsModp,
const std::vector<DoubleNativeInt>& modpBarretMu) const override;
DCRTPolyType ScaleAndRound(const std::shared_ptr<Params>& paramsOutput,
const std::vector<std::vector<NativeInteger>>& tOSHatInvModsDivsModo,
const std::vector<double>& tOSHatInvModsDivsFrac,
const std::vector<DoubleNativeInt>& modoBarretMu) const override;
PolyType ScaleAndRound(const std::vector<NativeInteger>& moduliQ, const NativeInteger& t,
const NativeInteger& tgamma, const std::vector<NativeInteger>& tgammaQHatModq,
const std::vector<NativeInteger>& tgammaQHatModqPrecon,
const std::vector<NativeInteger>& negInvqModtgamma,
const std::vector<NativeInteger>& negInvqModtgammaPrecon) const override;
void ScaleAndRoundPOverQ(const std::shared_ptr<Params>& paramsQ,
const std::vector<NativeInteger>& pInvModq) override;
void FastBaseConvqToBskMontgomery(
const std::shared_ptr<Params>& paramsQBsk, const std::vector<NativeInteger>& moduliQ,
const std::vector<NativeInteger>& moduliBsk, const std::vector<DoubleNativeInt>& modbskBarrettMu,
const std::vector<NativeInteger>& mtildeQHatInvModq, const std::vector<NativeInteger>& mtildeQHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModbsk, const std::vector<uint64_t>& QHatModmtilde,
const std::vector<NativeInteger>& QModbsk, const std::vector<NativeInteger>& QModbskPrecon,
const uint64_t& negQInvModmtilde, const std::vector<NativeInteger>& mtildeInvModbsk,
const std::vector<NativeInteger>& mtildeInvModbskPrecon) override;
void FastRNSFloorq(const NativeInteger& t, const std::vector<NativeInteger>& moduliQ,
const std::vector<NativeInteger>& moduliBsk, const std::vector<DoubleNativeInt>& modbskBarrettMu,
const std::vector<NativeInteger>& tQHatInvModq,
const std::vector<NativeInteger>& tQHatInvModqPrecon,
const std::vector<std::vector<NativeInteger>>& QHatModbsk,
const std::vector<std::vector<NativeInteger>>& qInvModbsk,
const std::vector<NativeInteger>& tQInvModbsk,
const std::vector<NativeInteger>& tQInvModbskPrecon) override;
void FastBaseConvSK(const std::shared_ptr<Params>& paramsQ, const std::vector<DoubleNativeInt>& modqBarrettMu,
const std::vector<NativeInteger>& moduliBsk,
const std::vector<DoubleNativeInt>& modbskBarrettMu,
const std::vector<NativeInteger>& BHatInvModb,
const std::vector<NativeInteger>& BHatInvModbPrecon,
const std::vector<NativeInteger>& BHatModmsk, const NativeInteger& BInvModmsk,
const NativeInteger& BInvModmskPrecon, const std::vector<std::vector<NativeInteger>>& BHatModq,
const std::vector<NativeInteger>& BModq,
const std::vector<NativeInteger>& BModqPrecon) override;
void SwitchFormat() override;
void SwitchModulusAtIndex(size_t index, const Integer& modulus, const Integer& rootOfUnity) override;
template <class Archive>
void save(Archive& ar, std::uint32_t const version) const {
ar(::cereal::make_nvp("v", m_vectors));
ar(::cereal::make_nvp("f", m_format));
ar(::cereal::make_nvp("p", m_params));
}
template <class Archive>
void load(Archive& ar, std::uint32_t const version) {
if (version > SerializedVersion()) {
OPENFHE_THROW("serialized object version " + std::to_string(version) +
" is from a later version of the library");
}
ar(::cereal::make_nvp("v", m_vectors));
ar(::cereal::make_nvp("f", m_format));
ar(::cereal::make_nvp("p", m_params));
}
static const std::string GetElementName() {
return "DCRTPolyImpl";
}
std::string SerializedObjectName() const override {
return "DCRTPoly";
}
static uint32_t SerializedVersion() {
return 1;
}
inline Format GetFormat() const final {
return m_format;
}
void OverrideFormat(const Format f) final {
m_format = f;
}
inline const std::shared_ptr<Params>& GetParams() const {
return m_params;
}
const std::vector<PolyType>& GetAllElements() const {
return m_vectors;
}
std::vector<PolyType>& GetAllElements() {
return m_vectors;
}
void SetElementAtIndex(usint index, const PolyType& element) {
m_vectors[index] = element;
}
void SetElementAtIndex(usint index, PolyType&& element) {
m_vectors[index] = std::move(element);
}
protected:
std::shared_ptr<Params> m_params{std::make_shared<DCRTPolyImpl::Params>()};
Format m_format{Format::EVALUATION};
std::vector<PolyType> m_vectors;
};
} // namespace lbcrypto
#endif