Program Listing for File dcrtpoly-interface.h

Return to documentation for file (core/include/lattice/hal/dcrtpoly-interface.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  Defines an interface that any DCRT Polynomial implmentation must implement in order to work in OpenFHE.
 */

#ifndef LBCRYPTO_INC_LATTICE_HAL_DCRTPOLYINTERFACE_H
#define LBCRYPTO_INC_LATTICE_HAL_DCRTPOLYINTERFACE_H

#include "lattice/hal/default/ildcrtparams.h"
#include "lattice/ilelement.h"

#include "math/math-hal.h"
#include "math/distrgen.h"

#include "utils/inttypes.h"
#include "utils/exception.h"

#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace lbcrypto {

// TODO: CRTP with ILElement to remove virtual overhead

template <typename DerivedType, typename BigVecType, typename LilVecType,
          template <typename LVT> typename RNSContainerType>
class DCRTPolyInterface : public ILElement<DerivedType, BigVecType> {
public:
    using BigIntType    = typename BigVecType::Integer;
    using Params        = ILDCRTParams<BigIntType>;
    using LilIntType    = typename LilVecType::Integer;
    using TowerType     = RNSContainerType<LilVecType>;
    using PolyLargeType = RNSContainerType<BigVecType>;
    using DggType       = DiscreteGaussianGeneratorImpl<LilVecType>;
    using DugType       = DiscreteUniformGeneratorImpl<LilVecType>;
    using TugType       = TernaryUniformGeneratorImpl<LilVecType>;
    using BugType       = BinaryUniformGeneratorImpl<LilVecType>;

    DerivedType& GetDerived() {
        return static_cast<DerivedType&>(*this);
    }

    const DerivedType& GetDerived() const {
        return static_cast<DerivedType const&>(*this);
    }

    static std::function<DerivedType()> Allocator(const std::shared_ptr<Params>& params, Format format) {
        return [=]() {
            return DerivedType(params, format, true);
        };
    }

    static std::function<DerivedType()> MakeDiscreteGaussianCoefficientAllocator(const std::shared_ptr<Params>& params,
                                                                                 Format resultFormat, double stddev) {
        return [=]() {
            DggType dgg(stddev);
            return DerivedType(dgg, params, resultFormat);
        };
    }

    static std::function<DerivedType()> MakeDiscreteUniformAllocator(const std::shared_ptr<Params>& params,
                                                                     Format format) {
        return [=]() {
            DugType dug;
            return DerivedType(dug, params, format);
        };
    }

    DerivedType CloneTowers(uint32_t startTower, uint32_t endTower) {
        return this->GetDerived().CloneTowers(startTower, endTower);
    }

    DerivedType Clone() const final {
        return DerivedType(this->GetDerived());
    }

    DerivedType CloneEmpty() const final {
        return DerivedType();
    }

    DerivedType CloneParametersOnly() const final {
        return DerivedType(this->GetDerived().GetParams(), this->GetDerived().GetFormat());
    }

    DerivedType CloneWithNoise(const DiscreteGaussianGeneratorImpl<BigVecType>& dgg, Format format) const override = 0;

    Format GetFormat() const override = 0;
    //    return this->GetDerived().GetFormat();

    const std::shared_ptr<Params>& GetParams() const {
        return this->GetDerived().GetParams();
    }

    usint GetCyclotomicOrder() const final {
        return this->GetDerived().GetParams()->GetCyclotomicOrder();
    }

    usint GetRingDimension() const {
        return this->GetDerived().GetParams()->GetRingDimension();
    }

    const BigIntType& GetModulus() const final {
        return this->GetDerived().GetParams()->GetModulus();
    }

    const BigIntType& GetOriginalModulus() const {
        return this->GetDerived().GetParams()->GetOriginalModulus();
    }

    // TODO: this doesn't look right
    const BigIntType GetRootOfUnity() const {
        //        return BigIntType(0);
        return this->GetDerived().GetParams()->GetRootOfUnity();
    }

    usint GetLength() const final {
        return this->GetDerived().GetParams()->GetRingDimension();
    }

    BigIntType& at(usint i) final {
        OPENFHE_THROW("at() not implemented for DCRTPoly");
    }
    const BigIntType& at(usint i) const final {
        OPENFHE_THROW("const at() not implemented for DCRTPoly");
    }

    BigIntType& operator[](usint i) final {
        OPENFHE_THROW("operator[] not implemented for DCRTPoly");
    }
    const BigIntType& operator[](usint i) const final {
        OPENFHE_THROW("const operator[] not implemented for DCRTPoly");
    }

    const std::vector<TowerType>& GetAllElements() const {
        return this->GetDerived().GetAllElements();
    }
    std::vector<TowerType>& GetAllElements() {
        return this->GetDerived().GetAllElements();
    }

    usint GetNumOfElements() const {
        return this->GetDerived().GetAllElements().size();
    }

    const TowerType& GetElementAtIndex(usint i) const {
        return this->GetDerived().GetAllElements()[i];
    }

    void SetElementAtIndex(usint index, const TowerType& element) {
        return this->GetDerived().SetElementAtIndex(index, element);
    }

    void SetElementAtIndex(usint index, TowerType&& element) {
        return this->GetDerived().SetElementAtIndex(index, element);
    }

    /***********************************************************************
   * Yuriy and I stopped here!
   **********************************************************************/

    std::vector<DerivedType> BaseDecompose(usint baseBits, bool evalModeAnswer) const override = 0;

    std::vector<DerivedType> PowersOfBase(usint baseBits) const override = 0;

    std::vector<DerivedType> CRTDecompose(uint32_t baseBits) const {
        return this->GetDerived().CRTDecompose(baseBits);
    }

    DerivedType& operator=(const TowerType& rhs) {
        return this->GetDerived().operator=(rhs);
    }

    DerivedType& operator=(const DerivedType& rhs) override = 0;

    DerivedType& operator=(DerivedType&& rhs) override = 0;

    DerivedType& operator=(std::initializer_list<uint64_t> rhs) override = 0;

    DerivedType& operator=(uint64_t val) {
        return this->GetDerived().operator=(val);
    }

    DerivedType& operator=(const std::vector<int64_t>& rhs) {
        return this->GetDerived().operator=(rhs);
    }

    DerivedType& operator=(const std::vector<int32_t>& rhs) {
        return this->GetDerived().operator=(rhs);
    }

    DerivedType& operator=(std::initializer_list<std::string> rhs) {
        return this->GetDerived().operator=(rhs);
    }

    DerivedType operator-() const override = 0;

    bool operator==(const DerivedType& rhs) const override = 0;

    DerivedType& operator+=(const DerivedType& rhs) override = 0;

    DerivedType& operator-=(const DerivedType& rhs) override = 0;

    DerivedType AutomorphismTransform(uint32_t i) const override = 0;

    DerivedType AutomorphismTransform(uint32_t i, const std::vector<uint32_t>& vec) const override = 0;

    DerivedType Transpose() const final {
        if (this->GetDerived().GetFormat() == Format::COEFFICIENT)
            OPENFHE_THROW(
                "DCRTPolyInterface element transposition is currently "
                "implemented only in the Evaluation representation.");
        return this->GetDerived().AutomorphismTransform(this->GetDerived().GetCyclotomicOrder() - 1);
    }

    DerivedType Plus(const DerivedType& rhs) const override = 0;
    //    return this->GetDerived().Plus(rhs);

    DerivedType Times(const DerivedType& rhs) const override = 0;
    //    return this->GetDerived().Times(rhs);

    DerivedType Minus(const DerivedType& rhs) const override = 0;

    DerivedType Plus(const BigIntType& rhs) const override = 0;

    DerivedType Plus(const std::vector<BigIntType>& rhs) const {
        return this->GetDerived().Plus(rhs);
    }

    DerivedType Minus(const BigIntType& rhs) const override = 0;

    DerivedType Minus(const std::vector<BigIntType>& rhs) const {
        return this->GetDerived().Minus(rhs);
    }

    DerivedType Times(const BigIntType& rhs) const override = 0;

    DerivedType Times(NativeInteger::SignedNativeInt rhs) const override = 0;

#if NATIVEINT != 64
    DerivedType Times(int64_t rhs) const {
        return this->GetDerived().Times(rhs);
    }
#endif

    DerivedType Times(const std::vector<NativeInteger>& rhs) const {
        return this->GetDerived().Times(rhs);
    }

    DerivedType TimesNoCheck(const std::vector<NativeInteger>& rhs) const {
        return this->GetDerived().TimesNoCheck(rhs);
    }

    DerivedType Times(const std::vector<BigIntType>& rhs) const {
        return this->GetDerived().Times(rhs);
    }

    DerivedType MultiplyAndRound(const BigIntType& p, const BigIntType& q) const final {
        OPENFHE_THROW("MultiplyAndRound not implemented for DCRTPoly");
    }

    DerivedType DivideAndRound(const BigIntType& q) const final {
        OPENFHE_THROW("DivideAndRound not implemented for DCRTPoly");
    }

    virtual DerivedType Negate() const = 0;

    DerivedType& operator+=(const BigIntType& rhs) override = 0;
    virtual DerivedType& operator+=(const LilIntType& rhs)  = 0;

    DerivedType& operator-=(const BigIntType& rhs) override = 0;
    virtual DerivedType& operator-=(const LilIntType& rhs)  = 0;

    DerivedType& operator*=(const BigIntType& rhs) override = 0;
    virtual DerivedType& operator*=(const LilIntType& rhs)  = 0;

    DerivedType& operator*=(const DerivedType& rhs) override = 0;
    //    return this->GetDerived().operator*=(rhs);

    // multiplicative inverse operation
    DerivedType MultiplicativeInverse() const override = 0;

    DerivedType ModByTwo() const final {
        OPENFHE_THROW("Mod of a BigIntType not implemented for DCRTPoly");
    }

    DerivedType Mod(const BigIntType& modulus) const final {
        OPENFHE_THROW("Mod of a BigIntType not implemented for DCRTPoly");
    }

    const BigVecType& GetValues() const final {
        OPENFHE_THROW("GetValues not implemented for DCRTPoly");
    }

    void SetValues(const BigVecType& values, Format format) {
        OPENFHE_THROW("SetValues not implemented for DCRTPoly");
    }

    virtual void SetValuesToZero() = 0;

    virtual void SetValuesModSwitch(const DerivedType& element, const NativeInteger& modulus) = 0;

    void AddILElementOne() override = 0;

    DerivedType AddRandomNoise(const BigIntType& modulus) const {
        OPENFHE_THROW("AddRandomNoise is not currently implemented for DCRTPoly");
    }

    void MakeSparse(uint32_t wFactor) final {
        OPENFHE_THROW("MakeSparse is not currently implemented for DCRTPoly");
    }

    bool IsEmpty() const override = 0;

    virtual void DropLastElement() = 0;

    virtual void DropLastElements(size_t i) = 0;

    virtual void DropLastElementAndScale(const std::vector<NativeInteger>& QlQlInvModqlDivqlModq,
                                         const std::vector<NativeInteger>& qlInvModq) = 0;

    virtual void ModReduce(const NativeInteger& t, const std::vector<NativeInteger>& tModqPrecon,
                           const NativeInteger& negtInvModq, const NativeInteger& negtInvModqPrecon,
                           const std::vector<NativeInteger>& qlInvModq,
                           const std::vector<NativeInteger>& qlInvModqPrecon) = 0;

    virtual PolyLargeType CRTInterpolate() const = 0;

    virtual TowerType DecryptionCRTInterpolate(PlaintextModulus ptm) const = 0;

    virtual TowerType ToNativePoly() const = 0;

    virtual PolyLargeType CRTInterpolateIndex(usint i) const = 0;

    virtual BigIntType GetWorkingModulus() const = 0;

    virtual std::shared_ptr<Params> GetExtendedCRTBasis(const std::shared_ptr<Params>& paramsP) const = 0;

    virtual void TimesQovert(const std::shared_ptr<Params>& paramsQ, const std::vector<NativeInteger>& tInvModq,
                             const NativeInteger& t, const NativeInteger& NegQModt,
                             const NativeInteger& NegQModtPrecon) = 0;

    virtual DerivedType ApproxSwitchCRTBasis(const std::shared_ptr<Params>& paramsQ,
                                             const std::shared_ptr<Params>& paramsP,
                                             const std::vector<NativeInteger>& QHatInvModq,
                                             const std::vector<NativeInteger>& QHatInvModqPrecon,
                                             const std::vector<std::vector<NativeInteger>>& QHatModp,
                                             const std::vector<DoubleNativeInt>& modpBarrettMu) const = 0;

    virtual void ApproxModUp(const std::shared_ptr<Params>& paramsQ, const std::shared_ptr<Params>& paramsP,
                             const std::shared_ptr<Params>& paramsQP, const std::vector<NativeInteger>& QHatInvModq,
                             const std::vector<NativeInteger>& QHatInvModqPrecon,
                             const std::vector<std::vector<NativeInteger>>& QHatModp,
                             const std::vector<DoubleNativeInt>& modpBarrettMu) = 0;

    virtual DerivedType ApproxModDown(
        const std::shared_ptr<Params>& paramsQ, const std::shared_ptr<Params>& paramsP,
        const std::vector<NativeInteger>& PInvModq, const std::vector<NativeInteger>& PInvModqPrecon,
        const std::vector<NativeInteger>& PHatInvModp, const std::vector<NativeInteger>& PHatInvModpPrecon,
        const std::vector<std::vector<NativeInteger>>& PHatModq, const std::vector<DoubleNativeInt>& modqBarrettMu,
        const std::vector<NativeInteger>& tInvModp, const std::vector<NativeInteger>& tInvModpPrecon,
        const NativeInteger& t, const std::vector<NativeInteger>& tModqPrecon) const = 0;

    virtual DerivedType SwitchCRTBasis(const std::shared_ptr<Params>& paramsP,
                                       const std::vector<NativeInteger>& QHatInvModq,
                                       const std::vector<NativeInteger>& QHatInvModqPrecon,
                                       const std::vector<std::vector<NativeInteger>>& QHatModp,
                                       const std::vector<std::vector<NativeInteger>>& alphaQModp,
                                       const std::vector<DoubleNativeInt>& modpBarrettMu,
                                       const std::vector<double>& qInv) const = 0;

    virtual void ExpandCRTBasis(const std::shared_ptr<Params>& paramsQP, const std::shared_ptr<Params>& paramsP,
                                const std::vector<NativeInteger>& QHatInvModq,
                                const std::vector<NativeInteger>& QHatInvModqPrecon,
                                const std::vector<std::vector<NativeInteger>>& QHatModp,
                                const std::vector<std::vector<NativeInteger>>& alphaQModp,
                                const std::vector<DoubleNativeInt>& modpBarrettMu, const std::vector<double>& qInv,
                                Format resultFormat) = 0;

    virtual void ExpandCRTBasisReverseOrder(const std::shared_ptr<Params>& paramsQP,
                                            const std::shared_ptr<Params>& paramsP,
                                            const std::vector<NativeInteger>& QHatInvModq,
                                            const std::vector<NativeInteger>& QHatInvModqPrecon,
                                            const std::vector<std::vector<NativeInteger>>& QHatModp,
                                            const std::vector<std::vector<NativeInteger>>& alphaQModp,
                                            const std::vector<DoubleNativeInt>& modpBarrettMu,
                                            const std::vector<double>& qInv, Format resultFormat) = 0;

    struct CRTBasisExtensionPrecomputations {
        // TODO (dsuponit) and (pascoec): make the data members private to enforce their constantness and add getters.
        std::shared_ptr<Params> paramsQlPl;
        std::shared_ptr<Params> paramsPl;
        std::shared_ptr<Params> paramsQl;
        std::vector<NativeInteger> mPlQHatInvModq;
        std::vector<NativeInteger> mPlQHatInvModqPrecon;
        std::vector<std::vector<NativeInteger>> qInvModp;
        std::vector<DoubleNativeInt> modpBarrettMu;
        std::vector<NativeInteger> PlHatInvModp;
        std::vector<NativeInteger> PlHatInvModpPrecon;
        std::vector<std::vector<NativeInteger>> PlHatModq;
        std::vector<std::vector<NativeInteger>> alphaPlModq;
        std::vector<DoubleNativeInt> modqBarrettMu;
        std::vector<double> pInv;

        // clang-format off
        CRTBasisExtensionPrecomputations(
            const std::shared_ptr<Params>& paramsQlPl0,
            const std::shared_ptr<Params>& paramsPl0,
            const std::shared_ptr<Params>& paramsQl0,
            const std::vector<NativeInteger>& mPlQHatInvModq0,
            const std::vector<NativeInteger>& mPlQHatInvModqPrecon0,
            const std::vector<std::vector<NativeInteger>>& qInvModp0,
            const std::vector<DoubleNativeInt>& modpBarrettMu0,
            const std::vector<NativeInteger>& PlHatInvModp0,
            const std::vector<NativeInteger>& PlHatInvModpPrecon0,
            const std::vector<std::vector<NativeInteger>>& PlHatModq0,
            const std::vector<std::vector<NativeInteger>>& alphaPlModq0,
            const std::vector<DoubleNativeInt>& modqBarrettMu0,
            const std::vector<double>& pInv0)
            : paramsQlPl(paramsQlPl0),
              paramsPl(paramsPl0),
              paramsQl(paramsQl0),
              mPlQHatInvModq(mPlQHatInvModq0),
              mPlQHatInvModqPrecon(mPlQHatInvModqPrecon0),
              qInvModp(qInvModp0),
              modpBarrettMu(modpBarrettMu0),
              PlHatInvModp(PlHatInvModp0),
              PlHatInvModpPrecon(PlHatInvModpPrecon0),
              PlHatModq(PlHatModq0),
              alphaPlModq(alphaPlModq0),
              modqBarrettMu(modqBarrettMu0),
              pInv(pInv0) {}
        // clang-format on
    };
    typedef struct CRTBasisExtensionPrecomputations CRTBasisExtensionPrecomputations;

    virtual void FastExpandCRTBasisPloverQ(const CRTBasisExtensionPrecomputations& precomputed) = 0;

    virtual void ExpandCRTBasisQlHat(const std::shared_ptr<Params>& paramsQ,
                                     const std::vector<NativeInteger>& QlHatModq,
                                     const std::vector<NativeInteger>& QlHatModqPrecon, const usint sizeQ) = 0;

    virtual TowerType ScaleAndRound(const NativeInteger& t, const std::vector<NativeInteger>& tQHatInvModqDivqModt,
                                    const std::vector<NativeInteger>& tQHatInvModqDivqModtPrecon,
                                    const std::vector<NativeInteger>& tQHatInvModqBDivqModt,
                                    const std::vector<NativeInteger>& tQHatInvModqBDivqModtPrecon,
                                    const std::vector<double>& tQHatInvModqDivqFrac,
                                    const std::vector<double>& tQHatInvModqBDivqFrac) const = 0;

    virtual DerivedType ApproxScaleAndRound(const std::shared_ptr<Params>& paramsP,
                                            const std::vector<std::vector<NativeInteger>>& tPSHatInvModsDivsModp,
                                            const std::vector<DoubleNativeInt>& modpBarretMu) const = 0;

    virtual DerivedType ScaleAndRound(const std::shared_ptr<Params>& paramsOutput,
                                      const std::vector<std::vector<NativeInteger>>& tOSHatInvModsDivsModo,
                                      const std::vector<double>& tOSHatInvModsDivsFrac,
                                      const std::vector<DoubleNativeInt>& modoBarretMu) const = 0;

    virtual TowerType ScaleAndRound(const std::vector<NativeInteger>& moduliQ, const NativeInteger& t,
                                    const NativeInteger& tgamma, const std::vector<NativeInteger>& tgammaQHatModq,
                                    const std::vector<NativeInteger>& tgammaQHatModqPrecon,
                                    const std::vector<NativeInteger>& negInvqModtgamma,
                                    const std::vector<NativeInteger>& negInvqModtgammaPrecon) const = 0;

    virtual void ScaleAndRoundPOverQ(const std::shared_ptr<Params>& paramsQ,
                                     const std::vector<NativeInteger>& pInvModq) = 0;

    virtual void FastBaseConvqToBskMontgomery(
        const std::shared_ptr<Params>& paramsQBsk, const std::vector<NativeInteger>& moduliQ,
        const std::vector<NativeInteger>& moduliBsk, const std::vector<DoubleNativeInt>& modbskBarrettMu,
        const std::vector<NativeInteger>& mtildeQHatInvModq, const std::vector<NativeInteger>& mtildeQHatInvModqPrecon,
        const std::vector<std::vector<NativeInteger>>& QHatModbsk, const std::vector<uint64_t>& QHatModmtilde,
        const std::vector<NativeInteger>& QModbsk, const std::vector<NativeInteger>& QModbskPrecon,
        const uint64_t& negQInvModmtilde, const std::vector<NativeInteger>& mtildeInvModbsk,
        const std::vector<NativeInteger>& mtildeInvModbskPrecon) = 0;

    virtual void FastRNSFloorq(
        const NativeInteger& t, const std::vector<NativeInteger>& moduliQ, const std::vector<NativeInteger>& moduliBsk,
        const std::vector<DoubleNativeInt>& modbskBarrettMu, const std::vector<NativeInteger>& tQHatInvModq,
        const std::vector<NativeInteger>& tQHatInvModqPrecon, const std::vector<std::vector<NativeInteger>>& QHatModbsk,
        const std::vector<std::vector<NativeInteger>>& qInvModbsk, const std::vector<NativeInteger>& tQInvModbsk,
        const std::vector<NativeInteger>& tQInvModbskPrecon) = 0;

    virtual void FastBaseConvSK(
        const std::shared_ptr<Params>& paramsQ, const std::vector<DoubleNativeInt>& modqBarrettMu,
        const std::vector<NativeInteger>& moduliBsk, const std::vector<DoubleNativeInt>& modbskBarrettMu,
        const std::vector<NativeInteger>& BHatInvModb, const std::vector<NativeInteger>& BHatInvModbPrecon,
        const std::vector<NativeInteger>& BHatModmsk, const NativeInteger& BInvModmsk,
        const NativeInteger& BInvModmskPrecon, const std::vector<std::vector<NativeInteger>>& BHatModq,
        const std::vector<NativeInteger>& BModq, const std::vector<NativeInteger>& BModqPrecon) = 0;

    void SwitchFormat() override = 0;

    virtual void OverrideFormat(const Format f) = 0;

    void SwitchModulus(const BigIntType& modulus, const BigIntType& rootOfUnity, const BigIntType& modulusArb,
                       const BigIntType& rootOfUnityArb) final {
        OPENFHE_THROW("SwitchModulus not implemented for DCRTPoly");
    }

    virtual void SwitchModulusAtIndex(size_t index, const BigIntType& modulus, const BigIntType& rootOfUnity) = 0;

    bool InverseExists() const override = 0;

    //    virtual double Norm() const override = 0;
    double Norm() const final {
        return PolyLargeType(this->GetDerived().CRTInterpolate()).Norm();
    }

    const std::string GetElementName() const {
        return this->GetDerived().GetElementName();
    }

protected:
    friend inline std::ostream& operator<<(std::ostream& os, const DerivedType& vec) {
        // os << (vec.m_format == EVALUATION ? "EVAL: " : "COEF: ");
        for (usint i = 0; i < vec.GetAllElements().size(); i++) {
            if (i != 0)
                os << std::endl;
            os << i << ": ";
            os << vec.GetAllElements()[i];
        }
        return os;
    }

    friend inline DerivedType operator+(const DerivedType& a, const DerivedType& b) {
        return a.Plus(b);
    }
    friend inline DerivedType operator+(const DerivedType& a, const BigIntType& b) {
        return a.Plus(b);
    }

    friend inline DerivedType operator+(const BigIntType& a, const DerivedType& b) {
        return b.Plus(a);
    }

    friend inline DerivedType operator+(const DerivedType& a, const std::vector<BigIntType>& b) {
        return a.Plus(b);
    }

    friend inline DerivedType operator+(const std::vector<BigIntType>& a, const DerivedType& b) {
        return b.Plus(a);
    }

    friend inline DerivedType operator-(const DerivedType& a, const DerivedType& b) {
        return a.Minus(b);
    }

    friend inline DerivedType operator-(const DerivedType& a, const std::vector<BigIntType>& b) {
        return a.Minus(b);
    }

    friend inline DerivedType operator-(const std::vector<BigIntType>& a, const DerivedType& b) {
        return b.Minus(a);
    }

    friend inline DerivedType operator-(const DerivedType& a, const BigIntType& b) {
        return a.Minus(b);
    }

    friend inline DerivedType operator*(const DerivedType& a, const DerivedType& b) {
        return a.Times(b);
    }

    friend inline DerivedType operator*(const DerivedType& a, const BigIntType& b) {
        return a.Times(b);
    }

    friend inline DerivedType operator*(const DerivedType& a, const std::vector<BigIntType>& b) {
        return a.Times(b);
    }

    friend inline DerivedType operator*(const BigIntType& a, const DerivedType& b) {
        return b.Times(a);
    }

    friend inline DerivedType operator*(const DerivedType& a, int64_t b) {
        return a.Times((NativeInteger::SignedNativeInt)b);
    }

    friend inline DerivedType operator*(int64_t a, const DerivedType& b) {
        return b.Times((NativeInteger::SignedNativeInt)a);
    }
};

}  // namespace lbcrypto

#endif  // LBCRYPTO_LATTICE_HAL_DCRTPOLYINTERFACE_H