Program Listing for File lwe-pke.cpp
↰ Return to documentation for file (binfhe/lib/lwe-pke.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#include "lwe-pke.h"
#include "math/binaryuniformgenerator.h"
#include "math/discreteuniformgenerator.h"
#include "math/ternaryuniformgenerator.h"
#include "utils/parallel.h"
namespace lbcrypto {
// the main rounding operation used in ModSwitch (as described in Section 3 of
// https://eprint.iacr.org/2014/816) The idea is that Round(x) = 0.5 + Floor(x)
static inline NativeInteger RoundqQ(NativeInteger v, NativeInteger q, NativeInteger Q) {
return NativeInteger(static_cast<BasicInteger>(
std::floor(0.5 + v.ConvertToDouble() * q.ConvertToDouble() / Q.ConvertToDouble())))
.Mod(q);
}
LWEPrivateKey LWEEncryptionScheme::KeyGen(uint32_t size, NativeInteger modulus) const {
TernaryUniformGeneratorImpl<NativeVector> tug;
return std::make_shared<LWEPrivateKeyImpl>(tug.GenerateVector(size, modulus));
}
LWEPrivateKey LWEEncryptionScheme::KeyGenGaussian(uint32_t size, NativeInteger modulus) const {
DiscreteGaussianGeneratorImpl<NativeVector> dgg(3.19);
return std::make_shared<LWEPrivateKeyImpl>(dgg.GenerateVector(size, modulus));
}
// size is the ring dimension N, modulus is the large Q used in RGSW encryption of bootstrapping.
LWEKeyPair LWEEncryptionScheme::KeyGenPair(const std::shared_ptr<LWECryptoParams>& params) const {
uint32_t dim = params->GetN();
auto modulus = params->GetQ();
// generate secret vector skN of ring dimension N
auto skN = (params->GetKeyDist() == GAUSSIAN) ? KeyGenGaussian(dim, modulus) : KeyGen(dim, modulus);
// generate public key pkN corresponding to secret key skN
auto pkN = PubKeyGen(params, skN);
// return the public key (A, v), private key sk pair
return std::make_shared<LWEKeyPairImpl>(std::move(pkN), std::move(skN));
}
// size is the ring dimension N, modulus is the large Q used in RGSW encryption of bootstrapping.
LWEPublicKey LWEEncryptionScheme::PubKeyGen(const std::shared_ptr<LWECryptoParams>& params,
ConstLWEPrivateKey& skN) const {
const uint32_t dim = params->GetN();
const auto modulus = params->GetQ();
const auto mu = modulus.ComputeMu();
const auto& ske = skN->GetElement();
std::vector<NativeVector> A(dim);
auto v = params->GetDgg().GenerateVector(dim, modulus);
DiscreteUniformGeneratorImpl<NativeVector> dug(modulus);
// compute v = As + e
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(dim)) firstprivate(dug)
for (uint32_t j = 0; j < dim; ++j) {
A[j] = dug.GenerateVector(dim);
for (uint32_t i = 0; i < dim; ++i)
v[j].ModAddFastEq(A[j][i].ModMulFast(ske[i], modulus, mu), modulus);
}
return std::make_shared<LWEPublicKeyImpl>(std::move(A), std::move(v));
}
// classical LWE encryption
// a is a randomly uniform vector of dimension n; with integers mod q
// b = a*s + e + m floor(q/4) is an integer mod q
LWECiphertext LWEEncryptionScheme::Encrypt(const std::shared_ptr<LWECryptoParams>& params, ConstLWEPrivateKey& sk,
LWEPlaintext m, LWEPlaintextModulus p, NativeInteger q) const {
if (q % p != 0 && q.ConvertToInt() & (1 == 0))
OPENFHE_THROW("plaintext modulus p must divide ciphertext modulus q");
NativeVector s = sk->GetElement();
s.SwitchModulus(q);
DiscreteUniformGeneratorImpl<NativeVector> dug;
const uint32_t n = s.GetLength();
NativeVector a = dug.GenerateVector(n, q);
NativeInteger b = (m % p) * (q / p) + params->GetDgg().GenerateInteger(q);
NativeInteger mu = q.ComputeMu();
for (uint32_t i = 0; i < n; ++i)
b += a[i].ModMulFast(s[i], q, mu);
return std::make_shared<LWECiphertextImpl>(std::move(a), b.Mod(q), p);
}
// classical public key LWE encryption
// a = As' + e' of dimension n; with integers mod q
// b = vs' + e" + m floor(q/4) is an integer mod q
LWECiphertext LWEEncryptionScheme::EncryptN(const std::shared_ptr<LWECryptoParams>& params, ConstLWEPublicKey& pk,
LWEPlaintext m, LWEPlaintextModulus p, NativeInteger q) const {
if (q % p != 0 && q.ConvertToInt() & (1 == 0))
OPENFHE_THROW("plaintext modulus p must divide ciphertext modulus q");
auto bp = pk->Getv();
bp.SwitchModulus(q); // todo : this is probably not required
uint32_t N = bp.GetLength();
TernaryUniformGeneratorImpl<NativeVector> tug;
NativeVector sp = tug.GenerateVector(N, q);
// compute a in the ciphertext (a, b)
const auto& dgg = params->GetDgg();
auto a = dgg.GenerateVector(N, q);
auto& A = pk->GetA();
for (uint32_t j = 0; j < N; ++j) {
// columnwise a = A_1s1 + ... + A_NsN
a.ModAddEq(A[j].ModMul(sp[j]));
}
// compute b in ciphertext (a,b)
NativeInteger mu = q.ComputeMu();
NativeInteger b = (m % p) * (q / p) + dgg.GenerateInteger(q);
if (b >= q)
b.ModEq(q);
for (uint32_t i = 0; i < N; ++i)
b.ModAddFastEq(bp[i].ModMulFast(sp[i], q, mu), q);
return std::make_shared<LWECiphertextImpl>(std::move(a), b, p);
}
// convert ciphertext with modulus Q and dimension N to ciphertext with modulus q and dimension n
LWECiphertext LWEEncryptionScheme::SwitchCTtoqn(const std::shared_ptr<LWECryptoParams>& params,
ConstLWESwitchingKey& ksk, ConstLWECiphertext& ct) const {
// Modulus switching to a middle step Q'
auto ctMS = ModSwitch(params->GetqKS(), ct);
// Key switching
auto ctKS = KeySwitch(params, ksk, ctMS);
// Modulus switching
return ModSwitch(params->Getq(), ctKS);
}
// classical LWE decryption
// m_result = Round(4/q * (b - a*s))
void LWEEncryptionScheme::Decrypt(const std::shared_ptr<LWECryptoParams>& params, ConstLWEPrivateKey& sk,
ConstLWECiphertext& ct, LWEPlaintext* result, LWEPlaintextModulus p) const {
if (sk == nullptr)
OPENFHE_THROW("PrivateKey is empty");
else if (ct == nullptr)
OPENFHE_THROW("Ciphertext is empty");
else if (result == nullptr)
OPENFHE_THROW("result is nullptr");
// TODO in the future we should add a check to make sure sk parameters match
// the ct parameters
// Create local variables to speed up the computations
auto q = ct->GetModulus();
if (q % (p * 2) != 0 && q.ConvertToInt() & (1 == 0))
OPENFHE_THROW("plaintext modulus p*2 must divide ciphertext modulus q");
const auto& a = ct->GetA();
auto s = sk->GetElement();
uint32_t n = s.GetLength();
auto mu = q.ComputeMu();
s.SwitchModulus(q);
NativeInteger inner(0);
for (uint32_t i = 0; i < n; ++i) {
inner += a[i].ModMulFast(s[i], q, mu);
}
inner.ModEq(q);
NativeInteger r = ct->GetB();
r.ModSubFastEq(inner, q);
// Alternatively, rounding can be done as
// *result = (r.MultiplyAndRound(NativeInteger(4),q)).ConvertToInt();
// But the method below is a more efficient way of doing the rounding
// the idea is that Round(4/q x) = q/8 + Floor(4/q x)
r.ModAddFastEq((q / (p * 2)), q);
*result = ((NativeInteger(p) * r) / q).ConvertToInt();
#if defined(WITH_NOISE_DEBUG)
double error =
(static_cast<double>(p) * (r.ConvertToDouble() - q.ConvertToDouble() / (p * 2))) / q.ConvertToDouble() -
static_cast<double>(*result);
std::cerr << error * q.ConvertToDouble() / static_cast<double>(p) << std::endl;
#endif
}
void LWEEncryptionScheme::EvalAddEq(LWECiphertext& ct1, ConstLWECiphertext& ct2) const {
ct1->GetA().ModAddEq(ct2->GetA());
ct1->SetB(ct1->GetB().ModAddFast(ct2->GetB(), ct1->GetModulus()));
}
void LWEEncryptionScheme::EvalAddConstEq(LWECiphertext& ct, NativeInteger cnst) const {
ct->SetB(ct->GetB().ModAddFast(cnst, ct->GetModulus()));
}
void LWEEncryptionScheme::EvalSubEq(LWECiphertext& ct1, ConstLWECiphertext& ct2) const {
ct1->GetA().ModSubEq(ct2->GetA());
ct1->SetB(ct1->GetB().ModSubFast(ct2->GetB(), ct1->GetModulus()));
}
void LWEEncryptionScheme::EvalSubEq2(ConstLWECiphertext& ct1, LWECiphertext& ct2) const {
ct2->GetA() = ct1->GetA().ModSub(ct2->GetA());
ct2->SetB(ct1->GetB().ModSubFast(ct2->GetB(), ct1->GetModulus()));
}
void LWEEncryptionScheme::EvalSubConstEq(LWECiphertext& ct, NativeInteger cnst) const {
ct->SetB(ct->GetB().ModSubFast(cnst, ct->GetModulus()));
}
void LWEEncryptionScheme::EvalMultConstEq(LWECiphertext& ct1, NativeInteger cnst) const {
ct1->GetA().ModMulEq(cnst);
ct1->SetB(ct1->GetB().ModMulFast(cnst, ct1->GetModulus()));
}
// Modulus switching - directly applies the scale-and-round operation RoundQ
LWECiphertext LWEEncryptionScheme::ModSwitch(NativeInteger q, ConstLWECiphertext& ctQ) const {
uint32_t n = ctQ->GetLength();
auto Q = ctQ->GetModulus();
NativeVector a(n, q);
for (uint32_t i = 0; i < n; ++i)
a[i] = RoundqQ(ctQ->GetA()[i], q, Q);
return std::make_shared<LWECiphertextImpl>(std::move(a), RoundqQ(ctQ->GetB(), q, Q));
}
// Switching key as described in Section 3 of https://eprint.iacr.org/2014/816
LWESwitchingKey LWEEncryptionScheme::KeySwitchGen(const std::shared_ptr<LWECryptoParams>& params,
ConstLWEPrivateKey& sk, ConstLWEPrivateKey& skN) const {
NativeInteger qKS(params->GetqKS());
NativeInteger baseKS(params->GetBaseKS());
NativeInteger value{1};
const uint32_t digitCount = std::ceil(std::log(qKS.ConvertToDouble()) / std::log(baseKS.ConvertToDouble()));
std::vector<NativeInteger> digitsKS(digitCount);
for (uint32_t i = 0; i < digitCount; ++i) {
digitsKS[i] = value;
value *= baseKS;
}
// newSK stores negative values using modulus q
// we need to switch to modulus Q
NativeVector sv(sk->GetElement());
sv.SwitchModulus(qKS);
NativeVector svN(skN->GetElement());
svN.SwitchModulus(qKS);
DiscreteUniformGeneratorImpl<NativeVector> dug(qKS);
NativeInteger mu(qKS.ComputeMu());
const uint32_t N(params->GetN());
const uint32_t m(baseKS.ConvertToInt<uint32_t>());
const uint32_t n(params->Getn());
std::vector<std::vector<std::vector<NativeVector>>> resultVecA(N);
std::vector<std::vector<std::vector<NativeInteger>>> resultVecB(N);
#if !defined(__MINGW32__) && !defined(__MINGW64__)
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(N)) firstprivate(dug)
#endif
for (uint32_t i = 0; i < N; ++i) {
std::vector<std::vector<NativeVector>> vector1A;
vector1A.reserve(m);
std::vector<std::vector<NativeInteger>> vector1B;
vector1B.reserve(m);
for (uint32_t j = 0; j < m; ++j) {
std::vector<NativeVector> vector2A;
vector2A.reserve(digitCount);
std::vector<NativeInteger> vector2B;
vector2B.reserve(digitCount);
for (uint32_t k = 0; k < digitCount; ++k) {
vector2A.emplace_back(dug.GenerateVector(n));
NativeVector& a = vector2A.back();
NativeInteger b =
(params->GetDggKS().GenerateInteger(qKS)).ModAdd(svN[i].ModMul(j * digitsKS[k], qKS), qKS);
#if NATIVEINT == 32
for (uint32_t i = 0; i < n; ++i)
b.ModAddFastEq(a[i].ModMulFast(sv[i], qKS, mu), qKS);
#else
for (uint32_t i = 0; i < n; ++i)
b += a[i].ModMulFast(sv[i], qKS, mu);
b.ModEq(qKS);
#endif
vector2B.emplace_back(b);
}
vector1A.push_back(std::move(vector2A));
vector1B.push_back(std::move(vector2B));
}
resultVecA[i] = std::move(vector1A);
resultVecB[i] = std::move(vector1B);
}
return std::make_shared<LWESwitchingKeyImpl>(std::move(resultVecA), std::move(resultVecB));
}
// the key switching operation as described in Section 3 of
// https://eprint.iacr.org/2014/816
LWECiphertext LWEEncryptionScheme::KeySwitch(const std::shared_ptr<LWECryptoParams>& params, ConstLWESwitchingKey& K,
ConstLWECiphertext& ctQN) const {
const uint32_t n(params->Getn());
const uint32_t N(params->GetN());
NativeInteger Q(params->GetqKS());
NativeInteger::Integer baseKS(params->GetBaseKS());
const uint32_t digitCount = std::ceil(std::log(Q.ConvertToDouble()) / std::log(static_cast<double>(baseKS)));
NativeVector a(n, Q);
NativeInteger b(ctQN->GetB());
for (uint32_t i = 0; i < N; ++i) {
auto& refA = K->GetElementsA()[i];
auto& refB = K->GetElementsB()[i];
NativeInteger::Integer atmp(ctQN->GetA()[i].ConvertToInt());
for (uint32_t j = 0; j < digitCount; ++j) {
const auto a0 = (atmp % baseKS);
atmp /= baseKS;
b.ModSubFastEq(refB[a0][j], Q);
auto& refAj = refA[a0][j];
for (uint32_t k = 0; k < n; ++k)
a[k].ModSubFastEq(refAj[k], Q);
}
}
return std::make_shared<LWECiphertextImpl>(std::move(a), b);
}
// noiseless LWE embedding
// a is a zero vector of dimension n; with integers mod q
// b = m floor(q/4) is an integer mod q
LWECiphertext LWEEncryptionScheme::NoiselessEmbedding(const std::shared_ptr<LWECryptoParams>& params,
LWEPlaintext m) const {
NativeInteger q(params->Getq());
return std::make_shared<LWECiphertextImpl>(NativeVector(params->Getn(), q), (q >> 2) * m);
}
}; // namespace lbcrypto