Program Listing for File bgvrns-leveledshe.cpp

Return to documentation for file (pke/lib/scheme/bgvrns/bgvrns-leveledshe.cpp)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
BGV implementation. See https://eprint.iacr.org/2021/204 for details.
 */

#define PROFILE

#include "scheme/bgvrns/bgvrns-leveledshe.h"

#include "scheme/bgvrns/bgvrns-cryptoparameters.h"
#include "ciphertext.h"

namespace lbcrypto {

void LeveledSHEBGVRNS::ModReduceInternalInPlace(Ciphertext<DCRTPoly>& ciphertext, size_t levels) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBGVRNS>(ciphertext->GetCryptoParameters());

    const auto t = ciphertext->GetCryptoParameters()->GetPlaintextModulus();

    std::vector<DCRTPoly>& cv = ciphertext->GetElements();
    usint sizeQl              = cv[0].GetNumOfElements();

    if (sizeQl > levels && sizeQl > 0) {
        for (auto& c : cv) {
            for (size_t i = sizeQl - 1; i >= sizeQl - levels; --i) {
                c.ModReduce(t, cryptoParams->GettModqPrecon(), cryptoParams->GetNegtInvModq(i),
                            cryptoParams->GetNegtInvModqPrecon(i), cryptoParams->GetqlInvModq(i),
                            cryptoParams->GetqlInvModqPrecon(i));
            }
        }
    }
    else {
        std::string errMsg = "ERROR: Not enough towers to support ModReduce.";
        OPENFHE_THROW(errMsg);
    }

    ciphertext->SetLevel(ciphertext->GetLevel() + levels);
    ciphertext->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() - levels);

    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
        for (usint i = 0; i < levels; ++i) {
            NativeInteger modReduceFactor    = cryptoParams->GetModReduceFactorInt(sizeQl - 1 - i);
            NativeInteger modReduceFactorInv = modReduceFactor.ModInverse(t);
            ciphertext->SetScalingFactorInt(ciphertext->GetScalingFactorInt().ModMul(modReduceFactorInv, t));
        }
    }
}

void LeveledSHEBGVRNS::LevelReduceInternalInPlace(Ciphertext<DCRTPoly>& ciphertext, size_t levels) const {
    std::vector<DCRTPoly>& elements = ciphertext->GetElements();
    for (auto& element : elements) {
        element.DropLastElements(levels);
    }
    ciphertext->SetLevel(ciphertext->GetLevel() + levels);
}

