Program Listing for File ckkspackedencoding.h

Return to documentation for file (pke/include/encoding/ckkspackedencoding.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#ifndef LBCRYPTO_UTILS_CKKSPACKEDEXTENCODING_H
#define LBCRYPTO_UTILS_CKKSPACKEDEXTENCODING_H

#include "constants.h"

#include "encoding/encodingparams.h"
#include "encoding/plaintext.h"

#include "math/hal/basicint.h"

#include <algorithm>
#include <functional>
#include <initializer_list>
#include <memory>
#include <numeric>
#include <utility>
#include <vector>

namespace lbcrypto {

class CKKSPackedEncoding : public PlaintextImpl {
public:
    // these two constructors are used inside of Decrypt
    template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
                                                      std::is_same<T, NativePoly::Params>::value ||
                                                      std::is_same<T, DCRTPoly::Params>::value,
                                                  bool>::type = true>
    CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKSRNS_SCHEME) {
        this->slots = GetDefaultSlotSize();
        if (this->slots > (GetElementRingDimension() / 2)) {
            OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
        }
    }

    /*
   * @param noiseScaleDeg degree of the scaling factor of a plaintext
   * @param level level of plaintext to create.
   * @param scFact scaling factor of a plaintext of this level at depth 1.
   *
   */
    template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
                                                      std::is_same<T, NativePoly::Params>::value ||
                                                      std::is_same<T, DCRTPoly::Params>::value,
                                                  bool>::type = true>
    CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<std::complex<double>>& coeffs,
                       size_t noiseScaleDeg, uint32_t level, double scFact, size_t slots)
        : PlaintextImpl(vp, ep, CKKSRNS_SCHEME), value(coeffs) {
        // validate the number of slots
        if ((slots & (slots - 1)) != 0) {
            OPENFHE_THROW("The number of slots should be a power of two");
        }

        this->slots = (slots) ? slots : GetDefaultSlotSize();

        if (this->slots < coeffs.size()) {
            OPENFHE_THROW("The number of slots cannot be smaller than value vector size");
        }
        else if (this->slots > (GetElementRingDimension() / 2)) {
            OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
        }

        this->noiseScaleDeg = noiseScaleDeg;
        this->level         = level;
        this->scalingFactor = scFact;
    }

    explicit CKKSPackedEncoding(const std::vector<std::complex<double>>& rhs, size_t slots)
        : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKSRNS_SCHEME), value(rhs) {
        // validate the number of slots
        if ((slots & (slots - 1)) != 0) {
            OPENFHE_THROW("The number of slots should be a power of two");
        }

        this->slots = (slots) ? slots : GetDefaultSlotSize();

        if (this->slots < rhs.size()) {
            OPENFHE_THROW("The number of slots cannot be smaller than value vector size");
        }
        else if (this->slots > (GetElementRingDimension() / 2)) {
            OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
        }
    }

    CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKSRNS_SCHEME) {
        this->slots = GetDefaultSlotSize();
        if (this->slots > (GetElementRingDimension() / 2)) {
            OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
        }
    }

    CKKSPackedEncoding(const CKKSPackedEncoding& rhs)
        : PlaintextImpl(rhs), value(rhs.value), m_logError(rhs.m_logError) {}

    CKKSPackedEncoding(CKKSPackedEncoding&& rhs)
        : PlaintextImpl(std::move(rhs)), value(std::move(rhs.value)), m_logError(rhs.m_logError) {}

    bool Encode();

    bool Decode() {
        OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead.");
    }

    bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode);

    const std::vector<std::complex<double>>& GetCKKSPackedValue() const {
        return value;
    }

    const std::vector<double> GetRealPackedValue() const {
        std::vector<double> realValue(value.size());
        std::transform(value.begin(), value.end(), realValue.begin(),
                       [](std::complex<double> da) { return da.real(); });

        return realValue;
    }

    static std::vector<DCRTPoly::Integer> CRTMult(const std::vector<DCRTPoly::Integer>& a,
                                                  const std::vector<DCRTPoly::Integer>& b,
                                                  const std::vector<DCRTPoly::Integer>& mods);

    PlaintextEncodings GetEncodingType() const {
        return CKKS_PACKED_ENCODING;
    }

    size_t GetLength() const {
        return value.size();
    }

    double GetLogError() const {
        return m_logError;
    }

    double GetLogPrecision() const {
        return encodingParams->GetPlaintextModulus() - m_logError;
    }

    void SetLength(size_t siz) {
        value.resize(siz);
    }

    bool CompareTo(const PlaintextImpl& other) const {
        const auto& rv = static_cast<const CKKSPackedEncoding&>(other);
        return this->value == rv.value;
    }

    static void Destroy();

    void PrintValue(std::ostream& out) const {
        // for sanity's sake, trailing zeros get elided into "..."
        // out.precision(15);
        out << "(";
        size_t i = value.size();
        while (--i > 0)
            if (value[i] != std::complex<double>(0, 0))
                break;

        for (size_t j = 0; j <= i; j++) {
            out << value[j].real() << ", ";
        }

        out << " ... ); ";
        out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl;
    }

private:
    std::vector<std::complex<double>> value;

    double m_logError = 0;

protected:
    usint GetDefaultSlotSize() {
        auto batchSize = GetEncodingParams()->GetBatchSize();
        return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize;
    }
    void FitToNativeVector(const std::vector<int64_t>& vec, int64_t bigBound, NativeVector* nativeVec) const;

#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
    void FitToNativeVector(const std::vector<int128_t>& vec, int128_t bigBound, NativeVector* nativeVec) const;
#endif
};

}  // namespace lbcrypto

#endif