Program Listing for File matrix-lattice-impl.h

Return to documentation for file (core/include/lattice/matrix-lattice-impl.h)

//==================================================================================
// BSD 2-Clause License
//
// Copyright (c) 2014-2023, NJIT, Duality Technologies Inc. and other contributors
//
// All rights reserved.
//
// Author TPOC: contact@openfhe.org
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//==================================================================================

/*
  matrix class implementations and type specific implementations
 */

#ifndef LBCRYPTO_INC_LATTICE_MATRIX_IMPL_H
#define LBCRYPTO_INC_LATTICE_MATRIX_IMPL_H

#include "math/matrix-impl.h"

#include "utils/parallel.h"

#include <memory>

// this is the implementation of matrixes of things that are in core
// and that need template specializations

namespace lbcrypto {

template <typename Element>
Matrix<typename Element::Integer> Rotate(Matrix<Element> const& inMat) {
    Matrix<Element> mat(inMat);
    mat.SetFormat(Format::COEFFICIENT);
    size_t n                                 = mat(0, 0).GetLength();
    typename Element::Integer const& modulus = mat(0, 0).GetModulus();
    size_t rows                              = mat.GetRows() * n;
    size_t cols                              = mat.GetCols() * n;
    Matrix<typename Element::Integer> result(Element::Integer::Allocator, rows, cols);
    for (size_t row = 0; row < mat.GetRows(); ++row) {
        for (size_t col = 0; col < mat.GetCols(); ++col) {
            for (size_t rotRow = 0; rotRow < n; ++rotRow) {
                for (size_t rotCol = 0; rotCol < n; ++rotCol) {
                    result(row * n + rotRow, col * n + rotCol) =
                        mat(row, col).GetValues().at((rotRow - rotCol + n) % n);
                    //  negate (mod q) upper-right triangle to account for
                    //  (mod x^n + 1)
                    if (rotRow < rotCol) {
                        result(row * n + rotRow, col * n + rotCol) =
                            modulus.ModSub(result(row * n + rotRow, col * n + rotCol), modulus);
                    }
                }
            }
        }
    }
    return result;
}

template <typename Element>
Matrix<typename Element::Vector> RotateVecResult(Matrix<Element> const& inMat) {
    Matrix<Element> mat(inMat);
    mat.SetFormat(Format::COEFFICIENT);
    size_t n                                 = mat(0, 0).GetLength();
    typename Element::Integer const& modulus = mat(0, 0).GetModulus();
    typename Element::Vector zero(1, modulus);
    size_t rows                = mat.GetRows() * n;
    size_t cols                = mat.GetCols() * n;
    auto singleElemBinVecAlloc = [=]() {
        return typename Element::Vector(1, modulus);
    };
    Matrix<typename Element::Vector> result(singleElemBinVecAlloc, rows, cols);
    for (size_t row = 0; row < mat.GetRows(); ++row) {
        for (size_t col = 0; col < mat.GetCols(); ++col) {
            for (size_t rotRow = 0; rotRow < n; ++rotRow) {
                for (size_t rotCol = 0; rotCol < n; ++rotCol) {
                    typename Element::Vector& elem = result(row * n + rotRow, col * n + rotCol);
                    elem.at(0)                     = mat(row, col).GetValues().at((rotRow - rotCol + n) % n);
                    //  negate (mod q) upper-right triangle to account for
                    //  (mod x^n + 1)
                    if (rotRow < rotCol) {
                        result(row * n + rotRow, col * n + rotCol) = zero.ModSub(elem);
                    }
                }
            }
        }
    }
    return result;
}

template <typename Element>
void Matrix<Element>::SetFormat(Format format) {
    if (data[0][0].GetFormat() != format)
        this->SwitchFormat();
}

template <typename Element>
void Matrix<Element>::SwitchFormat() {
    if (rows == 1) {
        // TODO: figure out why this is causing a segfault with GCC10
        // #pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(cols))
        for (size_t col = 0; col < cols; ++col) {
            data[0][col].SwitchFormat();
        }
    }
    else {
        // #pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(rows))
        for (size_t row = 0; row < rows; ++row) {
            for (size_t col = 0; col < cols; ++col) {
                data[row][col].SwitchFormat();
            }
        }
    }
}

//  Convert from Z_q to [-q/2, q/2]
template <typename T>
Matrix<int32_t> ConvertToInt32(const Matrix<T>& input, const T& modulus) {
    size_t rows = input.GetRows();
    size_t cols = input.GetCols();
    T negativeThreshold(modulus / BigInteger(2));
    Matrix<int32_t> result([]() { return 0; }, rows, cols);
    for (size_t i = 0; i < rows; ++i) {
        for (size_t j = 0; j < cols; ++j) {
            if (input(i, j) > negativeThreshold) {
                result(i, j) = -1 * (modulus - input(i, j)).ConvertToInt();
            }
            else {
                result(i, j) = input(i, j).ConvertToInt();
            }
        }
    }
    return result;
}

template <typename V>
Matrix<int32_t> ConvertToInt32(const Matrix<V>& input, const typename V::Integer& modulus) {
    size_t rows = input.GetRows();
    size_t cols = input.GetCols();
    typename V::Integer negativeThreshold(modulus / BigInteger(2));
    Matrix<int32_t> result([]() { return 0; }, rows, cols);
    for (size_t i = 0; i < rows; ++i) {
        for (size_t j = 0; j < cols; ++j) {
            const typename V::Integer& elem = input(i, j).at(0);
            if (elem > negativeThreshold) {
                result(i, j) = -1 * (modulus - elem).ConvertToInt();
            }
            else {
                result(i, j) = elem.ConvertToInt();
            }
        }
    }
    return result;
}

}  // namespace lbcrypto

#endif