Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

514 create a builder class that can make a backward euler solver #527

Merged
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b52a86d
starting to piece together the builder
K20shores May 13, 2024
8002c87
hiding some template details behind a solver impl
K20shores May 14, 2024
fe22e55
removing a lot of template template parameters
K20shores May 15, 2024
ae78706
removing more template templates, compiling with gcc
K20shores May 16, 2024
eb2c6d8
chaning function name and adding a type alias
K20shores May 16, 2024
42de6d9
adding state policy in progress
K20shores May 16, 2024
100b3d6
update for new state policies
mattldawson May 16, 2024
59fc99a
getting a solver builder to work
K20shores May 17, 2024
82a32be
it compiles with clang on my machine
K20shores May 17, 2024
af79b76
making the dockerfile build
K20shores May 17, 2024
c0782ef
making sure profiling option compiles
K20shores May 17, 2024
bdfcf91
attempting to get cuda to compile
K20shores May 20, 2024
4d03171
i think nvidia all compiles now
K20shores May 21, 2024
2ac73fa
starting to make the jit compiler a singleton
K20shores May 23, 2024
8497e9c
trying to pull my old changes
K20shores May 23, 2024
e15bdd3
trying to define jit compiler as a singleton
K20shores May 23, 2024
a4e7d79
saving progress just in case since I've made some
K20shores May 23, 2024
2fb3917
it compiles
K20shores May 23, 2024
c93c87f
removing more lambdas
K20shores May 23, 2024
15bd1f8
suppresing openmp checks
K20shores May 23, 2024
b678291
updating suppresion file
K20shores May 23, 2024
b4358b7
addressing PR comments
K20shores May 24, 2024
4557b58
correcting implementation
K20shores May 24, 2024
49d9e5f
merging main
K20shores May 24, 2024
63368c0
reverting readme example
K20shores May 24, 2024
a0aa13e
addressing PR comments
K20shores May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docker/Dockerfile.intel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ RUN apt update \
ca-certificates \
cmake \
cmake-curses-gui \
curl \
libcurl4-openssl-dev \
libhdf5-dev \
m4 \
Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile.llvm
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RUN dnf -y update \
make \
zlib-devel \
llvm-devel \
openmpi-devel \
valgrind \
&& dnf clean all

Expand All @@ -24,6 +25,7 @@ RUN mkdir /build \
-D MICM_ENABLE_CLANG_TIDY:BOOL=FALSE \
-D MICM_ENABLE_LLVM:BOOL=TRUE \
-D MICM_ENABLE_MEMCHECK:BOOL=TRUE \
-D MICM_ENABLE_OPENMP:BOOL=TRUE \
../micm \
&& make install -j 8

Expand Down
21 changes: 3 additions & 18 deletions examples/profile_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace fs = std::filesystem;
using namespace micm;

