Program Listing for File transformnat.h

Return to documentation for file (core/include/math/hal/intnat/transformnat.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
 This file contains the linear transform interface functionality for the native math backend
*/

#ifndef LBCRYPTO_MATH_HAL_INTNAT_TRANSFORMNAT_H
#define LBCRYPTO_MATH_HAL_INTNAT_TRANSFORMNAT_H

#include "math/hal/transform.h"

#include "utils/inttypes.h"

#include <map>
#include <mutex>
#include <unordered_map>
#include <utility>
#include <vector>

namespace intnat {

struct HashPair {
    template <class T1, class T2>
    size_t operator()(const std::pair<T1, T2>& p) const {
        auto hash1 = std::hash<T1>{}(std::get<0>(p));
        auto hash2 = std::hash<T2>{}(std::get<1>(p));
        return HashCombine(hash1, hash2);
    }

    static size_t HashCombine(size_t lhs, size_t rhs) {
        lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
        return lhs;
    }
};

template <typename VecType>
class NumberTheoreticTransformNat {
    using IntType = typename VecType::Integer;

public:
    void ForwardTransformIterative(const VecType& element, const VecType& rootOfUnityTable, VecType* result);

    void InverseTransformIterative(const VecType& element, const VecType& rootOfUnityInverseTable, VecType* result);

    void ForwardTransformToBitReverse(const VecType& element, const VecType& rootOfUnityTable, VecType* result);
    void ForwardTransformToBitReverseInPlace(const VecType& rootOfUnityTable, VecType* element);

    void ForwardTransformToBitReverse(const VecType& element, const VecType& rootOfUnityTable,
                                      const VecType& preconRootOfUnityTable, VecType* result);

    void ForwardTransformToBitReverseInPlace(const VecType& rootOfUnityTable, const VecType& preconRootOfUnityTable,
                                             VecType* element);

    void InverseTransformFromBitReverse(const VecType& element, const VecType& rootOfUnityInverseTable,
                                        const IntType& cycloOrderInv, VecType* result);

    void InverseTransformFromBitReverseInPlace(const VecType& rootOfUnityInverseTable, const IntType& cycloOrderInv,
                                               VecType* element);

    void InverseTransformFromBitReverse(const VecType& element, const VecType& rootOfUnityInverseTable,
                                        const VecType& preconRootOfUnityInverseTable, const IntType& cycloOrderInv,
                                        const IntType& preconCycloOrderInv, VecType* result);

    void InverseTransformFromBitReverseInPlace(const VecType& rootOfUnityInverseTable,
                                               const VecType& preconRootOfUnityInverseTable,
                                               const IntType& cycloOrderInv, const IntType& preconCycloOrderInv,
                                               VecType* element);
};

template <typename VecType>
class ChineseRemainderTransformFTTNat final : public lbcrypto::ChineseRemainderTransformFTTInterface<VecType> {
    using IntType = typename VecType::Integer;

public:
    void ForwardTransformToBitReverse(const VecType& element, const IntType& rootOfUnity, const usint CycloOrder,
                                      VecType* result);

    void ForwardTransformToBitReverseInPlace(const IntType& rootOfUnity, const usint CycloOrder, VecType* element);

    void InverseTransformFromBitReverse(const VecType& element, const IntType& rootOfUnity, const usint CycloOrder,
                                        VecType* result);

    void InverseTransformFromBitReverseInPlace(const IntType& rootOfUnity, const usint CycloOrder, VecType* element);

    void PreCompute(const IntType& rootOfUnity, const usint CycloOrder, const IntType& modulus);

    void PreCompute(std::vector<IntType>& rootOfUnity, const usint CycloOrder, std::vector<IntType>& moduliChain);

    void Reset();

    static std::map<IntType, VecType> m_cycloOrderInverseTableByModulus;

    static std::map<IntType, VecType> m_cycloOrderInversePreconTableByModulus;

    static std::map<IntType, VecType> m_rootOfUnityReverseTableByModulus;

    static std::map<IntType, VecType> m_rootOfUnityInverseReverseTableByModulus;

    static std::map<IntType, VecType> m_rootOfUnityPreconReverseTableByModulus;

    static std::map<IntType, VecType> m_rootOfUnityInversePreconReverseTableByModulus;
};

// struct used as a key in BlueStein transform
template <typename IntType>
using ModulusRoot = std::pair<IntType, IntType>;

template <typename IntType>
using ModulusRootPair = std::pair<ModulusRoot<IntType>, ModulusRoot<IntType>>;

template <typename VecType>
class BluesteinFFTNat {
    using IntType = typename VecType::Integer;

public:
    VecType ForwardTransform(const VecType& element, const IntType& root, const usint cycloOrder);
    VecType ForwardTransform(const VecType& element, const IntType& root, const usint cycloOrder,
                             const ModulusRoot<IntType>& nttModulusRoot);

    VecType PadZeros(const VecType& a, const usint finalSize);

    VecType Resize(const VecType& a, usint lo, usint hi);

    // void PreComputeNTTModulus(usint cycloOrder, const std::vector<IntType>
    // &modulii);

    void PreComputeDefaultNTTModulusRoot(usint cycloOrder, const IntType& modulus);

    void PreComputeRootTableForNTT(usint cycloOrder, const ModulusRoot<IntType>& nttModulusRoot);

    void PreComputePowers(usint cycloOrder, const ModulusRoot<IntType>& modulusRoot);

    void PreComputeRBTable(usint cycloOrder, const ModulusRootPair<IntType>& modulusRootPair);

    void Reset();

    // map to store the root of unity table with modulus as key.
    static std::map<ModulusRoot<IntType>, VecType> m_rootOfUnityTableByModulusRoot;

    // map to store the root of unity inverse table with modulus as key.
    static std::map<ModulusRoot<IntType>, VecType> m_rootOfUnityInverseTableByModulusRoot;

    // map to store the power of roots as a table with modulus + root of unity as
    // key.
    static std::map<ModulusRoot<IntType>, VecType> m_powersTableByModulusRoot;

    // map to store the forward transform of power table with modulus + root of
    // unity as key.
    static std::map<ModulusRootPair<IntType>, VecType> m_RBTableByModulusRootPair;

private:
    // map to store the precomputed NTT modulus with modulus as key.
    static std::map<IntType, ModulusRoot<IntType>> m_defaultNTTModulusRoot;
};

template <typename VecType>
class ChineseRemainderTransformArbNat final : public lbcrypto::ChineseRemainderTransformArbInterface<VecType> {
    using IntType = typename VecType::Integer;

public:
    void SetCylotomicPolynomial(const VecType& poly, const IntType& mod);

    VecType ForwardTransform(const VecType& element, const IntType& root, const IntType& bigMod, const IntType& bigRoot,
                             const usint cycloOrder);

    VecType InverseTransform(const VecType& element, const IntType& root, const IntType& bigMod, const IntType& bigRoot,
                             const usint cycloOrder);

    void Reset();

    void PreCompute(const usint cyclotoOrder, const IntType& modulus);

    void SetPreComputedNTTModulus(usint cyclotoOrder, const IntType& modulus, const IntType& nttMod,
                                  const IntType& nttRoot);

    void SetPreComputedNTTDivisionModulus(usint cyclotoOrder, const IntType& modulus, const IntType& nttMod,
                                          const IntType& nttRoot);

    VecType InversePolyMod(const VecType& cycloPoly, const IntType& modulus, usint power);

private:
    VecType Pad(const VecType& element, const usint cycloOrder, bool forward);

    VecType Drop(const VecType& element, const usint cycloOrder, bool forward, const IntType& bigMod,
                 const IntType& bigRoot);

    // map to store the cyclotomic polynomial with polynomial ring's modulus as
    // key.
    static std::map<IntType, VecType> m_cyclotomicPolyMap;

    // map to store the forward NTT transform of the inverse of cyclotomic
    // polynomial with polynomial ring's modulus as key.
    static std::map<IntType, VecType> m_cyclotomicPolyReverseNTTMap;

    // map to store the forward NTT transform of the cyclotomic polynomial with
    // polynomial ring's modulus as key.
    static std::map<IntType, VecType> m_cyclotomicPolyNTTMap;

    // map to store the root of unity table used in NTT based polynomial division.
    static std::map<IntType, VecType> m_rootOfUnityDivisionTableByModulus;

    // map to store the root of unity table for computing forward NTT of inverse
    // cyclotomic polynomial used in NTT based polynomial division.
    static std::map<IntType, VecType> m_rootOfUnityDivisionInverseTableByModulus;

    // modulus used in NTT based polynomial division.
    static std::map<IntType, IntType> m_DivisionNTTModulus;

    // root of unity used in NTT based polynomial division.
    static std::map<IntType, IntType> m_DivisionNTTRootOfUnity;

    // dimension of the NTT transform in NTT based polynomial division.
    static std::map<usint, usint> m_nttDivisionDim;
};

}  // namespace intnat

// class implementations
#include "math/hal/intnat/transformnat-impl.h"

#endif