void LeveledSHEBGVRNS::AdjustLevelsAndDepthInPlace(Ciphertext<DCRTPoly>& ciphertext1,
                                                   Ciphertext<DCRTPoly>& ciphertext2) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBGVRNS>(ciphertext1->GetCryptoParameters());

    const NativeInteger t(cryptoParams->GetPlaintextModulus());

    usint c1lvl   = ciphertext1->GetLevel();
    usint c2lvl   = ciphertext2->GetLevel();
    usint c1depth = ciphertext1->GetNoiseScaleDeg();
    usint c2depth = ciphertext2->GetNoiseScaleDeg();
    auto sizeQl1  = ciphertext1->GetElements()[0].GetNumOfElements();
    auto sizeQl2  = ciphertext2->GetElements()[0].GetNumOfElements();

    if (c1lvl < c2lvl) {
        if (c1depth == 2) {
            if (c2depth == 2) {
                NativeInteger scf1    = ciphertext1->GetScalingFactorInt();
                NativeInteger scf2    = ciphertext2->GetScalingFactorInt();
                NativeInteger ql1Modt = cryptoParams->GetModReduceFactorInt(sizeQl1 - 1);
                NativeInteger scf1Inv = scf1.ModInverse(t);

                EvalMultCoreInPlace(ciphertext1, scf2.ModMul(scf1Inv, t).ModMul(ql1Modt, t).ConvertToInt());
                ModReduceInternalInPlace(ciphertext1, BASE_NUM_LEVELS_TO_DROP);
                if (c1lvl + 1 < c2lvl) {
                    LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl - 1);
                }
                ciphertext1->SetScalingFactorInt(ciphertext2->GetScalingFactorInt());
            }
            else {
                if (c1lvl + 1 == c2lvl) {
                    ModReduceInternalInPlace(ciphertext1, BASE_NUM_LEVELS_TO_DROP);
                }
                else {
                    NativeInteger scf1    = ciphertext1->GetScalingFactorInt();
                    NativeInteger scf2    = cryptoParams->GetScalingFactorIntBig(c2lvl - 1);
                    NativeInteger ql1Modt = cryptoParams->GetModReduceFactorInt(sizeQl1 - 1);
                    NativeInteger scf1Inv = scf1.ModInverse(t);

                    EvalMultCoreInPlace(ciphertext1, scf2.ModMul(scf1Inv, t).ModMul(ql1Modt, t).ConvertToInt());
                    ModReduceInternalInPlace(ciphertext1, BASE_NUM_LEVELS_TO_DROP);
                    if (c1lvl + 2 < c2lvl) {
                        LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl - 2);
                    }
                    ModReduceInternalInPlace(ciphertext1, BASE_NUM_LEVELS_TO_DROP);
                    ciphertext1->SetScalingFactorInt(ciphertext2->GetScalingFactorInt());
                }
            }
        }
        else {
            if (c2depth == 2) {
                NativeInteger scf1    = ciphertext1->GetScalingFactorInt();
                NativeInteger scf2    = ciphertext2->GetScalingFactorInt();
                NativeInteger scf1Inv = scf1.ModInverse(t);

                EvalMultCoreInPlace(ciphertext1, scf2.ModMul(scf1Inv, t).ConvertToInt());
                LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl);
                ciphertext1->SetScalingFactorInt(scf2);
            }
            else {
                NativeInteger scf1    = ciphertext1->GetScalingFactorInt();
                NativeInteger scf2    = cryptoParams->GetScalingFactorIntBig(c2lvl - 1);
                NativeInteger scf1Inv = scf1.ModInverse(t);

                EvalMultCoreInPlace(ciphertext1, scf2.ModMul(scf1Inv, t).ConvertToInt());
                if (c1lvl + 1 < c2lvl) {
                    LevelReduceInternalInPlace(ciphertext1, c2lvl - c1lvl - 1);
                }
                ModReduceInternalInPlace(ciphertext1, BASE_NUM_LEVELS_TO_DROP);
                ciphertext1->SetScalingFactorInt(ciphertext2->GetScalingFactorInt());
            }
        }
    }
    else if (c1lvl > c2lvl) {
        if (c2depth == 2) {
            if (c1depth == 2) {
                NativeInteger scf2    = ciphertext2->GetScalingFactorInt();
                NativeInteger scf1    = ciphertext1->GetScalingFactorInt();
                NativeInteger ql2Modt = cryptoParams->GetModReduceFactorInt(sizeQl2 - 1);
                NativeInteger scf2Inv = scf2.ModInverse(t);

                EvalMultCoreInPlace(ciphertext2, scf1.ModMul(scf2Inv, t).ModMul(ql2Modt, t).ConvertToInt());
                ModReduceInternalInPlace(ciphertext2, BASE_NUM_LEVELS_TO_DROP);
                if (c2lvl + 1 < c1lvl) {
                    LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl - 1);
                }
                ciphertext2->SetScalingFactorInt(ciphertext1->GetScalingFactorInt());
            }
            else {
                if (c2lvl + 1 == c1lvl) {
                    ModReduceInternalInPlace(ciphertext2, BASE_NUM_LEVELS_TO_DROP);
                }
                else {
                    NativeInteger scf2    = ciphertext2->GetScalingFactorInt();
                    NativeInteger scf1    = cryptoParams->GetScalingFactorIntBig(c1lvl - 1);
                    NativeInteger ql2Modt = cryptoParams->GetModReduceFactorInt(sizeQl2 - 1);
                    NativeInteger scf2Inv = scf2.ModInverse(t);

                    EvalMultCoreInPlace(ciphertext2, scf1.ModMul(scf2Inv, t).ModMul(ql2Modt, t).ConvertToInt());
                    ModReduceInternalInPlace(ciphertext2, BASE_NUM_LEVELS_TO_DROP);
                    if (c2lvl + 2 < c1lvl) {
                        LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl - 2);
                    }
                    ModReduceInternalInPlace(ciphertext2, BASE_NUM_LEVELS_TO_DROP);
                    ciphertext2->SetScalingFactorInt(ciphertext1->GetScalingFactorInt());
                }
            }
        }
        else {
            if (c1depth == 2) {
                NativeInteger scf2    = ciphertext2->GetScalingFactorInt();
                NativeInteger scf1    = ciphertext1->GetScalingFactorInt();
                NativeInteger scf2Inv = scf2.ModInverse(t);

                EvalMultCoreInPlace(ciphertext2, scf1.ModMul(scf2Inv, t).ConvertToInt());
                LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl);
                ciphertext2->SetScalingFactorInt(scf1);
            }
            else {
                NativeInteger scf2    = ciphertext2->GetScalingFactorInt();
                NativeInteger scf1    = cryptoParams->GetScalingFactorIntBig(c1lvl - 1);
                NativeInteger scf2Inv = scf2.ModInverse(t);

                EvalMultCoreInPlace(ciphertext2, scf1.ModMul(scf2Inv, t).ConvertToInt());
                if (c2lvl + 1 < c1lvl) {
                    LevelReduceInternalInPlace(ciphertext2, c1lvl - c2lvl - 1);
                }
                ModReduceInternalInPlace(ciphertext2, BASE_NUM_LEVELS_TO_DROP);
                ciphertext2->SetScalingFactorInt(ciphertext1->GetScalingFactorInt());
            }
        }
    }
    else {
        if (c1depth < c2depth) {
            NativeInteger scf = ciphertext1->GetScalingFactorInt();
            EvalMultCoreInPlace(ciphertext1, scf.ConvertToInt());
        }
        else if (c2depth < c1depth) {
            NativeInteger scf = ciphertext2->GetScalingFactorInt();
            EvalMultCoreInPlace(ciphertext2, scf.ConvertToInt());
        }
    }
}