template<template<class> class MatrixType, template<class> class SparseMatrixType>
template<template<class> class MatrixType, class SparseMatrixType>
int Run(const char* filepath, const char* initial_conditions, const std::string& matrix_ordering_type)
{
using SolverType = RosenbrockSolver<MatrixType, SparseMatrixType>;
Expand Down Expand Up @@ -114,24 +114,9 @@ int Run(const char* filepath, const char* initial_conditions, const std::string&
return 0;
}

template<class T>
using SparseMatrixParam = micm::SparseMatrix<T>;
template<class T>
using Vector1MatrixParam = micm::VectorMatrix<T, 1>;
template<class T>
using Vector10MatrixParam = micm::VectorMatrix<T, 10>;
template<class T>
using Vector100MatrixParam = micm::VectorMatrix<T, 100>;
template<class T>
template<typename T>
using Vector1000MatrixParam = micm::VectorMatrix<T, 1000>;
template<class T>
using Vector1SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1>>;
template<class T>
using Vector10SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<10>>;
template<class T>
using Vector100SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<100>>;
template<class T>
using Vector1000SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1000>>;
using Vector1000SparseMatrixParam = micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<1000>>;

int main(const int argc, const char* argv[])
{
Expand Down
99 changes: 60 additions & 39 deletions include/micm/jit/jit_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
enum class MicmJitErrc
{
InvalidMatrix = MICM_JIT_ERROR_CODE_INVALID_MATRIX,
MissingJitFunction = MICM_JIT_ERROR_CODE_MISSING_JIT_FUNCTION
MissingJitFunction = MICM_JIT_ERROR_CODE_MISSING_JIT_FUNCTION,
FailedToBuild = MICM_JIT_ERROR_CODE_FAILED_TO_BUILD
};

namespace std
Expand Down Expand Up @@ -84,6 +85,7 @@ inline std::error_code make_error_code(MicmJitErrc e)
namespace micm
{

// a singleton class
class JitCompiler
{
private:
Expand All @@ -99,23 +101,23 @@ namespace micm
llvm::orc::JITDylib &main_lib_;

public:
JitCompiler(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
llvm::orc::JITTargetMachineBuilder machine_builder,
llvm::DataLayout data_layout)
: execution_session_(std::move(execution_session)),
data_layout_(std::move(data_layout)),
mangle_(*this->execution_session_, this->data_layout_),
object_layer_(*this->execution_session_, []() { return std::make_unique<llvm::SectionMemoryManager>(); }),
compile_layer_(
*this->execution_session_,
object_layer_,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(machine_builder))),
optimize_layer_(*this->execution_session_, compile_layer_, OptimizeModule),
main_lib_(this->execution_session_->createBareJITDylib("<main>"))
// Delete the copy constructor and assignment operator
JitCompiler(const JitCompiler &) = delete;
JitCompiler &operator=(const JitCompiler &) = delete;

static JitCompiler &GetInstance()
{
main_lib_.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(data_layout_.getGlobalPrefix())));
static std::unique_ptr<JitCompiler> instance;
if (!instance)
{
auto expectedInstance = Create();
if (!expectedInstance)
{
throw std::system_error(make_error_code(MicmJitErrc::FailedToBuild));
}
instance = std::move(*expectedInstance);
}
return *instance;
}

~JitCompiler()
Expand All @@ -124,28 +126,6 @@ namespace micm
execution_session_->reportError(std::move(Err));
}

static llvm::Expected<std::shared_ptr<JitCompiler>> Create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto EPC = llvm::orc::SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(std::move(*EPC));

llvm::orc::JITTargetMachineBuilder machine_builder(execution_session->getExecutorProcessControl().getTargetTriple());

auto data_layout = machine_builder.getDefaultDataLayoutForTarget();
if (!data_layout)
return data_layout.takeError();

return std::make_shared<JitCompiler>(
std::move(execution_session), std::move(machine_builder), std::move(*data_layout));
}

const llvm::DataLayout &GetDataLayout() const
{
return data_layout_;
Expand All @@ -171,6 +151,47 @@ namespace micm
}

private:
static llvm::Expected<std::unique_ptr<JitCompiler>> Create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto EPC = llvm::orc::SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(std::move(*EPC));

llvm::orc::JITTargetMachineBuilder machine_builder(execution_session->getExecutorProcessControl().getTargetTriple());

auto data_layout = machine_builder.getDefaultDataLayoutForTarget();
if (!data_layout)
return data_layout.takeError();

return llvm::Expected<std::unique_ptr<JitCompiler>>(std::unique_ptr<JitCompiler>(
new JitCompiler(std::move(execution_session), std::move(machine_builder), std::move(*data_layout))));
}

JitCompiler(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
llvm::orc::JITTargetMachineBuilder machine_builder,
llvm::DataLayout data_layout)
: execution_session_(std::move(execution_session)),
data_layout_(std::move(data_layout)),
mangle_(*this->execution_session_, this->data_layout_),
object_layer_(*this->execution_session_, []() { return std::make_unique<llvm::SectionMemoryManager>(); }),
compile_layer_(
*this->execution_session_,
object_layer_,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(machine_builder))),
optimize_layer_(*this->execution_session_, compile_layer_, OptimizeModule),
main_lib_(this->execution_session_->createBareJITDylib("<main>"))
{
main_lib_.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(data_layout_.getGlobalPrefix())));
}

