Program Listing for File ckkspackedencoding.h

Return to documentation for file (pke/include/encoding/ckkspackedencoding.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

#ifndef LBCRYPTO_UTILS_CKKSPACKEDEXTENCODING_H
#define LBCRYPTO_UTILS_CKKSPACKEDEXTENCODING_H

#include "constants.h"

#include "encoding/encodingparams.h"
#include "encoding/plaintext.h"

#include "math/hal/basicint.h"

#include <algorithm>
#include <functional>
#include <initializer_list>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>

namespace lbcrypto {

class CKKSPackedEncoding : public PlaintextImpl {
private:
    std::vector<std::complex<double>> value;
    double m_logError = 0.;

public:
    // these two constructors are used inside of Decrypt
    template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
                                                      std::is_same<T, NativePoly::Params>::value ||
                                                      std::is_same<T, DCRTPoly::Params>::value,
                                                  bool>::type = true>
    CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, CKKSDataType ckksdt = REAL)
        : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
        ckksDataType = ckksdt;
        slots        = GetDefaultSlotSize();
    }

    /*
   * @param noiseScaleDeg degree of the scaling factor of a plaintext
   * @param level level of plaintext to create.
   * @param scFact scaling factor of a plaintext of this level at depth 1.
   *
   */
    template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
                                                      std::is_same<T, NativePoly::Params>::value ||
                                                      std::is_same<T, DCRTPoly::Params>::value,
                                                  bool>::type = true>
    CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<std::complex<double>>& v,
                       size_t nsdeg, uint32_t lvl, double scFact, uint32_t slts, CKKSDataType ckksdt = REAL)
        : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(v) {
        ckksDataType  = ckksdt;
        scalingFactor = scFact;
        level         = lvl;
        noiseScaleDeg = nsdeg;
        slots         = GetDefaultSlotSize(slts, v.size());

        if (ckksDataType == REAL) {
            auto* rvptr = reinterpret_cast<double*>(value.data()) + 1;
            auto* limit = rvptr + 2 * value.size();
            for (; rvptr < limit; rvptr += 2)
                *rvptr = 0;
        }
    }

    explicit CKKSPackedEncoding(const std::vector<std::complex<double>>& v, uint32_t s)
        : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(v) {
        slots = GetDefaultSlotSize(s, v.size());

        // Assumes ckksDataType = REAL
        auto* rvptr = reinterpret_cast<double*>(value.data()) + 1;
        auto* limit = rvptr + 2 * value.size();
        for (; rvptr < limit; rvptr += 2)
            *rvptr = 0;
    }

    CKKSPackedEncoding()
        : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
        slots = GetDefaultSlotSize();
    }

    CKKSPackedEncoding(const CKKSPackedEncoding& rhs)
        : PlaintextImpl(rhs), value(rhs.value), m_logError(rhs.m_logError) {}

    CKKSPackedEncoding(CKKSPackedEncoding&& rhs) noexcept
        : PlaintextImpl(std::move(rhs)), value(std::move(rhs.value)), m_logError(rhs.m_logError) {}

    bool Encode() override;

    bool Decode() override {
        OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead.");
    }

    bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override;

    const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
        return value;
    }

    std::vector<double> GetRealPackedValue() const override {
        std::vector<double> realValue(value.size());
        auto* rvptr = realValue.data();
        for (auto vit = value.cbegin(); vit != value.cend(); ++vit, ++rvptr)
            *rvptr = vit->real();
        return realValue;
    }

    static std::vector<DCRTPoly::Integer> CRTMult(const std::vector<DCRTPoly::Integer>& a,
                                                  const std::vector<DCRTPoly::Integer>& b,
                                                  const std::vector<DCRTPoly::Integer>& m) {
        // TODO: add check that vector lengths match?
        std::vector<DCRTPoly::Integer> r;
        r.reserve(m.size());
        for (uint32_t i = 0; i < a.size(); ++i)
            r.emplace_back(a[i].ModMulFast(b[i], m[i]));
        return r;
    }

    size_t GetLength() const override {
        return value.size();
    }

    double GetLogError() const override {
        return m_logError;
    }

    double GetLogPrecision() const override {
        if (ckksDataType == COMPLEX)
            OPENFHE_THROW("GetLogPrecision for complex numbers is not implemented.");
        return encodingParams->GetPlaintextModulus() - m_logError;
    }

    void SetLength(size_t siz) override {
        value.resize(siz);
    }

    static void Destroy();

    std::string GetFormattedValues(int64_t precision) const override {
        std::stringstream ss;
        ss << "(";

        // for sanity's sake: get rid of all trailing zeroes and print "..." instead
        size_t i       = value.size();
        bool allZeroes = true;
        while (i > 0) {
            if (value[--i] != std::complex<double>(0, 0)) {
                allZeroes = false;
                break;
            }
        }
        if (!allZeroes) {
            if (ckksDataType == REAL) {
                for (size_t j = 0; j <= i; ++j)
                    ss << std::setprecision(precision) << value[j].real() << ", ";
                ss << "... ); Estimated precision: " << GetLogPrecision() << " bits";
            }
            else {
                for (size_t j = 0; j <= i; ++j)
                    ss << std::setprecision(precision) << " (" << value[j].real() << ", " << value[j].imag() << "), ";
                ss << "... )";
            }
        }
        return ss.str();
    }

protected:
    void PrintValue(std::ostream& out) const override {
        out << GetFormattedValues(8) << std::endl;
    }

    uint32_t GetDefaultSlotSize(uint32_t slots = 0, size_t vlen = 0) {
        if (slots == 0) {
            uint32_t batchSize = GetEncodingParams()->GetBatchSize();
            slots              = (batchSize == 0) ? GetElementRingDimension() >> 1 : batchSize;
        }
        if ((slots & (slots - 1)) != 0)
            OPENFHE_THROW("The number of slots should be a power of two");
        if (slots > GetElementRingDimension() >> 1)
            OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
        if (slots < vlen)
            OPENFHE_THROW("The number of slots cannot be smaller than value vector size");
        return slots;
    }

    bool CompareTo(const PlaintextImpl& rhs) const override {
        if (typeid(rhs) != typeid(CKKSPackedEncoding))
            return false;

        const auto& el = static_cast<const CKKSPackedEncoding&>(rhs);
        return value == el.value;
    }

    void FitToNativeVector(const std::vector<int64_t>& vec, int64_t bigBound, NativeVector* nativeVec) const;

#if NATIVEINT == 128
    void FitToNativeVector(const std::vector<int128_t>& vec, int128_t bigBound, NativeVector* nativeVec) const;
#endif
};

}  // namespace lbcrypto

#endif