void LeveledSHEBGVRNS::AdjustLevelsAndDepthToOneInPlace(Ciphertext<DCRTPoly>& ciphertext1,
                                                        Ciphertext<DCRTPoly>& ciphertext2) const {
    AdjustLevelsAndDepthInPlace(ciphertext1, ciphertext2);

    if (ciphertext1->GetNoiseScaleDeg() == 2) {
        ModReduceInternalInPlace(ciphertext1, BASE_NUM_LEVELS_TO_DROP);
        ModReduceInternalInPlace(ciphertext2, BASE_NUM_LEVELS_TO_DROP);
    }
}

void LeveledSHEBGVRNS::EvalMultCoreInPlace(Ciphertext<DCRTPoly>& ciphertext, const NativeInteger& constant) const {
    const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersBGVRNS>(ciphertext->GetCryptoParameters());

    std::vector<DCRTPoly>& cv = ciphertext->GetElements();
    for (usint i = 0; i < cv.size(); ++i) {
        cv[i] *= constant;
    }
    const NativeInteger t(cryptoParams->GetPlaintextModulus());

    ciphertext->SetNoiseScaleDeg(ciphertext->GetNoiseScaleDeg() + 1);
    if (cryptoParams->GetScalingTechnique() == FLEXIBLEAUTO || cryptoParams->GetScalingTechnique() == FLEXIBLEAUTOEXT) {
        ciphertext->SetScalingFactorInt(ciphertext->GetScalingFactorInt().ModMul(constant, t));
    }
}

usint LeveledSHEBGVRNS::FindAutomorphismIndex(usint index, usint m) const {
    return FindAutomorphismIndex2n(index, m);
}

}  // namespace lbcrypto