static llvm::Expected<llvm::orc::ThreadSafeModule> OptimizeModule(
llvm::orc::ThreadSafeModule threadsafe_module,
const llvm::orc::MaterializationResponsibility &responsibility)
Expand Down
16 changes: 8 additions & 8 deletions include/micm/jit/jit_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace micm
{
bool generated_ = false;
std::string name_;
std::shared_ptr<JitCompiler> compiler_;
JitCompiler* compiler_;
mattldawson marked this conversation as resolved.
Show resolved Hide resolved

public:
std::unique_ptr<llvm::LLVMContext> context_;
Expand All @@ -91,7 +91,7 @@ namespace micm
JitFunction() = delete;

friend class JitFunctionBuilder;
static JitFunctionBuilder Create(std::shared_ptr<JitCompiler> compiler);
static JitFunctionBuilder Create();
JitFunction(JitFunctionBuilder& function_builder);

/// @brief Generates the function
Expand Down Expand Up @@ -138,23 +138,23 @@ namespace micm

class JitFunctionBuilder
{
std::shared_ptr<JitCompiler> compiler_;
JitCompiler* compiler_;
std::string name_;
std::vector<std::pair<std::string, JitType>> arguments_;
JitType return_type_{ JitType::Void };
friend class JitFunction;

public:
JitFunctionBuilder() = delete;
JitFunctionBuilder(std::shared_ptr<JitCompiler> compiler);
JitFunctionBuilder(JitCompiler& compiler);
JitFunctionBuilder& SetName(const std::string& name);
JitFunctionBuilder& SetArguments(const std::vector<std::pair<std::string, JitType>>& arguments);
JitFunctionBuilder& SetReturnType(JitType type);
};

inline JitFunctionBuilder JitFunction::Create(std::shared_ptr<JitCompiler> compiler)
inline JitFunctionBuilder JitFunction::Create()
{
return JitFunctionBuilder{ compiler };
return JitFunctionBuilder{ JitCompiler::GetInstance() };
}

JitFunction::JitFunction(JitFunctionBuilder& function_builder)
Expand Down Expand Up @@ -296,8 +296,8 @@ namespace micm
return TmpB.CreateAlloca(type, 0, var_name.c_str());
}

inline JitFunctionBuilder::JitFunctionBuilder(std::shared_ptr<JitCompiler> compiler)
: compiler_(compiler){};
inline JitFunctionBuilder::JitFunctionBuilder(JitCompiler& compiler)
: compiler_(&compiler){};

inline JitFunctionBuilder& JitFunctionBuilder::SetName(const std::string& name)
{
Expand Down
42 changes: 21 additions & 21 deletions include/micm/process/cuda_process_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ namespace micm
template<typename OrderingPolicy>
void SetJacobianFlatIds(const SparseMatrix<double, OrderingPolicy>& matrix);

template<template<class> typename MatrixPolicy>
requires(CudaMatrix<MatrixPolicy<double>>&& VectorizableDense<MatrixPolicy<double>>) void AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const;

template<template<class> typename MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy<double>>) void AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const;
template<typename MatrixPolicy>
requires(CudaMatrix<MatrixPolicy>&& VectorizableDense<MatrixPolicy>) void AddForcingTerms(
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const;

template<typename MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy>) void AddForcingTerms(
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const;

template<class MatrixPolicy, class SparseMatrixPolicy>
requires(
Expand Down Expand Up @@ -101,24 +101,24 @@ namespace micm
micm::cuda::CopyJacobiFlatId(hoststruct, this->devstruct_);
}

template<template<class> class MatrixPolicy>
requires(CudaMatrix<MatrixPolicy<double>>&& VectorizableDense<MatrixPolicy<double>>) inline void CudaProcessSet::
template<class MatrixPolicy>
requires(CudaMatrix<MatrixPolicy>&& VectorizableDense<MatrixPolicy>) inline void CudaProcessSet::
AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const
{
auto forcing_param = forcing.AsDeviceParam(); // we need to update forcing so it can't be constant and must be an lvalue
micm::cuda::AddForcingTermsKernelDriver(
rate_constants.AsDeviceParam(), state_variables.AsDeviceParam(), forcing_param, this->devstruct_);
}

// call the function from the base class
template<template<class> class MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy<double>>) inline void CudaProcessSet::AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const
template<class MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy>) inline void CudaProcessSet::AddForcingTerms(
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const
{
AddForcingTerms(rate_constants, state_variables, forcing);
}
Expand Down
Loading
Loading