Program Listing for File ckkspackedencoding.h
↰ Return to documentation for file (pke/include/encoding/ckkspackedencoding.h)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
#ifndef LBCRYPTO_UTILS_CKKSPACKEDEXTENCODING_H
#define LBCRYPTO_UTILS_CKKSPACKEDEXTENCODING_H
#include "constants.h"
#include "encoding/encodingparams.h"
#include "encoding/plaintext.h"
#include "math/hal/basicint.h"
#include <algorithm>
#include <functional>
#include <initializer_list>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
namespace lbcrypto {
class CKKSPackedEncoding : public PlaintextImpl {
private:
std::vector<std::complex<double>> value;
double m_logError = 0.;
public:
// these two constructors are used inside of Decrypt
template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, CKKSDataType ckksdt = REAL)
: PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
ckksDataType = ckksdt;
slots = GetDefaultSlotSize();
}
/*
* @param noiseScaleDeg degree of the scaling factor of a plaintext
* @param level level of plaintext to create.
* @param scFact scaling factor of a plaintext of this level at depth 1.
*
*/
template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<std::complex<double>>& v,
size_t nsdeg, uint32_t lvl, double scFact, uint32_t slts, CKKSDataType ckksdt = REAL)
: PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(v) {
ckksDataType = ckksdt;
scalingFactor = scFact;
level = lvl;
noiseScaleDeg = nsdeg;
slots = GetDefaultSlotSize(slts, v.size());
if (ckksDataType == REAL) {
auto* rvptr = reinterpret_cast<double*>(value.data()) + 1;
auto* limit = rvptr + 2 * value.size();
for (; rvptr < limit; rvptr += 2)
*rvptr = 0;
}
}
explicit CKKSPackedEncoding(const std::vector<std::complex<double>>& v, uint32_t s)
: PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(v) {
slots = GetDefaultSlotSize(s, v.size());
// Assumes ckksDataType = REAL
auto* rvptr = reinterpret_cast<double*>(value.data()) + 1;
auto* limit = rvptr + 2 * value.size();
for (; rvptr < limit; rvptr += 2)
*rvptr = 0;
}
CKKSPackedEncoding()
: PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
slots = GetDefaultSlotSize();
}
CKKSPackedEncoding(const CKKSPackedEncoding& rhs)
: PlaintextImpl(rhs), value(rhs.value), m_logError(rhs.m_logError) {}
CKKSPackedEncoding(CKKSPackedEncoding&& rhs) noexcept
: PlaintextImpl(std::move(rhs)), value(std::move(rhs.value)), m_logError(rhs.m_logError) {}
bool Encode() override;
bool Decode() override {
OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead.");
}
bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override;
const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
return value;
}
std::vector<double> GetRealPackedValue() const override {
std::vector<double> realValue(value.size());
auto* rvptr = realValue.data();
for (auto vit = value.cbegin(); vit != value.cend(); ++vit, ++rvptr)
*rvptr = vit->real();
return realValue;
}
static std::vector<DCRTPoly::Integer> CRTMult(const std::vector<DCRTPoly::Integer>& a,
const std::vector<DCRTPoly::Integer>& b,
const std::vector<DCRTPoly::Integer>& m) {
// TODO: add check that vector lengths match?
std::vector<DCRTPoly::Integer> r;
r.reserve(m.size());
for (uint32_t i = 0; i < a.size(); ++i)
r.emplace_back(a[i].ModMulFast(b[i], m[i]));
return r;
}
size_t GetLength() const override {
return value.size();
}
double GetLogError() const override {
return m_logError;
}
double GetLogPrecision() const override {
if (ckksDataType == COMPLEX)
OPENFHE_THROW("GetLogPrecision for complex numbers is not implemented.");
return encodingParams->GetPlaintextModulus() - m_logError;
}
void SetLength(size_t siz) override {
value.resize(siz);
}
static void Destroy();
std::string GetFormattedValues(int64_t precision) const override {
std::stringstream ss;
ss << "(";
// for sanity's sake: get rid of all trailing zeroes and print "..." instead
size_t i = value.size();
bool allZeroes = true;
while (i > 0) {
if (value[--i] != std::complex<double>(0, 0)) {
allZeroes = false;
break;
}
}
if (!allZeroes) {
if (ckksDataType == REAL) {
for (size_t j = 0; j <= i; ++j)
ss << std::setprecision(precision) << value[j].real() << ", ";
ss << "... ); Estimated precision: " << GetLogPrecision() << " bits";
}
else {
for (size_t j = 0; j <= i; ++j)
ss << std::setprecision(precision) << " (" << value[j].real() << ", " << value[j].imag() << "), ";
ss << "... )";
}
}
return ss.str();
}
protected:
void PrintValue(std::ostream& out) const override {
out << GetFormattedValues(8) << std::endl;
}
uint32_t GetDefaultSlotSize(uint32_t slots = 0, size_t vlen = 0) {
if (slots == 0) {
uint32_t batchSize = GetEncodingParams()->GetBatchSize();
slots = (batchSize == 0) ? GetElementRingDimension() >> 1 : batchSize;
}
if ((slots & (slots - 1)) != 0)
OPENFHE_THROW("The number of slots should be a power of two");
if (slots > GetElementRingDimension() >> 1)
OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
if (slots < vlen)
OPENFHE_THROW("The number of slots cannot be smaller than value vector size");
return slots;
}
bool CompareTo(const PlaintextImpl& rhs) const override {
if (typeid(rhs) != typeid(CKKSPackedEncoding))
return false;
const auto& el = static_cast<const CKKSPackedEncoding&>(rhs);
return value == el.value;
}
void FitToNativeVector(const std::vector<int64_t>& vec, int64_t bigBound, NativeVector* nativeVec) const;
#if NATIVEINT == 128
void FitToNativeVector(const std::vector<int128_t>& vec, int128_t bigBound, NativeVector* nativeVec) const;
#endif
};
} // namespace lbcrypto
#endif