From 5f42f05340b50fb659da05a8532b34df3a6684e7 Mon Sep 17 00:00:00 2001 From: Matt Dawson Date: Thu, 19 Dec 2024 12:18:51 -0800 Subject: [PATCH] update jit lu decomp --- .../solver/jit_lu_decomposition_doolittle.hpp | 7 ++++++- .../solver/jit_lu_decomposition_doolittle.inl | 20 +++++++++++++++---- .../solver/lu_decomposition_doolittle.hpp | 10 +++++----- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/include/micm/jit/solver/jit_lu_decomposition_doolittle.hpp b/include/micm/jit/solver/jit_lu_decomposition_doolittle.hpp index b6f445d99..3045b7789 100644 --- a/include/micm/jit/solver/jit_lu_decomposition_doolittle.hpp +++ b/include/micm/jit/solver/jit_lu_decomposition_doolittle.hpp @@ -35,12 +35,17 @@ namespace micm ~JitLuDecompositionDoolittle(); + /// @brief Create an LU decomposition algorithm for a given sparse matrix policy + /// @param matrix Sparse matrix + template + requires(SparseMatrixConcept) static JitLuDecompositionDoolittle Create(const SparseMatrixPolicy& matrix); + /// @brief Create sparse L and U matrices for a given A matrix /// @param A Sparse matrix that will be decomposed /// @param lower The lower triangular matrix created by decomposition /// @param upper The upper triangular matrix created by decomposition template - void Decompose(const SparseMatrixPolicy &A, SparseMatrixPolicy &lower, SparseMatrixPolicy &upper) const; + void Decompose(const SparseMatrixPolicy &A, auto& lower, auto& upper) const; private: /// @brief Generates a function to perform the LU decomposition for a specific matrix sparsity structure diff --git a/include/micm/jit/solver/jit_lu_decomposition_doolittle.inl b/include/micm/jit/solver/jit_lu_decomposition_doolittle.inl index 542d60848..95cc27590 100644 --- a/include/micm/jit/solver/jit_lu_decomposition_doolittle.inl +++ b/include/micm/jit/solver/jit_lu_decomposition_doolittle.inl @@ -26,9 +26,8 @@ namespace micm template inline JitLuDecompositionDoolittle::JitLuDecompositionDoolittle( const SparseMatrix> &matrix) - : LuDecompositionDoolittle( - LuDecompositionDoolittle::Create>>(matrix)) { + using SparseMatrixPolicy = SparseMatrix>; decompose_function_ = NULL; if (matrix.NumberOfBlocks() > L) { @@ -39,6 +38,7 @@ namespace micm std::to_string(L); throw std::system_error(make_error_code(MicmJitErrc::InvalidMatrix), msg); } + Initialize(matrix, typename SparseMatrixPolicy::value_type()); GenerateDecomposeFunction(); } @@ -52,6 +52,18 @@ namespace micm } } + template + template + requires(SparseMatrixConcept) + inline JitLuDecompositionDoolittle JitLuDecompositionDoolittle::Create( + const SparseMatrixPolicy& matrix) + { + static_assert(std::is_same_v, "SparseMatrixPolicy must be the same as LMatrixPolicy for JIT LU decomposition"); + static_assert(std::is_same_v, "SparseMatrixPolicy must be the same as UMatrixPolicy for JIT LU decomposition"); + JitLuDecompositionDoolittle lu_decomp(matrix); + return lu_decomp; + } + template void JitLuDecompositionDoolittle::GenerateDecomposeFunction() { @@ -210,8 +222,8 @@ namespace micm template void JitLuDecompositionDoolittle::Decompose( const SparseMatrixPolicy &A, - SparseMatrixPolicy &lower, - SparseMatrixPolicy &upper) const + auto& lower, + auto& upper) const { decompose_function_(A.AsVector().data(), lower.AsVector().data(), upper.AsVector().data()); for (size_t block = 0; block < A.NumberOfBlocks(); ++block) diff --git a/include/micm/solver/lu_decomposition_doolittle.hpp b/include/micm/solver/lu_decomposition_doolittle.hpp index d1fe311e3..ef4d98b45 100644 --- a/include/micm/solver/lu_decomposition_doolittle.hpp +++ b/include/micm/solver/lu_decomposition_doolittle.hpp @@ -89,20 +89,20 @@ namespace micm /// @brief Construct an LU decomposition algorithm for a given sparse matrix /// @param matrix Sparse matrix - template + template requires(SparseMatrixConcept) LuDecompositionDoolittle(const SparseMatrixPolicy& matrix); ~LuDecompositionDoolittle() = default; /// @brief Create an LU decomposition algorithm for a given sparse matrix policy /// @param matrix Sparse matrix - template + template requires(SparseMatrixConcept) static LuDecompositionDoolittle Create(const SparseMatrixPolicy& matrix); /// @brief Create sparse L and U matrices for a given A matrix /// @param A Sparse matrix that will be decomposed /// @return L and U Sparse matrices - template + template requires(SparseMatrixConcept) static std::pair GetLUMatrices( const SparseMatrixPolicy& A, typename SparseMatrixPolicy::value_type initial_value); @@ -122,10 +122,10 @@ namespace micm auto& L, auto& U) const; - private: + protected: /// @brief Initialize arrays for the LU decomposition /// @param A Sparse matrix to decompose - template + template requires(SparseMatrixConcept) void Initialize(const SparseMatrixPolicy& matrix, auto initial_value); };