Program Listing for File poly-interface.h

Return to documentation for file (core/include/lattice/hal/poly-interface.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  Defines an interface that any DCRT Polynomial implmentation must implement in order to work in OpenFHE.
 */

#ifndef LBCRYPTO_INC_LATTICE_HAL_POLYINTERFACE_H
#define LBCRYPTO_INC_LATTICE_HAL_POLYINTERFACE_H

#include "lattice/ilelement.h"
#include "lattice/hal/default/ilparams.h"

#include "math/math-hal.h"
#include "math/distrgen.h"
#include "math/nbtheory.h"

#include "utils/inttypes.h"
#include "utils/exception.h"

#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace lbcrypto {

template <typename DerivedType, typename VecType, template <typename LVT> typename ContainerType>
class PolyInterface : public ILElement<DerivedType, VecType> {
public:
    using Vector     = VecType;
    using Integer    = typename VecType::Integer;
    using Params     = ILParamsImpl<Integer>;
    using PolyNative = ContainerType<NativeVector>;
    using DggType    = DiscreteGaussianGeneratorImpl<VecType>;
    using DugType    = DiscreteUniformGeneratorImpl<VecType>;
    using TugType    = TernaryUniformGeneratorImpl<VecType>;
    using BugType    = BinaryUniformGeneratorImpl<VecType>;

    DerivedType& GetDerived() {
        return static_cast<DerivedType&>(*this);
    }

    const DerivedType& GetDerived() const {
        return static_cast<DerivedType const&>(*this);
    }

    inline static std::function<DerivedType()> Allocator(const std::shared_ptr<Params>& params, Format format) {
        return [=]() {
            return DerivedType(params, format, true);
        };
    }

    inline static std::function<DerivedType()> MakeDiscreteGaussianCoefficientAllocator(
        const std::shared_ptr<Params>& params, Format resultFormat, double stddev) {
        return [=]() {
            DggType dgg(stddev);
            return DerivedType(dgg, params, resultFormat);
        };
    }

    inline static std::function<DerivedType()> MakeDiscreteUniformAllocator(const std::shared_ptr<Params>& params,
                                                                            Format format) {
        return [=]() {
            DugType dug;
            return DerivedType(dug, params, format);
        };
    }

    DerivedType& operator=(const DerivedType& rhs) override = 0;
    DerivedType& operator=(DerivedType&& rhs) override      = 0;
    DerivedType& operator=(const std::vector<int32_t>& rhs) {
        return this->GetDerived().operator=(rhs);
    }
    DerivedType& operator=(const std::vector<int64_t>& rhs) {
        return this->GetDerived().operator=(rhs);
    }
    DerivedType& operator=(std::initializer_list<uint64_t> rhs) override = 0;
    DerivedType& operator=(std::initializer_list<std::string> rhs) {
        return this->GetDerived().operator=(rhs);
    }
    DerivedType& operator=(uint64_t rhs) {
        return this->GetDerived().operator=(rhs);
    }

    Format GetFormat() const override {
        return this->GetDerived().GetFormat();
    }

    const std::shared_ptr<Params>& GetParams() const {
        return this->GetDerived().GetParams();
    }

    usint GetRingDimension() const {
        return this->GetDerived().GetParams()->GetRingDimension();
    }

    const Integer& GetRootOfUnity() const {
        return this->GetDerived().GetParams()->GetRootOfUnity();
    }

    const Integer& GetModulus() const final {
        return this->GetDerived().GetParams()->GetModulus();
    }

    usint GetCyclotomicOrder() const final {
        return this->GetDerived().GetParams()->GetCyclotomicOrder();
    }

    usint GetLength() const final {
        //        if (this->GetDerived().IsEmpty())
        //            OPENFHE_THROW("No values in PolyImpl");
        return this->GetDerived().GetValues().GetLength();
    }

    const VecType& GetValues() const override = 0;

    Integer& at(usint i) override             = 0;
    const Integer& at(usint i) const override = 0;

    Integer& operator[](usint i) override {
        return this->GetDerived()[i];
    }

    const Integer& operator[](usint i) const override {
        return this->GetDerived()[i];
    }

    DerivedType Plus(const DerivedType& rhs) const override {
        return this->GetDerived().Plus(rhs);
    }

    DerivedType Minus(const DerivedType& element) const override = 0;

    DerivedType Times(const DerivedType& element) const override = 0;

    DerivedType TimesNoCheck(const DerivedType& rhs) const {
        return this->GetDerived().Times(rhs);
    }
    DerivedType Plus(const Integer& element) const override = 0;

    DerivedType Minus(const Integer& element) const override = 0;

    DerivedType Times(const Integer& element) const override = 0;

    DerivedType Times(NativeInteger::SignedNativeInt element) const override = 0;

#if NATIVEINT != 64
    DerivedType Times(int64_t rhs) const {
        return this->GetDerived().Times(rhs);
    }
#endif

    DerivedType MultiplyAndRound(const Integer& p, const Integer& q) const override = 0;

    DerivedType DivideAndRound(const Integer& q) const override = 0;

    virtual DerivedType Negate() const = 0;

    DerivedType operator-() const override = 0;

    DerivedType& operator+=(const Integer& element) override = 0;

    DerivedType& operator-=(const Integer& element) override = 0;

    DerivedType& operator*=(const Integer& element) override = 0;

    DerivedType& operator+=(const DerivedType& rhs) override = 0;

    DerivedType& operator-=(const DerivedType& rhs) override = 0;

    DerivedType& operator*=(const DerivedType& element) override = 0;

    bool operator==(const DerivedType& rhs) const override = 0;

    void AddILElementOne() override = 0;

    DerivedType AutomorphismTransform(uint32_t i) const override = 0;

    DerivedType AutomorphismTransform(uint32_t i, const std::vector<uint32_t>& vec) const override = 0;

    inline DerivedType Transpose() const final {
        if (this->GetDerived().GetFormat() == Format::COEFFICIENT) {
            OPENFHE_THROW(
                "PolyInterface element transposition is currently "
                "implemented only in the Evaluation representation.");
        }
        return this->GetDerived().AutomorphismTransform(this->GetDerived().GetCyclotomicOrder() - 1);
    }

    DerivedType MultiplicativeInverse() const override = 0;

    DerivedType ModByTwo() const override = 0;

    DerivedType Mod(const Integer& modulus) const override = 0;

    void SwitchModulus(const Integer& modulus, const Integer& rootOfUnity, const Integer& modulusArb,
                       const Integer& rootOfUnityArb) override = 0;

    void SwitchFormat() override = 0;

    virtual void OverrideFormat(const Format f) = 0;

    void MakeSparse(uint32_t wFactor) override = 0;

    bool IsEmpty() const override = 0;

    bool InverseExists() const override = 0;

    double Norm() const override = 0;

    std::vector<DerivedType> BaseDecompose(usint baseBits, bool evalModeAnswer) const override = 0;

    std::vector<DerivedType> PowersOfBase(usint baseBits) const override = 0;

    virtual void SetValues(const VecType& values, Format format) = 0;
    virtual void SetValues(VecType&& values, Format format)      = 0;

    virtual void SetValuesToZero() = 0;
    virtual void SetValuesToMax()  = 0;

    DerivedType CRTInterpolate() const {
        return this->GetDerived();
    }

    virtual PolyNative DecryptionCRTInterpolate(PlaintextModulus ptm) const = 0;

    virtual PolyNative ToNativePoly() const = 0;

    DerivedType Clone() const final {
        return DerivedType(this->GetDerived());
    }

    DerivedType CloneEmpty() const final {
        return DerivedType();
    }

    DerivedType CloneParametersOnly() const final {
        return DerivedType(this->GetDerived().GetParams(), this->GetDerived().GetFormat());
    }

    DerivedType CloneWithNoise(const DggType& dgg, Format format) const final {
        return DerivedType(dgg, this->GetDerived().GetParams(), this->GetDerived().GetFormat());
    }

    const std::string GetElementName() const {
        return this->GetDerived().GetElementName();
    }

protected:
    friend inline std::ostream& operator<<(std::ostream& os, const DerivedType& vec) {
        os << (vec.GetFormat() == Format::EVALUATION ? "EVAL: " : "COEF: ") << vec.GetValues();
        return os;
    }

    friend inline DerivedType operator+(const DerivedType& a, const DerivedType& b) {
        return a.Plus(b);
    }
    friend inline DerivedType operator+(const DerivedType& a, const Integer& b) {
        return a.Plus(b);
    }

    friend inline DerivedType operator+(const Integer& a, const DerivedType& b) {
        return b.Plus(a);
    }

    friend inline DerivedType operator-(const DerivedType& a, const DerivedType& b) {
        return a.Minus(b);
    }

    friend inline DerivedType operator-(const DerivedType& a, const Integer& b) {
        return a.Minus(b);
    }

    friend inline DerivedType operator*(const DerivedType& a, const DerivedType& b) {
        return a.Times(b);
    }

    friend inline DerivedType operator*(const DerivedType& a, const Integer& b) {
        return a.Times(b);
    }

    friend inline DerivedType operator*(const Integer& a, const DerivedType& b) {
        return b.Times(a);
    }

    friend inline DerivedType operator*(const DerivedType& a, int64_t b) {
        return a.Times((NativeInteger::SignedNativeInt)b);
    }

    friend inline DerivedType operator*(int64_t a, const DerivedType& b) {
        return b.Times((NativeInteger::SignedNativeInt)a);
    }
};

}  // namespace lbcrypto

#endif