Program Listing for File ckkspackedencoding.cpp
↰ Return to documentation for file (pke/lib/encoding/ckkspackedencoding.cpp)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#include "encoding/ckkspackedencoding.h"
#include "lattice/lat-hal.h"
#include "math/hal/basicint.h"
#include "math/dftransform.h"
#include "utils/exception.h"
#include "utils/inttypes.h"
#include "utils/utilities.h"
#include <complex>
#include <cmath>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace lbcrypto {
std::vector<std::complex<double>> Conjugate(const std::vector<std::complex<double>>& vec) {
uint32_t n = vec.size();
std::vector<std::complex<double>> result(n);
for (uint32_t i = 1; i < n; ++i) {
result[i] = {-vec[n - i].imag(), -vec[n - i].real()};
}
result[0] = {vec[0].real(), -vec[0].imag()};
return result;
}
// Estimate standard deviation using the imaginary part of decoded vector z
// Compute m(X) - m(1/X) as a proxy for z - Conj(z) = 2*Im(z)
// vec is m(X) corresponding to z
// conjugate is m(1/X) corresponding to Conj(z)
double StdDev(const std::vector<std::complex<double>>& vec, const std::vector<std::complex<double>>& conjugate) {
uint32_t slots = vec.size();
if (1 == slots) {
return vec[0].imag();
}
// ring dimension
uint32_t dslots = slots * 2;
// extract the complex part using identity z - Conj(z) == 2*Im(z)
// here we actually compute m(X) - m(1/X) corresponding to 2*Im(z).
// we only need first Nh/2 + 1 components of the imaginary part
// as the remaining Nh/2 - 1 components have a symmetry
// w.r.t. components from 1 to Nh/2 - 1
std::vector<std::complex<double>> complexValues(slots / 2 + 1);
for (size_t i = 0; i < slots / 2 + 1; ++i) {
complexValues[i] = vec[i] - conjugate[i];
}
// Calculate the mean
auto mean_func = [](double accumulator, const std::complex<double>& val) {
return accumulator + (val.real() + val.imag());
};
// use the symmetry condition
double mean = 2 * std::accumulate(complexValues.begin() + 1, complexValues.begin() + slots / 2, 0.0, mean_func);
// and then add values at indices 0 and Nh/2
mean += complexValues[0].imag();
mean += 2 * complexValues[slots / 2].real();
// exclude the real part at index 0 as it is always 0
mean /= static_cast<double>(dslots) - 1.0;
// Now calculate the variance
auto variance_func = [&mean](double accumulator, const std::complex<double>& val) {
return accumulator + (val.real() - mean) * (val.real() - mean) + (val.imag() - mean) * (val.imag() - mean);
};
// use the symmetry condition
double variance = 2 * accumulate(complexValues.begin() + 1, complexValues.begin() + slots / 2, 0.0, variance_func);
// and then add values at indices 0 and Nh/2
variance += (complexValues[0].imag() - mean) * (complexValues[0].imag() - mean);
variance += 2 * (complexValues[slots / 2].real() - mean) * (complexValues[slots / 2].real() - mean);
// exclude the real part at index 0 as it is always 0
variance /= static_cast<double>(dslots) - 2.0;
// scale down by 2 as we have worked with 2*Im(z) up to this point
return 0.5 * std::sqrt(variance);
}
bool CKKSPackedEncoding::Encode() {
if (isEncoded)
return true;
if (typeFlag != IsDCRTPoly)
OPENFHE_THROW("Only DCRTPoly is supported for CKKS.");
if (slots < value.size()) {
std::string errMsg = std::string("The number of slots [") + std::to_string(slots) +
"] is less than the size of data [" + std::to_string(value.size()) + "]";
OPENFHE_THROW(errMsg);
}
auto inverse{value};
inverse.resize(slots);
uint32_t ringDim = GetElementRingDimension();
DiscreteFourierTransform::FFTSpecialInv(inverse, ringDim * 2);
#if NATIVEINT == 128
uint64_t pBits = encodingParams->GetPlaintextModulus();
uint32_t precision = 52;
double powP = std::pow(2, precision);
int32_t pCurrent = pBits - precision;
// the idea is to break down real and imaginary parts
// expressed as input_mantissa * 2^input_exponent
// into (input_mantissa * 2^52) * 2^(p - 52 + input_exponent)
// to preserve 52-bit precision of doubles
// when converting to 128-bit numbers
std::vector<int128_t> temp(2 * slots);
int128_t MaxBitValue = Max128BitValue();
for (uint32_t i = 0; i < slots; ++i) {
// Check for possible overflow in llround function
int32_t n1 = 0;
// extract the mantissa of real part and multiply it by 2^52
double dre = static_cast<double>(std::frexp(inverse[i].real(), &n1) * powP);
int32_t n2 = 0;
// extract the mantissa of imaginary part and multiply it by 2^52
double dim = static_cast<double>(std::frexp(inverse[i].imag(), &n2) * powP);
if (is128BitOverflow(dre) || is128BitOverflow(dim)) {
OPENFHE_THROW("Overflow, try to decrease scaling factor");
}
int64_t re64 = std::llround(dre);
int32_t pRemaining = pCurrent + n1;
int128_t re = 0;
if (pRemaining < 0) {
re = re64 >> (-pRemaining);
}
else {
int128_t pPowRemaining = ((int128_t)1) << pRemaining;
re = pPowRemaining * re64;
}
int64_t im64 = std::llround(dim);
pRemaining = pCurrent + n2;
int128_t im = 0;
if (pRemaining < 0) {
im = im64 >> (-pRemaining);
}
else {
int128_t pPowRemaining = (static_cast<int64_t>(1)) << pRemaining;
im = pPowRemaining * im64;
}
temp[i] = (re < 0) ? MaxBitValue + re : re;
temp[i + slots] = (im < 0) ? MaxBitValue + im : im;
if (is128BitOverflow(temp[i]) || is128BitOverflow(temp[i + slots])) {
OPENFHE_THROW("Overflow, try to decrease scaling factor");
}
}
DCRTPoly::Integer intPowP = NativeInteger(1) << pBits;
#else // NATIVEINT == 64
int32_t logc = std::numeric_limits<int32_t>::min();
for (uint32_t i = 0; i < slots; ++i) {
inverse[i] *= scalingFactor;
if (inverse[i].real() != 0.) {
auto logci = static_cast<int32_t>(std::ceil(std::log2(std::abs(inverse[i].real()))));
if (logc < logci)
logc = logci;
}
if (inverse[i].imag() != 0.) {
auto logci = static_cast<int32_t>(std::ceil(std::log2(std::abs(inverse[i].imag()))));
if (logc < logci)
logc = logci;
}
}
logc = (logc == std::numeric_limits<int32_t>::min()) ? 0 : logc;
if (logc < 0)
OPENFHE_THROW("Scaling factor too small");
// Compute approxFactor, a value to scale down by in case the value exceeds a 64-bit integer.
constexpr int32_t MAX_BITS_IN_WORD = LargeScalingFactorConstants::MAX_BITS_IN_WORD;
int32_t logValid = (logc <= MAX_BITS_IN_WORD) ? logc : MAX_BITS_IN_WORD;
int32_t logApprox = logc - logValid;
double approxFactor = std::pow(2, logApprox);
double invLen = static_cast<double>(slots);
std::vector<int64_t> temp(2 * slots);
int64_t MaxBitValue = Max64BitValue();
for (uint32_t i = 0; i < slots; ++i) {
// Scale down by approxFactor in case the value exceeds a 64-bit integer.
double dre = inverse[i].real() / approxFactor;
double dim = inverse[i].imag() / approxFactor;
// Check for possible overflow
if (is64BitOverflow(dre) || is64BitOverflow(dim)) {
// IFFT formula:
// x[n] = (1/N) * \Sum^(N-1)_(k=0) X[k] * exp( j*2*pi*n*k/N )
// n is i
// k is idx below
// N is inverse.size()
//
// In the following, we switch to original data domain,
// and we identify the component that has the maximum
// contribution to the values in the iFFT domain. We do
// this to report it to the user, so they can identify
// large inputs.
DiscreteFourierTransform::FFTSpecial(inverse, ringDim * 2);
double factor = 2 * M_PI * i;
double realMax = -1, imagMax = -1;
uint32_t realMaxIdx = -1, imagMaxIdx = -1;
for (uint32_t idx = 0; idx < slots; ++idx) {
// exp( j*2*pi*n*k/N )
std::complex<double> expFactor = {cos((factor * idx) / invLen), sin((factor * idx) / invLen)};
// X[k] * exp( j*2*pi*n*k/N )
std::complex<double> prodFactor = inverse[idx] * expFactor;
double realVal = prodFactor.real();
double imagVal = prodFactor.imag();
if (realVal > realMax) {
realMax = realVal;
realMaxIdx = idx;
}
if (imagVal > imagMax) {
imagMax = imagVal;
imagMaxIdx = idx;
}
}
auto scaledInputSize = std::ceil(std::log2(dre));
std::stringstream buffer;
buffer << std::endl
<< "Overflow in data encoding - scaled input is too large to fit "
"into a NativeInteger (60 bits). Try decreasing scaling factor."
<< std::endl;
buffer << "Overflow at slot number " << i << std::endl;
buffer << "- Max real part contribution from input[" << realMaxIdx << "]: " << realMax << std::endl;
buffer << "- Max imaginary part contribution from input[" << imagMaxIdx << "]: " << imagMax << std::endl;
buffer << "Scaling factor is " << std::ceil(std::log2(scalingFactor)) << " bits " << std::endl;
buffer << "Scaled input is " << scaledInputSize << " bits " << std::endl;
OPENFHE_THROW(buffer.str());
}
int64_t re = std::llround(dre);
int64_t im = std::llround(dim);
temp[i] = (re < 0) ? MaxBitValue + re : re;
temp[i + slots] = (im < 0) ? MaxBitValue + im : im;
}
DCRTPoly::Integer intPowP(static_cast<uint64_t>(std::llround(scalingFactor)));
#endif
auto nativeParams = encodedVectorDCRT.GetParams()->GetParams();
uint32_t numTowers = nativeParams.size();
std::vector<DCRTPoly::Integer> moduli(numTowers);
for (uint32_t i = 0; i < numTowers; i++) {
moduli[i] = nativeParams[i]->GetModulus();
NativeVector nativeVec(ringDim, nativeParams[i]->GetModulus());
FitToNativeVector(temp, MaxBitValue, &nativeVec);
NativePoly element = GetElement<DCRTPoly>().GetElementAtIndex(i);
element.SetValues(std::move(nativeVec), Format::COEFFICIENT); // output was in coefficient format
encodedVectorDCRT.SetElementAtIndex(i, std::move(element));
}
// We want to scale temp by 2^(pd), and the loop starts from j=2
// because temp is already scaled by 2^p in the re/im loop above,
// and currPowP already is 2^p.
std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);
auto currPowP = crtPowP;
for (size_t i = 2; i < noiseScaleDeg; ++i)
currPowP = CKKSPackedEncoding::CRTMult(currPowP, crtPowP, moduli);
if (noiseScaleDeg > 1)
encodedVectorDCRT = encodedVectorDCRT.Times(currPowP);
#if NATIVEINT == 64
// Scale back up by the approxFactor to get the correct encoding.
int32_t MAX_LOG_STEP = 60;
if (logApprox > 0) {
int32_t logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP;
DCRTPoly::Integer intStep = static_cast<uint64_t>(1) << logStep;
std::vector<DCRTPoly::Integer> crtApprox(numTowers, intStep);
logApprox -= logStep;
while (logApprox > 0) {
int32_t logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP;
DCRTPoly::Integer intStep = static_cast<uint64_t>(1) << logStep;
std::vector<DCRTPoly::Integer> crtSF(numTowers, intStep);
crtApprox = CRTMult(crtApprox, crtSF, moduli);
logApprox -= logStep;
}
encodedVectorDCRT = encodedVectorDCRT.Times(crtApprox);
}
#endif
GetElement<DCRTPoly>().SetFormat(Format::EVALUATION);
scalingFactor = std::pow(scalingFactor, noiseScaleDeg);
return isEncoded = true;
}
bool CKKSPackedEncoding::Decode(size_t noiseScaleDeg, double scalingFactor, ScalingTechnique scalTech,
ExecutionMode executionMode) {
double p = encodingParams->GetPlaintextModulus();
double powP = 0.0;
uint32_t Nh = GetElementRingDimension() / 2;
uint32_t gap = Nh / slots;
value.clear();
std::vector<std::complex<double>> curValues(slots);
if (typeFlag == IsNativePoly) {
if (scalTech == FLEXIBLEAUTO || scalTech == FLEXIBLEAUTOEXT || scalTech == COMPOSITESCALINGAUTO ||
scalTech == COMPOSITESCALINGMANUAL)
powP = std::pow(scalingFactor, -1);
else
powP = std::pow(2, -p);
NativeInteger q = GetElementModulus().ConvertToInt();
NativeInteger qHalf = q >> 1;
for (uint32_t i = 0, idx = 0; i < slots; ++i, idx += gap) {
std::complex<double> cur;
if (GetElement<NativePoly>()[idx] > qHalf)
cur.real(-((q - GetElement<NativePoly>()[idx])).ConvertToDouble());
else
cur.real((GetElement<NativePoly>()[idx]).ConvertToDouble());
if (GetElement<NativePoly>()[idx + Nh] > qHalf)
cur.imag(-((q - GetElement<NativePoly>()[idx + Nh])).ConvertToDouble());
else
cur.imag((GetElement<NativePoly>()[idx + Nh]).ConvertToDouble());
curValues[i] = cur;
}
// clears the values containing information about the noise
GetElement<NativePoly>().SetValuesToZero();
}
else {
powP = std::pow(2, -p);
// we will bring down the scaling factor to 2^p
double scalingFactorPre = 0.0;
if (scalTech == FLEXIBLEAUTO || scalTech == FLEXIBLEAUTOEXT || scalTech == COMPOSITESCALINGAUTO ||
scalTech == COMPOSITESCALINGMANUAL)
scalingFactorPre = std::pow(scalingFactor, -1) * std::pow(2, p);
else
scalingFactorPre = std::pow(2, -p * (noiseScaleDeg - 1));
const BigInteger& q = GetElementModulus();
BigInteger qHalf = q >> 1;
for (size_t i = 0, idx = 0; i < slots; ++i, idx += gap) {
std::complex<double> cur;
if (GetElement<Poly>()[idx] > qHalf)
cur.real(-((q - GetElement<Poly>()[idx])).ConvertToDouble() * scalingFactorPre);
else
cur.real((GetElement<Poly>()[idx]).ConvertToDouble() * scalingFactorPre);
if (GetElement<Poly>()[idx + Nh] > qHalf)
cur.imag(-((q - GetElement<Poly>()[idx + Nh])).ConvertToDouble() * scalingFactorPre);
else
cur.imag((GetElement<Poly>()[idx + Nh]).ConvertToDouble() * scalingFactorPre);
curValues[i] = cur;
}
// clears the values containing information about the noise
GetElement<Poly>().SetValuesToZero();
}
// the code below adds a Gaussian noise to the decrypted result
// to prevent key recovery attacks.
// The standard deviation of the Gaussian noise is sqrt(M+1)*stddev,
// where stddev is the standard deviation estimated using the imaginary
// component and M is the extra factor that increases the number of decryption
// attacks that is needed to average out the added Gaussian noise (after the
// noise is removed, the attacker still has to find the secret key using the
// real part only, which requires another attack). By default (M = 1), stddev
// requires at least 128 decryption queries (in practice the values are
// typically closer to 10,000 or so). Then M can be used to increase this
// number further by M^2 (as desired for a given application). By default we
// we set M to 1.
// compute m(1/X) corresponding to Conj(z), where z is the decoded vector
auto conjugate = Conjugate(curValues);
// Estimate standard deviation from 1/2 (m(X) - m(1/x)),
// which corresponds to Im(z)
double stddev = StdDev(curValues, conjugate);
double logstd = std::log2(stddev);
if (executionMode == EXEC_NOISE_ESTIMATION) {
m_logError = logstd;
}
else {
// if stddev < sqrt{N}/8 (minimum approximation error that can be achieved)
if (stddev < 0.125 * std::sqrt(GetElementRingDimension())) {
stddev = 0.125 * std::sqrt(GetElementRingDimension());
}
// if stddev < sqrt{N}/4 (minimum approximation error that can be achieved)
// if (stddev < 0.125 * std::sqrt(GetElementRingDimension())) {
// if (noiseScaleDeg <= 1) {
// OPENFHE_THROW(
// "The decryption failed because the approximation error is
// " "too small. Check the protocol used. ");
// } else { // noiseScaleDeg > 1 and no rescaling operations have been applied yet
// stddev = 0.125 * std::sqrt(GetElementRingDimension());
// }
// }
if (ckksDataType == REAL) {
// If less than 5 bits of precision is observed
if (logstd > p - 5.0)
OPENFHE_THROW(
"The decryption failed because the approximation error is "
"too high. Check the parameters. ");
}
// real values
std::vector<std::complex<double>> realValues(slots);
// CKKS_M_FACTOR is a compile-level parameter
// set to 1 by default
stddev = std::sqrt(CKKS_M_FACTOR + 1) * stddev;
double scale = (ckksDataType == REAL) ? 0.5 * powP : powP;
// TODO temporary removed errors
std::normal_distribution<> d(0, stddev);
PRNG& g = PseudoRandomNumberGenerator::GetPRNG();
// Alternative way to do Gaussian sampling
// DiscreteGaussianGenerator dgg;
// TODO we can sample Nh integers instead of 2*Nh
// We would add sampling only for even indices of i.
// This change should be done together with the one below.
for (size_t i = 0; i < slots; ++i) {
double real = scale * curValues[i].real();
double imag = scale * curValues[i].imag();
if (ckksDataType == REAL) {
real += scale * conjugate[i].real() + powP * d(g);
// real += powP * dgg.GenerateIntegerKarney(0.0, stddev);
imag += scale * conjugate[i].imag() + powP * d(g);
// imag += powP * dgg.GenerateIntegerKarney(0.0, stddev);
}
realValues[i].real(real);
realValues[i].imag(imag);
}
// TODO we can half the dimension for the FFT by decoding in
// Z[X + 1/X]/(X^n + 1). This would change the complexity from n*logn to
// roughly (n/2)*log(n/2). This change should be done together with the one
// above.
DiscreteFourierTransform::FFTSpecial(realValues, GetElementRingDimension() * 2);
if (ckksDataType == REAL) {
// clears all imaginary values for security reasons
for (auto& val : realValues) {
val.imag(0.0);
}
// sets an estimate of the approximation error
m_logError = std::round(std::log2(stddev * std::sqrt(2 * slots)));
}
else {
m_logError = 0;
}
value = realValues;
}
return true;
}
void CKKSPackedEncoding::Destroy() {}
void CKKSPackedEncoding::FitToNativeVector(const std::vector<int64_t>& vec, int64_t bigBound,
NativeVector* nativeVec) const {
NativeInteger bigValueHf(bigBound >> 1);
NativeInteger modulus(nativeVec->GetModulus());
NativeInteger diff = bigBound - modulus;
uint32_t ringDim = GetElementRingDimension();
uint32_t dslots = vec.size();
uint32_t gap = ringDim / dslots;
for (uint32_t i = 0; i < vec.size(); i++) {
NativeInteger n(vec[i]);
if (n > bigValueHf) {
(*nativeVec)[gap * i] = n.ModSub(diff, modulus);
}
else {
(*nativeVec)[gap * i] = n.Mod(modulus);
}
}
}
#if NATIVEINT == 128
void CKKSPackedEncoding::FitToNativeVector(const std::vector<int128_t>& vec, int128_t bigBound,
NativeVector* nativeVec) const {
NativeInteger bigValueHf((uint128_t)bigBound >> 1);
NativeInteger modulus(nativeVec->GetModulus());
NativeInteger diff = NativeInteger((uint128_t)bigBound) - modulus;
uint32_t ringDim = GetElementRingDimension();
uint32_t dslots = vec.size();
uint32_t gap = ringDim / dslots;
for (uint32_t i = 0; i < vec.size(); i++) {
NativeInteger n((uint128_t)vec[i]);
if (n > bigValueHf) {
(*nativeVec)[gap * i] = n.ModSub(diff, modulus);
}
else {
(*nativeVec)[gap * i] = n.Mod(modulus);
}
}
}
#endif
} // namespace lbcrypto