Skip to content

Commit

Permalink
starting to make the jit compiler a singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
K20shores committed May 23, 2024
1 parent 4d03171 commit 2ac73fa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 43 deletions.
86 changes: 48 additions & 38 deletions include/micm/jit/jit_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ inline std::error_code make_error_code(MicmJitErrc e)
namespace micm
{

// a singleton class
class JitCompiler
{
private:
Expand All @@ -99,23 +100,14 @@ namespace micm
llvm::orc::JITDylib &main_lib_;

public:
JitCompiler(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
llvm::orc::JITTargetMachineBuilder machine_builder,
llvm::DataLayout data_layout)
: execution_session_(std::move(execution_session)),
data_layout_(std::move(data_layout)),
mangle_(*this->execution_session_, this->data_layout_),
object_layer_(*this->execution_session_, []() { return std::make_unique<llvm::SectionMemoryManager>(); }),
compile_layer_(
*this->execution_session_,
object_layer_,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(machine_builder))),
optimize_layer_(*this->execution_session_, compile_layer_, OptimizeModule),
main_lib_(this->execution_session_->createBareJITDylib("<main>"))
// Delete the copy constructor and assignment operator
JitCompiler(const JitCompiler&) = delete;
JitCompiler& operator=(const JitCompiler&) = delete;

static JitCompiler& GetInstance()
{
main_lib_.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(data_layout_.getGlobalPrefix())));
static JitCompiler instance = Create();
return instance;
}

~JitCompiler()
Expand All @@ -124,28 +116,6 @@ namespace micm
execution_session_->reportError(std::move(Err));
}

static llvm::Expected<std::shared_ptr<JitCompiler>> Create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto EPC = llvm::orc::SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(std::move(*EPC));

llvm::orc::JITTargetMachineBuilder machine_builder(execution_session->getExecutorProcessControl().getTargetTriple());

auto data_layout = machine_builder.getDefaultDataLayoutForTarget();
if (!data_layout)
return data_layout.takeError();

return std::make_shared<JitCompiler>(
std::move(execution_session), std::move(machine_builder), std::move(*data_layout));
}

const llvm::DataLayout &GetDataLayout() const
{
return data_layout_;
Expand All @@ -171,6 +141,46 @@ namespace micm
}

private:
static llvm::Expected<JitCompiler> Create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto EPC = llvm::orc::SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(std::move(*EPC));

llvm::orc::JITTargetMachineBuilder machine_builder(execution_session->getExecutorProcessControl().getTargetTriple());

auto data_layout = machine_builder.getDefaultDataLayoutForTarget();
if (!data_layout)
return data_layout.takeError();

return JitCompiler(std::move(execution_session), std::move(machine_builder), std::move(*data_layout));
}

JitCompiler(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
llvm::orc::JITTargetMachineBuilder machine_builder,
llvm::DataLayout data_layout)
: execution_session_(std::move(execution_session)),
data_layout_(std::move(data_layout)),
mangle_(*this->execution_session_, this->data_layout_),
object_layer_(*this->execution_session_, []() { return std::make_unique<llvm::SectionMemoryManager>(); }),
compile_layer_(
*this->execution_session_,
object_layer_,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(machine_builder))),
optimize_layer_(*this->execution_session_, compile_layer_, OptimizeModule),
main_lib_(this->execution_session_->createBareJITDylib("<main>"))
{
main_lib_.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(data_layout_.getGlobalPrefix())));
}

static llvm::Expected<llvm::orc::ThreadSafeModule> OptimizeModule(
llvm::orc::ThreadSafeModule threadsafe_module,
const llvm::orc::MaterializationResponsibility &responsibility)
Expand Down
7 changes: 2 additions & 5 deletions include/micm/process/jit_process_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ namespace micm
template<std::size_t L = MICM_DEFAULT_VECTOR_SIZE>
class JitProcessSet : public ProcessSet
{
std::shared_ptr<JitCompiler> compiler_;
llvm::orc::ResourceTrackerSP forcing_function_resource_tracker_;
void (*forcing_function_)(const double *, const double *, double *);
llvm::orc::ResourceTrackerSP jacobian_function_resource_tracker_;
Expand Down Expand Up @@ -81,7 +80,6 @@ namespace micm
template<std::size_t L>
inline JitProcessSet<L>::JitProcessSet(JitProcessSet &&other)
: ProcessSet(std::move(other)),
compiler_(std::move(other.compiler_)),
forcing_function_resource_tracker_(std::move(other.forcing_function_resource_tracker_)),
forcing_function_(std::move(other.forcing_function_)),
jacobian_function_resource_tracker_(std::move(other.jacobian_function_resource_tracker_)),
Expand All @@ -95,7 +93,6 @@ namespace micm
inline JitProcessSet<L> &JitProcessSet<L>::operator=(JitProcessSet &&other)
{
ProcessSet::operator=(std::move(other));
compiler_ = std::move(other.compiler_);
forcing_function_resource_tracker_ = std::move(other.forcing_function_resource_tracker_);
forcing_function_ = std::move(other.forcing_function_);
jacobian_function_resource_tracker_ = std::move(other.jacobian_function_resource_tracker_);
Expand All @@ -110,8 +107,7 @@ namespace micm
std::shared_ptr<JitCompiler> compiler,
const std::vector<Process> &processes,
const std::map<std::string, std::size_t> &variable_map)
: ProcessSet(processes, variable_map),
compiler_(compiler)
: ProcessSet(processes, variable_map)
{
forcing_function_ = NULL;
jacobian_function_ = NULL;
Expand All @@ -122,6 +118,7 @@ namespace micm
void JitProcessSet<L>::GenerateForcingFunction()
{
std::string function_name = "add_forcing_terms_" + GenerateRandomString();
auto compiler_ = JitCompiler::GetInstance();
JitFunction func = JitFunction::Create(compiler_)
.SetName(function_name)
.SetArguments({ { "rate constants", JitType::DoublePtr },
Expand Down

0 comments on commit 2ac73fa

Please sign in to comment.