Program Listing for File rlwe-mp.cpp

Return to documentation for file (pke/lib/schemelet/rlwe-mp.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2025, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#include "schemebase/rlwe-cryptoparameters.h"
#include "schemelet/rlwe-mp.h"
#include "cryptocontext.h"

#include <stdint.h>
#include <vector>

template <typename typeT>
static void BitReverse(typeT& vals) {
    uint32_t size = vals.size();
    for (uint32_t i = 1, j = 0; i < size; ++i) {
        uint32_t bit = size >> 1;
        for (; j >= bit; bit >>= 1)
            j -= bit;
        j += bit;
        if (i < j) {
            auto t  = vals[i];
            vals[i] = vals[j];
            vals[j] = t;
        }
    }
}

template <typename typeT>
static void BitReverseTwoHalves(typeT& vals) {
    uint32_t size = vals.size() / 2;
    for (uint32_t i = 1, j = 0; i < size; ++i) {
        uint32_t bit = size >> 1;
        for (; j >= bit; bit >>= 1)
            j -= bit;
        j += bit;
        if (i < j) {
            auto t  = vals[i];
            vals[i] = vals[j];
            vals[j] = t;
        }
    }

    for (uint32_t i = size + 1, j = size; i < 2 * size; ++i) {
        uint32_t bit = size >> 1;
        for (; j >= size + bit; bit >>= 1)
            j -= bit;
        j += bit;
        if (i < j) {
            auto t  = vals[i];
            vals[i] = vals[j];
            vals[j] = t;
        }
    }
}

namespace lbcrypto {

namespace {

static std::vector<DCRTPoly> ModSwitchUp(const std::vector<Poly>& input, const BigInteger& Qfrom, const BigInteger& Qto,
                                         const std::shared_ptr<ILDCRTParams<DCRTPoly::Integer>>& ep) {
    Poly bPoly = input[0];
    bPoly.SwitchModulus(Qto, 1, 0, 0);

    Poly aPoly = input[1];
    aPoly.SwitchModulus(Qto, 1, 0, 0);

    std::vector<DCRTPoly> output{DCRTPoly(bPoly.MultiplyAndRound(Qto, Qfrom), ep),
                                 DCRTPoly(aPoly.MultiplyAndRound(Qto, Qfrom), ep)};
    output[0].SetFormat(Format::EVALUATION);
    output[1].SetFormat(Format::EVALUATION);

    return output;
}

static std::vector<DCRTPoly> ModSwitchDown(const std::vector<Poly>& input, const BigInteger& Qfrom,
                                           const BigInteger& Qto,
                                           const std::shared_ptr<ILDCRTParams<DCRTPoly::Integer>>& ep) {
    Poly bPoly = input[0].MultiplyAndRound(Qto, Qfrom);
    bPoly.SwitchModulus(Qto, 1, 0, 0);

    Poly aPoly = input[1].MultiplyAndRound(Qto, Qfrom);
    aPoly.SwitchModulus(Qto, 1, 0, 0);

    std::vector<DCRTPoly> output{DCRTPoly(bPoly, ep), DCRTPoly(aPoly, ep)};
    output[0].SetFormat(Format::EVALUATION);
    output[1].SetFormat(Format::EVALUATION);

    return output;
}

}  // namespace

std::shared_ptr<ILDCRTParams<DCRTPoly::Integer>> SchemeletRLWEMP::GetElementParams(
    const PrivateKey<DCRTPoly>& privateKey, uint32_t level) {
    const auto cryptoParams =
        std::dynamic_pointer_cast<CryptoParametersRLWE<DCRTPoly>>(privateKey->GetCryptoParameters());

    auto ep = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(*(cryptoParams->GetElementParams()));
    for (uint32_t i = 0; i < level; ++i)
        ep->PopLastParam();

    return ep;
}

std::vector<Poly> SchemeletRLWEMP::EncryptCoeff(std::vector<int64_t> input, const BigInteger& Q, const BigInteger& p,
                                                const PrivateKey<DCRTPoly>& privateKey,
                                                const std::shared_ptr<ILDCRTParams<DCRTPoly::Integer>>& ep,
                                                bool bitReverse) {
    const auto cryptoParams =
        std::dynamic_pointer_cast<CryptoParametersRLWE<DCRTPoly>>(privateKey->GetCryptoParameters());

    DugType dug;
    DCRTPoly a(dug, ep, Format::EVALUATION);
    DCRTPoly e(cryptoParams->GetDiscreteGaussianGenerator(), ep, Format::EVALUATION);

    auto scopy(privateKey->GetPrivateElement());
    scopy.DropLastElements(scopy.GetParams()->GetParams().size() - ep->GetParams().size());

    DCRTPoly b = e - a * scopy;  // encryption of 0 using Q'

    a.SetFormat(Format::COEFFICIENT);
    auto aPoly = a.CRTInterpolate();
    b.SetFormat(Format::COEFFICIENT);
    auto bPoly = b.CRTInterpolate();

    BigInteger bigQPrime = b.GetModulus();

    // Do modulus switching from Q' to Q
    if (Q < bigQPrime) {
        bPoly = bPoly.MultiplyAndRound(Q, bigQPrime);
        bPoly.SwitchModulus(Q, 1, 0, 0);

        aPoly = aPoly.MultiplyAndRound(Q, bigQPrime);
        aPoly.SwitchModulus(Q, 1, 0, 0);
    }
    else {
        bPoly.SwitchModulus(Q, 1, 0, 0);
        bPoly = bPoly.MultiplyAndRound(Q, bigQPrime);

        aPoly.SwitchModulus(Q, 1, 0, 0);
        aPoly = aPoly.MultiplyAndRound(Q, bigQPrime);
    }

    auto mPoly = bPoly;
    mPoly.SetValuesToZero();

    auto delta   = Q / p;
    uint32_t gap = mPoly.GetLength() / (2.0 * input.size());

    // Input here is not yet padded up to the ring dimension
    if (bitReverse) {
        if (gap == 0) {
            BitReverseTwoHalves(input);
        }
        else {
            BitReverse(input);
        }
    }

    gap                  = (gap == 0) ? 1 : gap;
    const uint32_t limit = input.size() < mPoly.GetLength() ? input.size() : mPoly.GetLength();
    for (uint32_t i = 0; i < limit; ++i) {
        auto entry     = (input[i] < 0) ? mPoly.GetModulus() - BigInteger(static_cast<uint64_t>(llabs(input[i]))) :
                                          BigInteger{input[i]};
        mPoly[i * gap] = delta * entry;
        if (gap > 1) {
            mPoly[(i + limit) * gap] = delta * entry;
        }
    }

    return {bPoly += mPoly, aPoly};
}

std::vector<int64_t> SchemeletRLWEMP::DecryptCoeff(const std::vector<Poly>& input, const BigInteger& Q,
                                                   const BigInteger& p, const PrivateKey<DCRTPoly>& privateKey,
                                                   const std::shared_ptr<ILDCRTParams<DCRTPoly::Integer>>& ep,
                                                   uint32_t numSlots, uint32_t length, bool bitReverse) {
    const auto& bigQPrime = ep->GetModulus();

    auto ba = (Q < bigQPrime) ? ModSwitchUp(input, Q, bigQPrime, ep) : ModSwitchDown(input, Q, bigQPrime, ep);

    auto scopy(privateKey->GetPrivateElement());
    scopy.DropLastElements(scopy.GetParams()->GetParams().size() - ep->GetParams().size());

    auto m = ba[0] + ba[1] * scopy;

    m.SetFormat(Format::COEFFICIENT);

    auto mPoly   = m.CRTInterpolate();
    uint32_t gap = mPoly.GetLength() / (2 * numSlots);

    if (Q < bigQPrime) {
        mPoly = mPoly.MultiplyAndRound(Q, bigQPrime);
        mPoly.SwitchModulus(Q, 1, 0, 0);
    }
    else {
        mPoly.SwitchModulus(Q, 1, 0, 0);
        mPoly = mPoly.MultiplyAndRound(Q, bigQPrime);
    }

    mPoly = mPoly.MultiplyAndRound(p, Q);
    mPoly.SwitchModulus(p, 1, 0, 0);

    BigInteger half = p >> 1;

    length = (length == 0) ? numSlots : length;

    std::vector<int64_t> output(length);
    for (uint32_t i = 0, idx = 0; i < length; ++i, idx += gap)
        output[i] =
            (mPoly[idx] > half) ? -(p - mPoly[idx]).ConvertToInt<int64_t>() : mPoly[idx].ConvertToInt<int64_t>();

    if (bitReverse) {
        if (numSlots < length) {
            BitReverseTwoHalves(output);
        }
        else {
            BitReverse(output);
        }
    }

    return output;
}

void SchemeletRLWEMP::ModSwitch(std::vector<Poly>& input, const BigInteger& Q1, const BigInteger& Q2) {
    input[0] = input[0].MultiplyAndRound(Q1, Q2);
    input[0].SwitchModulus(Q1, 1, 0, 0);
    input[1] = input[1].MultiplyAndRound(Q1, Q2);
    input[1].SwitchModulus(Q1, 1, 0, 0);
}

Ciphertext<DCRTPoly> SchemeletRLWEMP::ConvertRLWEToCKKS(const CryptoContextImpl<DCRTPoly>& cc,
                                                        const std::vector<Poly>& coeffs,
                                                        const PublicKey<DCRTPoly>& pubKey, const BigInteger& Bigq,
                                                        uint32_t slots, uint32_t level) {
    std::vector<std::complex<double>> y(1);
    auto ptxt = cc.MakeCKKSPackedPlaintext(y, 1, level);
    ptxt->SetLength(slots);

    auto ctxt = cc.Encrypt(pubKey, ptxt);

    auto ep = ptxt->GetElement<DCRTPoly>().GetParams();

    auto& qPrimeCKKS = ep->GetModulus();

    auto elementsCKKS =
        (qPrimeCKKS > Bigq) ? ModSwitchUp(coeffs, Bigq, qPrimeCKKS, ep) : ModSwitchDown(coeffs, Bigq, qPrimeCKKS, ep);
    ctxt->SetElements(elementsCKKS);
    return ctxt;
}

std::vector<Poly> SchemeletRLWEMP::ConvertCKKSToRLWE(ConstCiphertext<DCRTPoly>& ctxt, const BigInteger& Q) {
    auto b = ctxt->GetElements()[0];
    b.SetFormat(Format::COEFFICIENT);
    auto bPoly = b.CRTInterpolate();

    auto a = ctxt->GetElements()[1];
    a.SetFormat(Format::COEFFICIENT);
    auto aPoly = a.CRTInterpolate();

    BigInteger QPrime = ctxt->GetElements()[0].GetModulus();
    if (Q < QPrime) {
        bPoly = bPoly.MultiplyAndRound(Q, QPrime);
        bPoly.SwitchModulus(Q, 1, 0, 0);

        aPoly = aPoly.MultiplyAndRound(Q, QPrime);
        aPoly.SwitchModulus(Q, 1, 0, 0);
    }
    else {
        bPoly.SwitchModulus(Q, 1, 0, 0);
        bPoly = bPoly.MultiplyAndRound(Q, QPrime);

        aPoly.SwitchModulus(Q, 1, 0, 0);
        aPoly = aPoly.MultiplyAndRound(Q, QPrime);
    }
    return {bPoly, aPoly};
}

BigInteger SchemeletRLWEMP::GetQPrime(const PublicKey<DCRTPoly>& pubKey, uint32_t lvls) {
    auto& params = pubKey->GetPublicElements()[0].GetParams()->GetParams();
    uint32_t cnt = 0;

    BigInteger QPrime = params[0]->GetModulus();
    while (lvls-- > 0)
        QPrime *= params[++cnt]->GetModulus();

    return QPrime;
}

}  // namespace lbcrypto