Program Listing for File matrix.h
↰ Return to documentation for file (core/include/math/matrix.h
)
//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================
/*
This code provide a templated matrix implementation
*/
#ifndef LBCRYPTO_MATH_MATRIX_H
#define LBCRYPTO_MATH_MATRIX_H
#include "lattice/lat-hal.h"
#include "math/distrgen.h"
#include "math/math-hal.h"
#include "math/nbtheory.h"
#include "utils/inttypes.h"
#include "utils/memory.h"
#include "utils/parallel.h"
#include "utils/serializable.h"
#include "utils/utilities.h"
#include <cmath>
#include <functional>
// #include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace lbcrypto {
// Forward declaration
class Field2n;
template <class Element>
class Matrix : public Serializable {
public:
typedef std::vector<std::vector<Element>> data_t;
typedef std::vector<Element> data_row_t;
typedef std::function<Element(void)> alloc_func;
Matrix(alloc_func allocZero, size_t rows, size_t cols) : data(), rows(rows), cols(cols), allocZero(allocZero) {
data.resize(rows);
for (auto row = data.begin(); row != data.end(); ++row) {
row->reserve(cols);
for (size_t col = 0; col < cols; ++col) {
row->push_back(allocZero());
}
}
}
// TODO: add Clear();
Matrix(alloc_func allocZero, size_t rows, size_t cols, alloc_func allocGen);
explicit Matrix(alloc_func allocZero = 0) : data(), rows(0), cols(0), allocZero(allocZero) {}
void SetSize(size_t rows, size_t cols) {
if (this->rows != 0 || this->cols != 0) {
OPENFHE_THROW("You cannot SetSize on a non-empty matrix");
}
this->rows = rows;
this->cols = cols;
data.resize(rows);
for (auto row = data.begin(); row != data.end(); ++row) {
row->reserve(cols);
for (size_t col = 0; col < cols; ++col) {
row->push_back(allocZero());
}
}
}
void SetAllocator(alloc_func allocZero) {
this->allocZero = allocZero;
}
Matrix(const Matrix<Element>& other) : data(), rows(other.rows), cols(other.cols), allocZero(other.allocZero) {
deepCopyData(other.data);
}
Matrix<Element>& operator=(const Matrix<Element>& other);
Matrix<Element>& Ones() {
for (size_t row = 0; row < rows; ++row) {
for (size_t col = 0; col < cols; ++col) {
data[row][col] = 1;
}
}
return *this;
}
Matrix<Element>& ModEq(const Element& modulus);
Matrix<Element>& ModSubEq(Matrix<Element> const& b, const Element& modulus);
Matrix<Element>& Fill(const Element& val);
Matrix<Element>& Identity() {
for (size_t row = 0; row < rows; ++row) {
for (size_t col = 0; col < cols; ++col) {
if (row == col) {
data[row][col] = 1;
}
else {
data[row][col] = 0;
}
}
}
return *this;
}
template <typename T = Element,
typename std::enable_if<!std::is_same<T, M2DCRTPoly>::value && !std::is_same<T, M4DCRTPoly>::value &&
!std::is_same<T, M6DCRTPoly>::value,
bool>::type = true>
Matrix<T> GadgetVector(int64_t base = 2) const {
Matrix<T> g(allocZero, rows, cols);
auto base_matrix = allocZero();
size_t k = cols / rows;
base_matrix = base;
g(0, 0) = 1;
for (size_t i = 1; i < k; i++) {
g(0, i) = g(0, i - 1) * base_matrix;
}
for (size_t row = 1; row < rows; row++) {
for (size_t i = 0; i < k; i++) {
g(row, i + row * k) = g(0, i);
}
}
return g;
}
template <typename T = Element,
typename std::enable_if<std::is_same<T, M2DCRTPoly>::value || std::is_same<T, M4DCRTPoly>::value ||
std::is_same<T, M6DCRTPoly>::value,
bool>::type = true>
Matrix<T> GadgetVector(int64_t base = 2) const {
Matrix<T> g(allocZero, rows, cols);
auto base_matrix = allocZero();
base_matrix = base;
size_t bk = 1;
auto params = g(0, 0).GetParams()->GetParams();
uint64_t digitCount = (uint64_t)ceil(log2(params[0]->GetModulus().ConvertToDouble()) / log2(base));
for (size_t k = 0; k < digitCount; k++) {
for (size_t i = 0; i < params.size(); i++) {
NativePoly temp(params[i]);
temp = bk;
g(0, k + i * digitCount).SetElementAtIndex(i, std::move(temp));
}
bk *= base;
}
size_t kCols = cols / rows;
for (size_t row = 1; row < rows; row++) {
for (size_t i = 0; i < kCols; i++) {
g(row, i + row * kCols) = g(0, i);
}
}
return g;
}
template <typename T = Element,
typename std::enable_if<std::is_same<T, double>::value || std::is_same<T, int>::value ||
std::is_same<T, int64_t>::value || std::is_same<T, Field2n>::value,
bool>::type = true>
double Norm() const {
OPENFHE_THROW("Norm not defined for this type");
}
template <typename T = Element,
typename std::enable_if<!std::is_same<T, double>::value && !std::is_same<T, int>::value &&
!std::is_same<T, int64_t>::value && !std::is_same<T, Field2n>::value,
bool>::type = true>
double Norm() const {
double retVal = 0.0;
double locVal = 0.0;
for (size_t row = 0; row < rows; ++row) {
for (size_t col = 0; col < cols; ++col) {
locVal = data[row][col].Norm();
if (locVal > retVal) {
retVal = locVal;
}
}
}
return retVal;
}
Matrix<Element> Mult(Matrix<Element> const& other) const;
Matrix<Element> operator*(Matrix<Element> const& other) const {
return Mult(other);
}
Matrix<Element> ScalarMult(Element const& other) const {
Matrix<Element> result(*this);
#pragma omp parallel for
for (size_t col = 0; col < result.cols; ++col) {
for (size_t row = 0; row < result.rows; ++row) {
result.data[row][col] = result.data[row][col] * other;
}
}
return result;
}
Matrix<Element> operator*(Element const& other) const {
return ScalarMult(other);
}
bool Equal(Matrix<Element> const& other) const {
if (rows != other.rows || cols != other.cols) {
return false;
}
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
if (data[i][j] != other.data[i][j]) {
return false;
}
}
}
return true;
}
bool operator==(Matrix<Element> const& other) const {
return Equal(other);
}
bool operator!=(Matrix<Element> const& other) const {
return !Equal(other);
}
const data_t& GetData() const {
return data;
}
size_t GetRows() const {
return rows;
}
size_t GetCols() const {
return cols;
}
alloc_func GetAllocator() const {
return allocZero;
}
void SetFormat(Format format);
Matrix<Element> Add(Matrix<Element> const& other) const {
if (rows != other.rows || cols != other.cols) {
OPENFHE_THROW("Addition operands have incompatible dimensions");
}
Matrix<Element> result(*this);
#pragma omp parallel for
for (size_t j = 0; j < cols; ++j) {
for (size_t i = 0; i < rows; ++i) {
result.data[i][j] += other.data[i][j];
}
}
return result;
}
Matrix<Element> operator+(Matrix<Element> const& other) const {
return this->Add(other);
}
Matrix<Element>& operator+=(Matrix<Element> const& other);
Matrix<Element> Sub(Matrix<Element> const& other) const {
if (rows != other.rows || cols != other.cols) {
OPENFHE_THROW("Subtraction operands have incompatible dimensions");
}
Matrix<Element> result(allocZero, rows, other.cols);
#pragma omp parallel for
for (size_t j = 0; j < cols; ++j) {
for (size_t i = 0; i < rows; ++i) {
result.data[i][j] = data[i][j] - other.data[i][j];
}
}
return result;
}
Matrix<Element> operator-(Matrix<Element> const& other) const {
return this->Sub(other);
}
Matrix<Element>& operator-=(Matrix<Element> const& other);
Matrix<Element> Transpose() const;
// YSP The signature of this method needs to be changed in the future
void Determinant(Element* result) const;
// Element Determinant() const;
Matrix<Element> CofactorMatrix() const;
Matrix<Element>& VStack(Matrix<Element> const& other);
Matrix<Element>& HStack(Matrix<Element> const& other);
Element& operator()(size_t row, size_t col) {
return data[row][col];
}
Element const& operator()(size_t row, size_t col) const {
return data[row][col];
}
Matrix<Element> ExtractRow(size_t row) const {
Matrix<Element> result(this->allocZero, 1, this->cols);
int i = 0;
for (auto& elem : this->GetData()[row]) {
result(0, i) = elem;
i++;
}
return result;
// return *this;
}
Matrix<Element> ExtractCol(size_t col) const {
Matrix<Element> result(this->allocZero, this->rows, 1);
for (size_t i = 0; i < this->rows; i++) {
result(i, 0) = data[i][col];
}
return result;
// return *this;
}
inline Matrix<Element> ExtractRows(size_t row_start, size_t row_end) const {
Matrix<Element> result(this->allocZero, row_end - row_start + 1, this->cols);
for (usint row = row_start; row < row_end + 1; row++) {
int i = 0;
for (auto elem = this->GetData()[row].begin(); elem != this->GetData()[row].end(); ++elem) {
result(row - row_start, i) = *elem;
i++;
}
}
return result;
}
friend std::ostream& operator<<(std::ostream& os, const Matrix<Element>& m) {
os << "[ ";
for (size_t row = 0; row < m.GetRows(); ++row) {
os << "[ ";
for (size_t col = 0; col < m.GetCols(); ++col) {
os << m(row, col) << " ";
}
os << "]\n";
}
os << " ]\n";
return os;
}
void SwitchFormat();
#define NOT_AN_ELEMENT_MATRIX(T) \
template <> \
void Matrix<T>::SwitchFormat() { \
OPENFHE_THROW("Not a matrix of Elements"); \
}
/*
* Multiply the matrix by a vector whose elements are all 1's. This causes
* the elements of each row of the matrix to be added and placed into the
* corresponding position in the output vector.
*/
Matrix<Element> MultByUnityVector() const;
/*
* Multiply the matrix by a vector of random 1's and 0's, which is the same as
* adding select elements in each row together. Return a vector that is a rows
* x 1 matrix.
*/
Matrix<Element> MultByRandomVector(std::vector<int> ranvec) const;
template <class Archive>
void save(Archive& ar, std::uint32_t const version) const {
ar(::cereal::make_nvp("d", data));
ar(::cereal::make_nvp("r", rows));
ar(::cereal::make_nvp("c", cols));
}
template <class Archive>
void load(Archive& ar, std::uint32_t const version) {
if (version > SerializedVersion()) {
OPENFHE_THROW("serialized object version " + std::to_string(version) +
" is from a later version of the library");
}
ar(::cereal::make_nvp("d", data));
ar(::cereal::make_nvp("r", rows));
ar(::cereal::make_nvp("c", cols));
// users will need to SetAllocator for any newly deserialized matrix
}
std::string SerializedObjectName() const override {
return "Matrix";
}
static uint32_t SerializedVersion() {
return 1;
}
private:
data_t data;
uint32_t rows;
uint32_t cols;
alloc_func allocZero;
// mutable int NUM_THREADS = 1;
// deep copy of data - used for copy constructor
void deepCopyData(data_t const& src) {
data.clear();
data.resize(src.size());
for (size_t row = 0; row < src.size(); ++row) {
for (auto elem = src[row].begin(); elem != src[row].end(); ++elem) {
data[row].push_back(*elem);
}
}
}
};
template <class Element>
Matrix<Element> operator*(Element const& e, Matrix<Element> const& M) {
return M.ScalarMult(e);
}
template <typename Element>
Matrix<typename Element::Integer> Rotate(Matrix<Element> const& inMat);
template <typename Element>
Matrix<typename Element::Vector> RotateVecResult(Matrix<Element> const& inMat);
template <class Element>
std::ostream& operator<<(std::ostream& os, const Matrix<Element>& m);
Matrix<double> Cholesky(const Matrix<int32_t>& input);
void Cholesky(const Matrix<int32_t>& input, Matrix<double>& result);
Matrix<int32_t> ConvertToInt32(const Matrix<BigInteger>& input, const BigInteger& modulus);
Matrix<int32_t> ConvertToInt32(const Matrix<BigVector>& input, const BigInteger& modulus);
template <typename Element>
Matrix<Element> SplitInt64IntoElements(Matrix<int64_t> const& other, size_t n,
const std::shared_ptr<typename Element::Params> params);
#define SPLIT64_FOR_TYPE(T) \
template <> \
Matrix<T> SplitInt64IntoElements(Matrix<int64_t> const& other, size_t n, \
const std::shared_ptr<typename T::Params> params) { \
auto zero_alloc = T::Allocator(params, Format::COEFFICIENT); \
size_t rows = other.GetRows() / n; \
Matrix<T> result(zero_alloc, rows, 1); \
for (size_t row = 0; row < rows; ++row) { \
std::vector<int64_t> values(n); \
for (size_t i = 0; i < n; ++i) \
values[i] = other(row * n + i, 0); \
result(row, 0) = values; \
} \
return result; \
}
template <typename Element>
Matrix<Element> SplitInt32AltIntoElements(Matrix<int32_t> const& other, size_t n,
const std::shared_ptr<typename Element::Params> params);
#define SPLIT32ALT_FOR_TYPE(T) \
template <> \
Matrix<T> SplitInt32AltIntoElements(Matrix<int32_t> const& other, size_t n, \
const std::shared_ptr<typename T::Params> params) { \
auto zero_alloc = T::Allocator(params, Format::COEFFICIENT); \
size_t rows = other.GetRows(); \
Matrix<T> result(zero_alloc, rows, 1); \
for (size_t row = 0; row < rows; ++row) { \
std::vector<int32_t> values(n); \
for (size_t i = 0; i < n; ++i) \
values[i] = other(row, i); \
result(row, 0) = values; \
} \
return result; \
}
template <typename Element>
Matrix<Element> SplitInt64AltIntoElements(Matrix<int64_t> const& other, size_t n,
const std::shared_ptr<typename Element::Params> params);
#define SPLIT64ALT_FOR_TYPE(T) \
template <> \
Matrix<T> SplitInt64AltIntoElements(Matrix<int64_t> const& other, size_t n, \
const std::shared_ptr<typename T::Params> params) { \
auto zero_alloc = T::Allocator(params, Format::COEFFICIENT); \
size_t rows = other.GetRows(); \
Matrix<T> result(zero_alloc, rows, 1); \
for (size_t row = 0; row < rows; ++row) { \
std::vector<int64_t> values(n); \
for (size_t i = 0; i < n; ++i) \
values[i] = other(row, i); \
result(row, 0) = values; \
} \
return result; \
}
} // namespace lbcrypto
#endif // LBCRYPTO_MATH_MATRIX_H