Skip to content

Commit

Permalink
update jit lu decomp
Browse files Browse the repository at this point in the history
  • Loading branch information
mattldawson committed Dec 19, 2024
1 parent 9b9dc80 commit 5f42f05
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
7 changes: 6 additions & 1 deletion include/micm/jit/solver/jit_lu_decomposition_doolittle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ namespace micm

~JitLuDecompositionDoolittle();

/// @brief Create an LU decomposition algorithm for a given sparse matrix policy
/// @param matrix Sparse matrix
template<class SparseMatrixPolicy, class LMatrixPolicy = SparseMatrixPolicy, class UMatrixPolicy = SparseMatrixPolicy>
requires(SparseMatrixConcept<SparseMatrixPolicy>) static JitLuDecompositionDoolittle<L> Create(const SparseMatrixPolicy& matrix);

/// @brief Create sparse L and U matrices for a given A matrix
/// @param A Sparse matrix that will be decomposed
/// @param lower The lower triangular matrix created by decomposition
/// @param upper The upper triangular matrix created by decomposition
template<class SparseMatrixPolicy>
void Decompose(const SparseMatrixPolicy &A, SparseMatrixPolicy &lower, SparseMatrixPolicy &upper) const;
void Decompose(const SparseMatrixPolicy &A, auto& lower, auto& upper) const;

private:
/// @brief Generates a function to perform the LU decomposition for a specific matrix sparsity structure
Expand Down
20 changes: 16 additions & 4 deletions include/micm/jit/solver/jit_lu_decomposition_doolittle.inl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ namespace micm
template<std::size_t L>
inline JitLuDecompositionDoolittle<L>::JitLuDecompositionDoolittle(
const SparseMatrix<double, SparseMatrixVectorOrdering<L>> &matrix)
: LuDecompositionDoolittle(
LuDecompositionDoolittle::Create<SparseMatrix<double, SparseMatrixVectorOrdering<L>>>(matrix))
{
using SparseMatrixPolicy = SparseMatrix<double, SparseMatrixVectorOrdering<L>>;
decompose_function_ = NULL;
if (matrix.NumberOfBlocks() > L)
{
Expand All @@ -39,6 +38,7 @@ namespace micm
std::to_string(L);
throw std::system_error(make_error_code(MicmJitErrc::InvalidMatrix), msg);
}
Initialize<SparseMatrixPolicy, SparseMatrixPolicy, SparseMatrixPolicy>(matrix, typename SparseMatrixPolicy::value_type());
GenerateDecomposeFunction();
}

Expand All @@ -52,6 +52,18 @@ namespace micm
}
}

template<std::size_t L>
template<class SparseMatrixPolicy, class LMatrixPolicy, class UMatrixPolicy>
requires(SparseMatrixConcept<SparseMatrixPolicy>)
inline JitLuDecompositionDoolittle<L> JitLuDecompositionDoolittle<L>::Create(
const SparseMatrixPolicy& matrix)
{
static_assert(std::is_same_v<SparseMatrixPolicy, LMatrixPolicy>, "SparseMatrixPolicy must be the same as LMatrixPolicy for JIT LU decomposition");
static_assert(std::is_same_v<SparseMatrixPolicy, UMatrixPolicy>, "SparseMatrixPolicy must be the same as UMatrixPolicy for JIT LU decomposition");
JitLuDecompositionDoolittle<L> lu_decomp(matrix);
return lu_decomp;
}

template<std::size_t L>
void JitLuDecompositionDoolittle<L>::GenerateDecomposeFunction()
{
Expand Down Expand Up @@ -210,8 +222,8 @@ namespace micm
template<class SparseMatrixPolicy>
void JitLuDecompositionDoolittle<L>::Decompose(
const SparseMatrixPolicy &A,
SparseMatrixPolicy &lower,
SparseMatrixPolicy &upper) const
auto& lower,
auto& upper) const
{
decompose_function_(A.AsVector().data(), lower.AsVector().data(), upper.AsVector().data());
for (size_t block = 0; block < A.NumberOfBlocks(); ++block)
Expand Down
10 changes: 5 additions & 5 deletions include/micm/solver/lu_decomposition_doolittle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,20 @@ namespace micm

/// @brief Construct an LU decomposition algorithm for a given sparse matrix
/// @param matrix Sparse matrix
template<class SparseMatrixPolicy, class LMatrixPolicy, class UMatrixPolicy>
template<class SparseMatrixPolicy, class LMatrixPolicy = SparseMatrixPolicy, class UMatrixPolicy = SparseMatrixPolicy>
requires(SparseMatrixConcept<SparseMatrixPolicy>) LuDecompositionDoolittle(const SparseMatrixPolicy& matrix);

~LuDecompositionDoolittle() = default;

/// @brief Create an LU decomposition algorithm for a given sparse matrix policy
/// @param matrix Sparse matrix
template<class SparseMatrixPolicy, class LMatrixPolicy, class UMatrixPolicy>
template<class SparseMatrixPolicy, class LMatrixPolicy = SparseMatrixPolicy, class UMatrixPolicy = SparseMatrixPolicy>
requires(SparseMatrixConcept<SparseMatrixPolicy>) static LuDecompositionDoolittle Create(const SparseMatrixPolicy& matrix);

/// @brief Create sparse L and U matrices for a given A matrix
/// @param A Sparse matrix that will be decomposed
/// @return L and U Sparse matrices
template<class SparseMatrixPolicy, class LMatrixPolicy, class UMatrixPolicy>
template<class SparseMatrixPolicy, class LMatrixPolicy = SparseMatrixPolicy, class UMatrixPolicy = SparseMatrixPolicy>
requires(SparseMatrixConcept<SparseMatrixPolicy>) static std::pair<LMatrixPolicy, UMatrixPolicy> GetLUMatrices(
const SparseMatrixPolicy& A,
typename SparseMatrixPolicy::value_type initial_value);
Expand All @@ -122,10 +122,10 @@ namespace micm
auto& L,
auto& U) const;

private:
protected:
/// @brief Initialize arrays for the LU decomposition
/// @param A Sparse matrix to decompose
template<class SparseMatrixPolicy, class LMatrixPolicy, class UMatrixPolicy>
template<class SparseMatrixPolicy, class LMatrixPolicy = SparseMatrixPolicy, class UMatrixPolicy = SparseMatrixPolicy>
requires(SparseMatrixConcept<SparseMatrixPolicy>) void Initialize(const SparseMatrixPolicy& matrix, auto initial_value);
};

Expand Down

0 comments on commit 5f42f05

Please sign in to comment.