Program Listing for File nbtheory-impl.h

Return to documentation for file (core/include/math/nbtheory-impl.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  This code provides number theory utilities
 */

#ifndef LBCRYPTO_INC_MATH_NBTHEORY_IMPL_H
#define LBCRYPTO_INC_MATH_NBTHEORY_IMPL_H

#define _USE_MATH_DEFINES

#include "math/distributiongenerator.h"
#include "math/nbtheory.h"

#include "utils/debug.h"
#include "utils/exception.h"
#include "utils/inttypes.h"

#include <cmath>
#include <limits>
#include <set>
#include <string>
#include <type_traits>
#include <vector>

namespace lbcrypto {

/*
 Generates a random number between 0 and n.
 Input: BigInteger n.
 Output: Randomly generated BigInteger  between 0 and n.
 */
template <typename IntType>
static IntType RNG(const IntType& modulus) {
    constexpr uint32_t chunk_min{0};
    constexpr uint32_t chunk_max{std::numeric_limits<uint32_t>::max()};
    constexpr uint32_t chunk_width{std::numeric_limits<uint32_t>::digits};
    static std::uniform_int_distribution<uint32_t> distribution(chunk_min, chunk_max);

    uint32_t chunksPerValue{(modulus.GetMSB() - 1) / chunk_width};
    uint32_t shiftChunk{chunksPerValue * chunk_width};
    std::uniform_int_distribution<uint32_t>::param_type bound(chunk_min, (modulus >> shiftChunk).ConvertToInt());

    while (true) {
        IntType result{};
        for (uint32_t i{0}, shift{0}; i < chunksPerValue; ++i, shift += chunk_width)
            result += IntType{distribution(PseudoRandomNumberGenerator::GetPRNG())} << shift;
        result += IntType{distribution(PseudoRandomNumberGenerator::GetPRNG(), bound)} << shiftChunk;
        if (result < modulus)
            return result;
    }
}
/*
 A witness function used for the Miller-Rabin Primality test.
 Inputs: a is a randomly generated witness between 2 and p-1,
 p is the number to be tested for primality,
 s and d satisfy p-1 = ((2^s) * d), d is odd.
 Output: true if p is composite,
 false if p is likely prime
 */
template <typename IntType>
static bool WitnessFunction(const IntType& a, const IntType& d, usint s, const IntType& p) {
    IntType mod  = a.ModExp(d, p);
    bool prevMod = false;
    for (usint i = 0; i < s; ++i) {
        prevMod = (mod != IntType(1) && mod != p - IntType(1));
        mod.ModMulFastEq(mod, p);
        if (mod == IntType(1) && prevMod)
            return true;
    }
    return (mod != IntType(1));
}

/*
 A helper function to RootOfUnity function. This finds a generator for a given
 prime q. Input: BigInteger q which is a prime. Output: A generator of prime q
 */
template <typename IntType>
static IntType FindGenerator(const IntType& q) {
    IntType qm1(q - IntType(1));
    IntType qm2(q - IntType(2));
    std::set<IntType> primeFactors;
    PrimeFactorize<IntType>(qm1, primeFactors);
    usint cnt;
    IntType gen;
    do {
        cnt = 0;
        gen = RNG(qm2) + IntType(1);
        for (auto it = primeFactors.begin(); it != primeFactors.end(); ++it, ++cnt) {
            if (gen.ModExp(qm1 / (*it), q) == IntType(1))
                break;
        }
    } while (cnt != primeFactors.size());
    return gen;
}

/*
 A helper function for arbitrary cyclotomics. This finds a generator for any
 composite q (cyclic group). Input: BigInteger q (cyclic group). Output: A
 generator of q
 */
template <typename IntType>
IntType FindGeneratorCyclic(const IntType& q) {
    IntType phi_q(GetTotient(q.ConvertToInt()));
    IntType phi_q_m1(GetTotient(q.ConvertToInt()));
    std::set<IntType> primeFactors;
    PrimeFactorize<IntType>(phi_q, primeFactors);
    usint cnt;
    IntType gen;
    do {
        cnt = 0;
        gen = RNG(phi_q_m1) + IntType(1);  // gen is random in [1, phi(q)]

        // Generator must lie in the group!
        if (GreatestCommonDivisor<IntType>(gen, q) != IntType(1))
            continue;

        // Order of a generator cannot divide any co-factor
        for (auto it = primeFactors.begin(); it != primeFactors.end(); ++it, ++cnt) {
            if (gen.ModExp(phi_q / (*it), q) == IntType(1))
                break;
        }
    } while (cnt != primeFactors.size());
    return gen;
}

/*
 A helper function for arbitrary cyclotomics. Checks if g is a generator of q
 (supports any cyclic group, not just prime-modulus groups) Input: Candidate
 generator g and modulus q Output: returns true if g is a generator for q
 */
template <typename IntType>
bool IsGenerator(const IntType& g, const IntType& q) {
    IntType qm1(GetTotient(q.ConvertToInt()));
    std::set<IntType> primeFactors;
    PrimeFactorize<IntType>(qm1, primeFactors);
    usint cnt = 0;
    for (auto it = primeFactors.begin(); it != primeFactors.end(); ++it, ++cnt) {
        if (g.ModExp(qm1 / (*it), q) == IntType(1))
            break;
    }
    return cnt == primeFactors.size();
}

/*
 finds roots of unity for given input.  Assumes the the input is a power of two.
 Mostly likely does not give correct results otherwise. input:  m as number
 which is cyclotomic(in format of int), modulo which is used to find generator
 (in format of BigInteger)

 output:  root of unity (in format of BigInteger)
 */
template <typename IntType>
IntType RootOfUnity(usint m, const IntType& modulo) {
    IntType M(m);
    if ((modulo - IntType(1)).Mod(M) != IntType(0)) {
        std::string errMsg =
            "Please provide a primeModulus(q) and a cyclotomic number(m) "
            "satisfying the condition: (q-1)/m is an integer. The values of "
            "primeModulus = " +
            modulo.ToString() + " and m = " + std::to_string(m) + " do not satisfy this condition";
        OPENFHE_THROW(errMsg);
    }

    IntType gen    = FindGenerator(modulo);
    IntType result = gen.ModExp((modulo - IntType(1)).DividedBy(M), modulo);
    if (result == IntType(1))
        result = RootOfUnity(m, modulo);

    /*
   * At this point, result contains a primitive root of unity. However,
   * we want to return the minimum root of unity, to avoid different
   * crypto contexts having different roots of unity for the same
   * cyclotomic order and moduli. Therefore, we are going to cycle over
   * all primitive roots of unity and select the smallest one (minRU).
   *
   * To cycle over all primitive roots of unity, we raise the root of
   * unity in result to all the powers that are co-prime to the
   * cyclotomic order. In power-of-two cyclotomics, this will be the
   * set of all odd powers, but here we use a more general routine
   * to support arbitrary cyclotomics.
   *
   */

    IntType mu(modulo.ComputeMu());
    IntType x(1);
    x.ModMulEq(result, modulo, mu);

    std::vector<IntType> coprimes = GetTotientList<IntType>(m);
    IntType minRU(x);
    IntType curPowIdx(1);
    for (size_t i = 0; i < coprimes.size(); ++i) {
        auto nextPowIdx = coprimes[i];
        IntType diffPow(nextPowIdx - curPowIdx);
        for (IntType j(0); j < diffPow; j += IntType(1))
            x.ModMulEq(result, modulo, mu);
        if (x < minRU && x != IntType(1))
            minRU = x;
        curPowIdx = nextPowIdx;
    }
    return minRU;
}

template <typename IntType>
std::vector<IntType> RootsOfUnity(usint m, const std::vector<IntType>& moduli) {
    std::vector<IntType> rootsOfUnity(moduli.size());
    for (size_t i = 0; i < moduli.size(); ++i)
        rootsOfUnity[i] = RootOfUnity(m, moduli[i]);
    return rootsOfUnity;
}

template <typename IntType>
IntType GreatestCommonDivisor(const IntType& a, const IntType& b) {
    static const IntType ZERO(0);
    auto m_a(a);
    auto m_b(b);
    while (m_b != ZERO) {
        auto tmp(m_b);
        m_b = m_a % m_b;
        m_a = tmp;
    }
    return m_a;
}

/*
 The Miller-Rabin Primality Test
 Input: p the number to be tested for primality.
 Output: true if p is prime,
 false if p is not prime
 */
template <typename IntType>
bool MillerRabinPrimalityTest(const IntType& p, const usint niter) {
    static const IntType ZERO(0);
    static const IntType TWO(2);
    static const IntType THREE(3);
    static const IntType FIVE(5);

    if (p == TWO || p == THREE || p == FIVE)
        return true;
    if (p < TWO || (p.Mod(TWO) == ZERO))
        return false;

    IntType d(p - IntType(1));
    usint s(0);
    while (d.Mod(TWO) == ZERO) {
        // d.DividedByEq(TWO);
        d.RShiftEq(1);
        ++s;
    }
    for (usint i = 0; i < niter; ++i) {
        if (WitnessFunction(RNG(p - THREE).ModAdd(TWO, p), d, s, p))
            return false;
    }
    return true;
}

/*
 The Pollard Rho factorization of a number n.
 Input: n the number to be factorized.
 Output: a factor of n.
 */
template <typename IntType>
const IntType PollardRhoFactorization(const IntType& n) {
    if (n.Mod(IntType(2)) == IntType(0))
        return IntType(2);
    IntType divisor(1);
    IntType c(RNG(n));
    IntType x(RNG(n));
    IntType xx(x);
    IntType mu(n.ComputeMu());
    do {
        x       = x.ModMul(x, n, mu).ModAdd(c, n, mu);
        xx      = xx.ModMul(xx, n, mu).ModAdd(c, n, mu);
        xx      = xx.ModMul(xx, n, mu).ModAdd(c, n, mu);
        divisor = GreatestCommonDivisor((x > xx) ? x - xx : xx - x, n);
    } while (divisor == IntType(1));
    return divisor;
}

/*
 Recursively factorizes and find the distinct primefactors of a number
 Input: n is the number to be prime factorized,
 primeFactors is a set of prime factors of n.
 */
template <typename IntType>
void PrimeFactorize(IntType n, std::set<IntType>& primeFactors) {
    if (n == IntType(0) || n == IntType(1))
        return;
    if (MillerRabinPrimalityTest(n)) {
        primeFactors.insert(n);
        return;
    }

    IntType divisor(PollardRhoFactorization(n));
    PrimeFactorize(divisor, primeFactors);
    PrimeFactorize(n / divisor, primeFactors);
}

template <typename IntType>
IntType FirstPrime(uint32_t nBits, uint32_t m) {
    if constexpr (std::is_same_v<IntType, NativeInteger>) {
        if (nBits > MAX_MODULUS_SIZE)
            OPENFHE_THROW(std::string(__func__) + ": Requested bit length " + std::to_string(nBits) +
                          " exceeds maximum allowed length " + std::to_string(MAX_MODULUS_SIZE));
    }

    IntType M(m);
    IntType q(IntType(1) << nBits);
    IntType r(q.Mod(M));
    IntType qNew(q + IntType(1) - r);
    if (r > IntType(0))
        qNew += M;
    while (!MillerRabinPrimalityTest(qNew)) {
        if ((qNew += M) < q)
            OPENFHE_THROW(std::string(__func__) + ": overflow growing candidate");
    }
    return qNew;
}

template <typename IntType>
IntType LastPrime(uint32_t nBits, uint32_t m) {
    if constexpr (std::is_same_v<IntType, NativeInteger>) {
        if (nBits > MAX_MODULUS_SIZE)
            OPENFHE_THROW(std::string(__func__) + ": Requested bit length " + std::to_string(nBits) +
                          " exceeds maximum allowed length " + std::to_string(MAX_MODULUS_SIZE));
    }

    IntType M(m);
    IntType q(IntType(1) << nBits);
    IntType r(q.Mod(M));
    IntType qNew(q + IntType(1) - r);
    if (r < IntType(2))
        qNew -= M;
    while (!MillerRabinPrimalityTest(qNew)) {
        if ((qNew -= M) > q)
            OPENFHE_THROW(std::string(__func__) + ": overflow shrinking candidate");
    }

    if (qNew.GetMSB() != nBits)
        OPENFHE_THROW(std::string(__func__) + ": Requested " + std::to_string(nBits) + " bits, but returned " +
                      std::to_string(qNew.GetMSB()) + ". Please adjust parameters.");

    return qNew;
}

template <typename IntType>
IntType NextPrime(const IntType& q, uint32_t m) {
    IntType M(m), qNew(q + M);
    while (!MillerRabinPrimalityTest(qNew)) {
        if ((qNew += M) < q)
            OPENFHE_THROW(std::string(__func__) + ": overflow growing candidate");
    }
    return qNew;
}

template <typename IntType>
IntType PreviousPrime(const IntType& q, uint32_t m) {
    IntType M(m), qNew(q - M);
    while (!MillerRabinPrimalityTest(qNew)) {
        if ((qNew -= M) > q)
            OPENFHE_THROW(std::string(__func__) + ": overflow shrinking candidate");
    }
    return qNew;
}

template <typename IntType>
IntType NextPowerOfTwo(IntType n) {
    usint result = ceil(log2(n));
    return result;
}

/*Naive Loop to find coprimes to n*/
template <typename IntType>
std::vector<IntType> GetTotientList(const IntType& n) {
    std::vector<IntType> result;
    static const IntType one(1);
    for (IntType i = one; i < n; i = i + one) {
        if (GreatestCommonDivisor(i, n) == one)
            result.push_back(i);
    }
    return result;
}

/* Calculate the remainder from polynomial division */
template <typename IntVector>
IntVector PolyMod(const IntVector& dividend, const IntVector& divisor, const typename IntVector::Integer& modulus) {
    auto mu(modulus.ComputeMu());
    usint divisorLength(divisor.GetLength());
    usint dividendLength(dividend.GetLength());
    usint runs(dividendLength - divisorLength + 1);
    IntVector runningDividend(dividend);
    for (usint i = 0; i < runs; ++i) {
        // get the highest degree coeff
        auto divConst(runningDividend[dividendLength - 1]);
        usint divisorPtr(divisorLength - 1);
        for (usint j = 0; j < dividendLength - i - 1; j++) {
            auto& rdtmp1 = runningDividend[dividendLength - 1 - j];
            rdtmp1       = runningDividend[dividendLength - 2 - j];
            if (divisorPtr > j)
                rdtmp1.ModSubEq(divisor[divisorPtr - 1 - j] * divConst, modulus, mu);
        }
    }

    IntVector result(divisorLength - 1, modulus);
    for (usint i = 0, j = runs; i < divisorLength - 1; ++i, ++j)
        result[i] = runningDividend[j];
    return result;
}

template <typename IntVector>
IntVector PolynomialMultiplication(const IntVector& a, const IntVector& b) {
    usint degreeA(a.GetLength());
    usint degreeB(b.GetLength());
    usint degreeResultant(degreeA + degreeB - 1);
    const auto& modulus = a.GetModulus();
    IntVector result(degreeResultant, modulus);
    for (usint i = 0; i < degreeA; i++) {
        for (usint j = 0; j < degreeB; j++) {
            result[i + j].ModAddEq(a[i] * b[j], modulus);
        }
    }
    return result;
}

template <typename IntVector>
IntVector GetCyclotomicPolynomial(usint m, const typename IntVector::Integer& modulus) {
    auto intCP = GetCyclotomicPolynomialRecursive(m);
    IntVector result(intCP.size(), modulus);
    for (usint i = 0; i < intCP.size(); i++) {
        auto val = intCP[i];
        if (val > -1) {
            result[i] = typename IntVector::Integer(val);
        }
        else {
            result[i] = modulus - typename IntVector::Integer(-val);
        }
    }
    return result;
}

template <typename IntVector>
typename IntVector::Integer SyntheticRemainder(const IntVector& dividend, const typename IntVector::Integer& a,
                                               const typename IntVector::Integer& modulus) {
    auto mu  = modulus.ComputeMu();
    auto val = dividend[dividend.GetLength() - 1];
    for (int i = dividend.GetLength() - 2; i >= 0; --i)
        val = (dividend[i] + a * val).Mod(modulus, mu);
    return val;
}

template <typename IntVector>
IntVector SyntheticPolyRemainder(const IntVector& dividend, const IntVector& aList,
                                 const typename IntVector::Integer& modulus) {
    IntVector result(aList.GetLength(), modulus);
    for (usint i = 0; i < aList.GetLength(); ++i)
        result[i] = SyntheticRemainder(dividend, aList[i], modulus);
    return result;
}

template <typename IntVector>
IntVector PolynomialPower(const IntVector& input, usint power) {
    usint finalDegree = (input.GetLength() - 1) * power;
    IntVector finalPoly(finalDegree + 1, input.GetModulus());
    for (usint i = 0; i < input.GetLength(); ++i)
        finalPoly[i * power] = input[i];
    return finalPoly;
}

template <typename IntVector>
IntVector SyntheticPolynomialDivision(const IntVector& dividend, const typename IntVector::Integer& a,
                                      const typename IntVector::Integer& modulus) {
    auto mu(modulus.ComputeMu());
    usint n(dividend.GetLength() - 1);
    IntVector result(n, modulus);
    result[n - 1] = dividend[n];
    auto val(dividend[n]);
    for (int i = n - 1; i > 0; i--) {
        val           = (val * a + dividend[i]).Mod(modulus, mu);
        result[i - 1] = val;
    }
    return result;
}

}  // namespace lbcrypto

#endif