Program Listing for File mubintvecnat.h
↰ Return to documentation for file (core/include/math/hal/intnat/mubintvecnat.h
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
* This file contains the vector manipulation functionality for native integers
*/
#ifndef LBCRYPTO_INC_MATH_HAL_INTNAT_MUBINTVECNAT_H
#define LBCRYPTO_INC_MATH_HAL_INTNAT_MUBINTVECNAT_H
#include "math/hal/basicint.h"
#include "math/hal/intnat/ubintnat.h"
#include "math/hal/vector.h"
#include "utils/blockAllocator/xvector.h"
#include "utils/exception.h"
#include "utils/inttypes.h"
#include "utils/serializable.h"
#include <algorithm>
#include <initializer_list>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
// the following should be set to 1 in order to have native vector use block
// allocations then determine if you want dynamic or static allocations by
// settingdefining STAIC_POOLS on line 24 of
// xallocator.cpp
#define BLOCK_VECTOR_ALLOCATION 0 // set to 1 to use block allocations
namespace intnat {
// Forward declare class and give it an alias for the expected type
template <typename IntType>
class NativeVectorT;
using NativeVector = NativeVectorT<NativeInteger>;
#if 0 // allocator that reports bytes used.
template <class Tp>
struct NAlloc {
typedef Tp value_type;
NAlloc() = default;
template <class T> NAlloc(const NAlloc<T>&) {}
Tp* allocate(std::size_t n) {
n *= sizeof(Tp);
return static_cast<Tp*>(::operator new(n));
}
void deallocate(Tp* p, std::size_t n) {
std::cout << "deallocating " << n*sizeof*p << " bytes\n";
::operator delete(p);
}
};
template <class T, class U>
bool operator==(const NAlloc<T>&, const NAlloc<U>&) { return true; }
template <class T, class U>
bool operator!=(const NAlloc<T>&, const NAlloc<U>&) { return false; }
#endif
#if 0 // allocator that reports bytes used.
template <class Tp>
struct NAlloc {
typedef Tp value_type;
NAlloc() = default;
template <class T> NAlloc(const NAlloc<T>&) {}
Tp* allocate(std::size_t n) {
n *= sizeof(Tp);
std::cout << "allocating " << n << " bytes\n";
return static_cast<Tp*>(::operator new(n));
}
void deallocate(Tp* p, std::size_t n) {
std::cout << "deallocating " << n*sizeof*p << " bytes\n";
::operator delete(p);
}
};
template <class T, class U>
bool operator==(const NAlloc<T>&, const NAlloc<U>&) { return true; }
template <class T, class U>
bool operator!=(const NAlloc<T>&, const NAlloc<U>&) { return false; }
#endif
template <class IntegerType>
class NativeVectorT final : public lbcrypto::BigVectorInterface<NativeVectorT<IntegerType>, IntegerType>,
public lbcrypto::Serializable {
private:
// m_modulus stores the internal modulus of the vector.
IntegerType m_modulus{0};
#if BLOCK_VECTOR_ALLOCATION != 1
std::vector<IntegerType> m_data{};
#else
xvector<IntegerType> m_data{};
#endif
// function to check if the index is a valid index.
bool IndexCheck(size_t length) const {
return length < m_data.size();
}
public:
using BasicInt = typename IntegerType::Integer;
constexpr NativeVectorT() = default;
static constexpr NativeVectorT Single(const IntegerType& val, const IntegerType& modulus) noexcept {
return NativeVectorT(1, modulus, val);
}
explicit constexpr NativeVectorT(usint length) noexcept : m_data(length) {}
constexpr NativeVectorT(usint length, const IntegerType& modulus) noexcept : m_modulus{modulus}, m_data(length) {
// TODO: better performance if this check is done at poly level
// if (modulus.GetMSB() > MAX_MODULUS_SIZE)
// OPENFHE_THROW(std::to_string(modulus.GetMSB()) +
// " bits larger than max modulus bits " + std::to_string(MAX_MODULUS_SIZE));
}
constexpr NativeVectorT(usint length, const IntegerType& modulus, const IntegerType& val) noexcept
: m_modulus{modulus}, m_data(length, val.Mod(modulus)) {
// TODO: better performance if this check is done at poly level
// if (modulus.GetMSB() > MAX_MODULUS_SIZE)
// OPENFHE_THROW(std::to_string(modulus.GetMSB()) +
// " bits larger than max modulus bits " + std::to_string(MAX_MODULUS_SIZE));
}
constexpr NativeVectorT(const NativeVectorT& v) noexcept : m_modulus{v.m_modulus}, m_data{v.m_data} {}
constexpr NativeVectorT(NativeVectorT&& v) noexcept
: m_modulus{std::move(v.m_modulus)}, m_data{std::move(v.m_data)} {}
NativeVectorT(usint length, const IntegerType& modulus, std::initializer_list<std::string> rhs) noexcept;
NativeVectorT(usint length, const IntegerType& modulus, std::initializer_list<uint64_t> rhs) noexcept;
NativeVectorT& operator=(const NativeVectorT& rhs) noexcept {
m_modulus = rhs.m_modulus;
if (m_data.size() >= rhs.m_data.size()) {
std::copy(rhs.m_data.begin(), rhs.m_data.end(), m_data.begin());
if (m_data.size() > rhs.m_data.size())
m_data.resize(rhs.m_data.size());
return *this;
}
m_data = rhs.m_data;
return *this;
}
NativeVectorT& operator=(NativeVectorT&& rhs) noexcept {
m_modulus = std::move(rhs.m_modulus);
m_data = std::move(rhs.m_data);
return *this;
}
NativeVectorT& operator=(std::initializer_list<std::string> rhs) noexcept;
NativeVectorT& operator=(std::initializer_list<uint64_t> rhs) noexcept;
constexpr NativeVectorT& operator=(uint64_t val) {
std::fill(m_data.begin(), m_data.end(), 0);
m_data.at(0) = val;
return *this;
}
IntegerType& at(size_t i) {
if (!NativeVectorT::IndexCheck(i))
OPENFHE_THROW("NativeVectorT index out of range");
return m_data[i];
}
const IntegerType& at(size_t i) const {
if (!NativeVectorT::IndexCheck(i))
OPENFHE_THROW("NativeVectorT index out of range");
return m_data[i];
}
IntegerType& operator[](size_t idx) {
return m_data[idx];
}
const IntegerType& operator[](size_t idx) const {
return m_data[idx];
}
void SetModulus(const IntegerType& value) {
if (value.GetMSB() > MAX_MODULUS_SIZE) {
std::string errMsg{"Requested modulus' size " + std::to_string(value.GetMSB()) + " is not supported."};
errMsg += " NativeVectorT supports only modulus size <= " + std::to_string(MAX_MODULUS_SIZE);
OPENFHE_THROW(errMsg);
}
m_modulus.m_value = value.m_value;
}
void SwitchModulus(const IntegerType& value);
const IntegerType& GetModulus() const {
return m_modulus;
}
size_t GetLength() const {
return m_data.size();
}
// MODULAR ARITHMETIC OPERATIONS
NativeVectorT Mod(const IntegerType& modulus) const;
NativeVectorT& ModEq(const IntegerType& modulus);
NativeVectorT ModAdd(const IntegerType& b) const;
NativeVectorT& ModAddEq(const IntegerType& b);
NativeVectorT ModAddAtIndex(size_t i, const IntegerType& b) const;
NativeVectorT& ModAddAtIndexEq(size_t i, const IntegerType& b);
NativeVectorT ModAdd(const NativeVectorT& b) const;
NativeVectorT& ModAddEq(const NativeVectorT& b);
NativeVectorT& ModAddNoCheckEq(const NativeVectorT& b) {
size_t size{m_data.size()};
auto mv{m_modulus};
for (size_t i = 0; i < size; ++i)
m_data[i].ModAddFastEq(b[i], mv);
return *this;
}
NativeVectorT ModSub(const IntegerType& b) const;
NativeVectorT& ModSubEq(const IntegerType& b);
NativeVectorT ModSub(const NativeVectorT& b) const;
NativeVectorT& ModSubEq(const NativeVectorT& b);
NativeVectorT ModMul(const IntegerType& b) const;
NativeVectorT& ModMulEq(const IntegerType& b);
NativeVectorT ModMul(const NativeVectorT& b) const;
NativeVectorT& ModMulEq(const NativeVectorT& b);
NativeVectorT& ModMulNoCheckEq(const NativeVectorT& b) {
size_t size{m_data.size()};
auto mv{m_modulus};
#ifdef NATIVEINT_BARRET_MOD
auto mu{m_modulus.ComputeMu()};
for (size_t i = 0; i < size; ++i)
m_data[i].ModMulFastEq(b[i], mv, mu);
#else
for (size_t i = 0; i < size; ++i)
m_data[i].ModMulFastEq(b[i], mv);
#endif
return *this;
}
NativeVectorT MultWithOutMod(const NativeVectorT& b) const;
NativeVectorT ModExp(const IntegerType& b) const;
NativeVectorT& ModExpEq(const IntegerType& b);
NativeVectorT ModInverse() const {
size_t size{m_data.size()};
auto mv{m_modulus};
NativeVectorT ans(size, mv);
for (size_t i{0}; i < size; ++i)
ans[i] = m_data[i].ModInverse(mv);
return ans;
}
NativeVectorT& ModInverseEq() {
size_t size{m_data.size()};
auto mv{m_modulus};
for (size_t i{0}; i < size; ++i)
m_data[i] = m_data[i].ModInverse(mv);
return *this;
}
NativeVectorT ModByTwo() const;
NativeVectorT& ModByTwoEq();
NativeVectorT MultiplyAndRound(const IntegerType& p, const IntegerType& q) const;
NativeVectorT& MultiplyAndRoundEq(const IntegerType& p, const IntegerType& q);
NativeVectorT DivideAndRound(const IntegerType& q) const;
NativeVectorT& DivideAndRoundEq(const IntegerType& q);
// OTHER FUNCTIONS
NativeVectorT GetDigitAtIndexForBase(usint index, usint base) const;
// STRINGS & STREAMS
template <class IntegerType_c>
friend std::ostream& operator<<(std::ostream& os, const NativeVectorT<IntegerType_c>& ptr_obj) {
auto len = ptr_obj.m_data.size();
os << "[";
for (usint i = 0; i < len; i++) {
os << ptr_obj.m_data[i];
os << ((i == (len - 1)) ? "]" : " ");
}
os << " modulus: " << ptr_obj.m_modulus;
return os;
}
template <class Archive>
typename std::enable_if<!cereal::traits::is_text_archive<Archive>::value, void>::type save(
Archive& ar, std::uint32_t const version) const {
::cereal::size_type size = m_data.size();
ar(size);
if (size > 0) {
ar(::cereal::binary_data(m_data.data(), size * sizeof(IntegerType)));
}
ar(m_modulus);
}
template <class Archive>
typename std::enable_if<cereal::traits::is_text_archive<Archive>::value, void>::type save(
Archive& ar, std::uint32_t const version) const {
ar(::cereal::make_nvp("v", m_data));
ar(::cereal::make_nvp("m", m_modulus));
}
template <class Archive>
typename std::enable_if<!cereal::traits::is_text_archive<Archive>::value, void>::type load(
Archive& ar, std::uint32_t const version) {
if (version > SerializedVersion()) {
OPENFHE_THROW("serialized object version " + std::to_string(version) +
" is from a later version of the library");
}
::cereal::size_type size;
ar(size);
m_data.resize(size);
if (size > 0) {
auto* data = reinterpret_cast<IntegerType*>(malloc(size * sizeof(IntegerType)));
ar(::cereal::binary_data(data, size * sizeof(IntegerType)));
for (::cereal::size_type i = 0; i < size; i++) {
m_data[i] = data[i];
}
free(data);
}
ar(m_modulus);
}
template <class Archive>
typename std::enable_if<cereal::traits::is_text_archive<Archive>::value, void>::type load(
Archive& ar, std::uint32_t const version) {
if (version > SerializedVersion()) {
OPENFHE_THROW("serialized object version " + std::to_string(version) +
" is from a later version of the library");
}
ar(::cereal::make_nvp("v", m_data));
ar(::cereal::make_nvp("m", m_modulus));
}
std::string SerializedObjectName() const override {
return "NativeVectorT";
}
static uint32_t SerializedVersion() {
return 1;
}
};
} // namespace intnat
namespace cereal {
template <class Archive, class A>
inline void CEREAL_SAVE_FUNCTION_NAME(Archive& ar, std::vector<intnat::NativeIntegerT<uint64_t>, A> const& vec) {
ar(make_size_tag(static_cast<cereal::size_type>(vec.size()))); // number of elements
for (const auto& v : vec) {
ar(v.ConvertToInt());
}
}
#if defined(HAVE_INT128)
template <class Archive, class A>
inline void CEREAL_SAVE_FUNCTION_NAME(Archive& ar, std::vector<intnat::NativeIntegerT<uint128_t>, A> const& vec) {
ar(make_size_tag(static_cast<cereal::size_type>(vec.size()))); // number of elements
constexpr uint128_t mask = (static_cast<uint128_t>(1) << 64) - 1;
for (const auto& v : vec) {
uint64_t vec[2];
uint128_t v128 = v.ConvertToInt();
vec[0] = v128 & mask; // least significant word
vec[1] = v128 >> 64; // most significant word
ar(vec);
}
}
#endif
template <class Archive, class A>
inline void CEREAL_LOAD_FUNCTION_NAME(Archive& ar, std::vector<intnat::NativeIntegerT<uint64_t>, A>& vec) {
cereal::size_type size;
ar(make_size_tag(size));
vec.resize(static_cast<size_t>(size));
for (auto& v : vec) {
uint64_t b;
ar(b);
v = b;
}
}
#if defined(HAVE_INT128)
template <class Archive, class A>
inline void CEREAL_LOAD_FUNCTION_NAME(Archive& ar, std::vector<intnat::NativeIntegerT<uint128_t>, A>& vec) {
cereal::size_type size;
ar(make_size_tag(size));
vec.resize(static_cast<size_t>(size));
for (auto& v : vec) {
uint64_t vec[2];
ar(vec);
v = vec[1]; // most significant word
v <<= 64;
v += vec[0]; // least significant word
}
}
#endif
} // namespace cereal
#endif // LBCRYPTO_MATH_HAL_INTNAT_MUBINTVECNAT_H