Program Listing for File matrix.h

Return to documentation for file (core/include/math/matrix.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  This code provide a templated matrix implementation
 */

#ifndef LBCRYPTO_MATH_MATRIX_H
#define LBCRYPTO_MATH_MATRIX_H

#include "lattice/lat-hal.h"

#include "math/distrgen.h"
#include "math/math-hal.h"
#include "math/nbtheory.h"

#include "utils/inttypes.h"
#include "utils/memory.h"
#include "utils/parallel.h"
#include "utils/serializable.h"
#include "utils/utilities.h"

#include <cmath>
#include <functional>
// #include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace lbcrypto {

// Forward declaration
class Field2n;

template <class Element>
class Matrix : public Serializable {
public:
    typedef std::vector<std::vector<Element>> data_t;
    typedef std::vector<Element> data_row_t;
    typedef std::function<Element(void)> alloc_func;

    Matrix(alloc_func allocZero, size_t rows, size_t cols) : data(), rows(rows), cols(cols), allocZero(allocZero) {
        data.resize(rows);
        for (auto row = data.begin(); row != data.end(); ++row) {
            row->reserve(cols);
            for (size_t col = 0; col < cols; ++col) {
                row->push_back(allocZero());
            }
        }
    }

    // TODO: add Clear();

    Matrix(alloc_func allocZero, size_t rows, size_t cols, alloc_func allocGen);

    explicit Matrix(alloc_func allocZero = 0) : data(), rows(0), cols(0), allocZero(allocZero) {}

    void SetSize(size_t rows, size_t cols) {
        if (this->rows != 0 || this->cols != 0) {
            OPENFHE_THROW("You cannot SetSize on a non-empty matrix");
        }

        this->rows = rows;
        this->cols = cols;

        data.resize(rows);
        for (auto row = data.begin(); row != data.end(); ++row) {
            row->reserve(cols);
            for (size_t col = 0; col < cols; ++col) {
                row->push_back(allocZero());
            }
        }
    }

    void SetAllocator(alloc_func allocZero) {
        this->allocZero = allocZero;
    }

    Matrix(const Matrix<Element>& other) : data(), rows(other.rows), cols(other.cols), allocZero(other.allocZero) {
        deepCopyData(other.data);
    }

    Matrix<Element>& operator=(const Matrix<Element>& other);

    Matrix<Element>& Ones() {
        for (size_t row = 0; row < rows; ++row) {
            for (size_t col = 0; col < cols; ++col) {
                data[row][col] = 1;
            }
        }
        return *this;
    }

    Matrix<Element>& ModEq(const Element& modulus);

    Matrix<Element>& ModSubEq(Matrix<Element> const& b, const Element& modulus);

    Matrix<Element>& Fill(const Element& val);

    Matrix<Element>& Identity() {
        for (size_t row = 0; row < rows; ++row) {
            for (size_t col = 0; col < cols; ++col) {
                if (row == col) {
                    data[row][col] = 1;
                }
                else {
                    data[row][col] = 0;
                }
            }
        }
        return *this;
    }

    template <typename T                          = Element,
              typename std::enable_if<!std::is_same<T, M2DCRTPoly>::value && !std::is_same<T, M4DCRTPoly>::value &&
                                          !std::is_same<T, M6DCRTPoly>::value,
                                      bool>::type = true>
    Matrix<T> GadgetVector(int64_t base = 2) const {
        Matrix<T> g(allocZero, rows, cols);
        auto base_matrix = allocZero();
        size_t k         = cols / rows;
        base_matrix      = base;
        g(0, 0)          = 1;
        for (size_t i = 1; i < k; i++) {
            g(0, i) = g(0, i - 1) * base_matrix;
        }
        for (size_t row = 1; row < rows; row++) {
            for (size_t i = 0; i < k; i++) {
                g(row, i + row * k) = g(0, i);
            }
        }
        return g;
    }

    template <typename T                          = Element,
              typename std::enable_if<std::is_same<T, M2DCRTPoly>::value || std::is_same<T, M4DCRTPoly>::value ||
                                          std::is_same<T, M6DCRTPoly>::value,
                                      bool>::type = true>
    Matrix<T> GadgetVector(int64_t base = 2) const {
        Matrix<T> g(allocZero, rows, cols);
        auto base_matrix = allocZero();
        base_matrix      = base;
        size_t bk        = 1;

        auto params = g(0, 0).GetParams()->GetParams();

        uint64_t digitCount = (uint64_t)ceil(log2(params[0]->GetModulus().ConvertToDouble()) / log2(base));

        for (size_t k = 0; k < digitCount; k++) {
            for (size_t i = 0; i < params.size(); i++) {
                NativePoly temp(params[i]);
                temp = bk;
                g(0, k + i * digitCount).SetElementAtIndex(i, std::move(temp));
            }
            bk *= base;
        }

        size_t kCols = cols / rows;
        for (size_t row = 1; row < rows; row++) {
            for (size_t i = 0; i < kCols; i++) {
                g(row, i + row * kCols) = g(0, i);
            }
        }
        return g;
    }

    template <typename T                          = Element,
              typename std::enable_if<std::is_same<T, double>::value || std::is_same<T, int>::value ||
                                          std::is_same<T, int64_t>::value || std::is_same<T, Field2n>::value,
                                      bool>::type = true>
    double Norm() const {
        OPENFHE_THROW("Norm not defined for this type");
    }

    template <typename T                          = Element,
              typename std::enable_if<!std::is_same<T, double>::value && !std::is_same<T, int>::value &&
                                          !std::is_same<T, int64_t>::value && !std::is_same<T, Field2n>::value,
                                      bool>::type = true>
    double Norm() const {
        double retVal = 0.0;
        double locVal = 0.0;
        for (size_t row = 0; row < rows; ++row) {
            for (size_t col = 0; col < cols; ++col) {
                locVal = data[row][col].Norm();
                if (locVal > retVal) {
                    retVal = locVal;
                }
            }
        }
        return retVal;
    }

    Matrix<Element> Mult(Matrix<Element> const& other) const;

    Matrix<Element> operator*(Matrix<Element> const& other) const {
        return Mult(other);
    }

    Matrix<Element> ScalarMult(Element const& other) const {
        Matrix<Element> result(*this);
#pragma omp parallel for
        for (size_t col = 0; col < result.cols; ++col) {
            for (size_t row = 0; row < result.rows; ++row) {
                result.data[row][col] = result.data[row][col] * other;
            }
        }

        return result;
    }

    Matrix<Element> operator*(Element const& other) const {
        return ScalarMult(other);
    }

    bool Equal(Matrix<Element> const& other) const {
        if (rows != other.rows || cols != other.cols) {
            return false;
        }

        for (size_t i = 0; i < rows; ++i) {
            for (size_t j = 0; j < cols; ++j) {
                if (data[i][j] != other.data[i][j]) {
                    return false;
                }
            }
        }
        return true;
    }

    bool operator==(Matrix<Element> const& other) const {
        return Equal(other);
    }

    bool operator!=(Matrix<Element> const& other) const {
        return !Equal(other);
    }

    const data_t& GetData() const {
        return data;
    }

    size_t GetRows() const {
        return rows;
    }

    size_t GetCols() const {
        return cols;
    }

    alloc_func GetAllocator() const {
        return allocZero;
    }

    void SetFormat(Format format);

    Matrix<Element> Add(Matrix<Element> const& other) const {
        if (rows != other.rows || cols != other.cols) {
            OPENFHE_THROW("Addition operands have incompatible dimensions");
        }
        Matrix<Element> result(*this);
#pragma omp parallel for
        for (size_t j = 0; j < cols; ++j) {
            for (size_t i = 0; i < rows; ++i) {
                result.data[i][j] += other.data[i][j];
            }
        }
        return result;
    }

    Matrix<Element> operator+(Matrix<Element> const& other) const {
        return this->Add(other);
    }

    Matrix<Element>& operator+=(Matrix<Element> const& other);

    Matrix<Element> Sub(Matrix<Element> const& other) const {
        if (rows != other.rows || cols != other.cols) {
            OPENFHE_THROW("Subtraction operands have incompatible dimensions");
        }
        Matrix<Element> result(allocZero, rows, other.cols);
#pragma omp parallel for
        for (size_t j = 0; j < cols; ++j) {
            for (size_t i = 0; i < rows; ++i) {
                result.data[i][j] = data[i][j] - other.data[i][j];
            }
        }

        return result;
    }

    Matrix<Element> operator-(Matrix<Element> const& other) const {
        return this->Sub(other);
    }

    Matrix<Element>& operator-=(Matrix<Element> const& other);

    Matrix<Element> Transpose() const;

    // YSP The signature of this method needs to be changed in the future
    void Determinant(Element* result) const;
    // Element Determinant() const;

    Matrix<Element> CofactorMatrix() const;

    Matrix<Element>& VStack(Matrix<Element> const& other);

    Matrix<Element>& HStack(Matrix<Element> const& other);

    Element& operator()(size_t row, size_t col) {
        return data[row][col];
    }

    Element const& operator()(size_t row, size_t col) const {
        return data[row][col];
    }

    Matrix<Element> ExtractRow(size_t row) const {
        Matrix<Element> result(this->allocZero, 1, this->cols);
        int i = 0;
        for (auto& elem : this->GetData()[row]) {
            result(0, i) = elem;
            i++;
        }
        return result;
        // return *this;
    }

    Matrix<Element> ExtractCol(size_t col) const {
        Matrix<Element> result(this->allocZero, this->rows, 1);
        for (size_t i = 0; i < this->rows; i++) {
            result(i, 0) = data[i][col];
        }
        return result;
        // return *this;
    }

    inline Matrix<Element> ExtractRows(size_t row_start, size_t row_end) const {
        Matrix<Element> result(this->allocZero, row_end - row_start + 1, this->cols);

        for (usint row = row_start; row < row_end + 1; row++) {
            int i = 0;

            for (auto elem = this->GetData()[row].begin(); elem != this->GetData()[row].end(); ++elem) {
                result(row - row_start, i) = *elem;
                i++;
            }
        }

        return result;
    }

    friend std::ostream& operator<<(std::ostream& os, const Matrix<Element>& m) {
        os << "[ ";
        for (size_t row = 0; row < m.GetRows(); ++row) {
            os << "[ ";
            for (size_t col = 0; col < m.GetCols(); ++col) {
                os << m(row, col) << " ";
            }
            os << "]\n";
        }
        os << " ]\n";
        return os;
    }

    void SwitchFormat();
#define NOT_AN_ELEMENT_MATRIX(T)                   \
    template <>                                    \
    void Matrix<T>::SwitchFormat() {               \
        OPENFHE_THROW("Not a matrix of Elements"); \
    }

    /*
   * Multiply the matrix by a vector whose elements are all 1's.  This causes
   * the elements of each row of the matrix to be added and placed into the
   * corresponding position in the output vector.
   */
    Matrix<Element> MultByUnityVector() const;

    /*
   * Multiply the matrix by a vector of random 1's and 0's, which is the same as
   * adding select elements in each row together. Return a vector that is a rows
   * x 1 matrix.
   */
    Matrix<Element> MultByRandomVector(std::vector<int> ranvec) const;

    template <class Archive>
    void save(Archive& ar, std::uint32_t const version) const {
        ar(::cereal::make_nvp("d", data));
        ar(::cereal::make_nvp("r", rows));
        ar(::cereal::make_nvp("c", cols));
    }

    template <class Archive>
    void load(Archive& ar, std::uint32_t const version) {
        if (version > SerializedVersion()) {
            OPENFHE_THROW("serialized object version " + std::to_string(version) +
                          " is from a later version of the library");
        }
        ar(::cereal::make_nvp("d", data));
        ar(::cereal::make_nvp("r", rows));
        ar(::cereal::make_nvp("c", cols));

        // users will need to SetAllocator for any newly deserialized matrix
    }

    std::string SerializedObjectName() const override {
        return "Matrix";
    }
    static uint32_t SerializedVersion() {
        return 1;
    }

private:
    data_t data;
    uint32_t rows;
    uint32_t cols;
    alloc_func allocZero;
    // mutable int NUM_THREADS = 1;

    // deep copy of data - used for copy constructor
    void deepCopyData(data_t const& src) {
        data.clear();
        data.resize(src.size());
        for (size_t row = 0; row < src.size(); ++row) {
            for (auto elem = src[row].begin(); elem != src[row].end(); ++elem) {
                data[row].push_back(*elem);
            }
        }
    }
};

template <class Element>
Matrix<Element> operator*(Element const& e, Matrix<Element> const& M) {
    return M.ScalarMult(e);
}

template <typename Element>
Matrix<typename Element::Integer> Rotate(Matrix<Element> const& inMat);

template <typename Element>
Matrix<typename Element::Vector> RotateVecResult(Matrix<Element> const& inMat);

template <class Element>
std::ostream& operator<<(std::ostream& os, const Matrix<Element>& m);

Matrix<double> Cholesky(const Matrix<int32_t>& input);

void Cholesky(const Matrix<int32_t>& input, Matrix<double>& result);

Matrix<int32_t> ConvertToInt32(const Matrix<BigInteger>& input, const BigInteger& modulus);

Matrix<int32_t> ConvertToInt32(const Matrix<BigVector>& input, const BigInteger& modulus);

template <typename Element>
Matrix<Element> SplitInt64IntoElements(Matrix<int64_t> const& other, size_t n,
                                       const std::shared_ptr<typename Element::Params> params);

#define SPLIT64_FOR_TYPE(T)                                                              \
    template <>                                                                          \
    Matrix<T> SplitInt64IntoElements(Matrix<int64_t> const& other, size_t n,             \
                                     const std::shared_ptr<typename T::Params> params) { \
        auto zero_alloc = T::Allocator(params, Format::COEFFICIENT);                     \
        size_t rows     = other.GetRows() / n;                                           \
        Matrix<T> result(zero_alloc, rows, 1);                                           \
        for (size_t row = 0; row < rows; ++row) {                                        \
            std::vector<int64_t> values(n);                                              \
            for (size_t i = 0; i < n; ++i)                                               \
                values[i] = other(row * n + i, 0);                                       \
            result(row, 0) = values;                                                     \
        }                                                                                \
        return result;                                                                   \
    }

template <typename Element>
Matrix<Element> SplitInt32AltIntoElements(Matrix<int32_t> const& other, size_t n,
                                          const std::shared_ptr<typename Element::Params> params);

#define SPLIT32ALT_FOR_TYPE(T)                                                              \
    template <>                                                                             \
    Matrix<T> SplitInt32AltIntoElements(Matrix<int32_t> const& other, size_t n,             \
                                        const std::shared_ptr<typename T::Params> params) { \
        auto zero_alloc = T::Allocator(params, Format::COEFFICIENT);                        \
        size_t rows     = other.GetRows();                                                  \
        Matrix<T> result(zero_alloc, rows, 1);                                              \
        for (size_t row = 0; row < rows; ++row) {                                           \
            std::vector<int32_t> values(n);                                                 \
            for (size_t i = 0; i < n; ++i)                                                  \
                values[i] = other(row, i);                                                  \
            result(row, 0) = values;                                                        \
        }                                                                                   \
        return result;                                                                      \
    }

template <typename Element>
Matrix<Element> SplitInt64AltIntoElements(Matrix<int64_t> const& other, size_t n,
                                          const std::shared_ptr<typename Element::Params> params);

#define SPLIT64ALT_FOR_TYPE(T)                                                              \
    template <>                                                                             \
    Matrix<T> SplitInt64AltIntoElements(Matrix<int64_t> const& other, size_t n,             \
                                        const std::shared_ptr<typename T::Params> params) { \
        auto zero_alloc = T::Allocator(params, Format::COEFFICIENT);                        \
        size_t rows     = other.GetRows();                                                  \
        Matrix<T> result(zero_alloc, rows, 1);                                              \
        for (size_t row = 0; row < rows; ++row) {                                           \
            std::vector<int64_t> values(n);                                                 \
            for (size_t i = 0; i < n; ++i)                                                  \
                values[i] = other(row, i);                                                  \
            result(row, 0) = values;                                                        \
        }                                                                                   \
        return result;                                                                      \
    }

}  // namespace lbcrypto
#endif  // LBCRYPTO_MATH_MATRIX_H