Skip to content

Commit

Permalink
Emit CVODE callback only on demand. (#1551)
Browse files Browse the repository at this point in the history
Use `codegen --cvode` to enable.
  • Loading branch information
1uc authored Nov 7, 2024
1 parent 80d7b8f commit 25da002
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ std::vector<IndexVariableInfo> CodegenCppVisitor::get_int_variables() {
void CodegenCppVisitor::setup(const Program& node) {
program_symtab = node.get_symbol_table();

CodegenHelperVisitor v;
CodegenHelperVisitor v(enable_cvode);
info = v.analyze(node);
info.mod_file = mod_filename;

Expand Down
17 changes: 15 additions & 2 deletions src/codegen/codegen_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,20 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
, float_type(std::move(float_type))
, optimize_ionvar_copies(optimize_ionvar_copies) {}

CodegenCppVisitor(std::string mod_filename,
std::ostream& stream,
std::string float_type,
const bool optimize_ionvar_copies,
const bool enable_cvode,
std::unique_ptr<nmodl::utils::Blame> blame = nullptr)
: printer(std::make_unique<CodePrinter>(stream, std::move(blame)))
, mod_filename(std::move(mod_filename))
, float_type(std::move(float_type))
, optimize_ionvar_copies(optimize_ionvar_copies)
, enable_cvode(enable_cvode) {}

private:
bool enable_cvode = false;

protected:
using SymbolType = std::shared_ptr<symtab::Symbol>;
Expand Down Expand Up @@ -1481,11 +1495,10 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
std::string compute_method_name(BlockType type) const;

// Automatically called as part of `visit_program`.
void setup(const ast::Program& node);
virtual void setup(const ast::Program& node);

public:


/**
* Main and only member function to call after creating an instance of this class.
* \param program the AST to translate to C++ code
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_helper_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void CodegenHelperVisitor::check_cvode_codegen(const ast::Program& node) {
// PROCEDURE with `after_cvode` method
if (solve_nodes.size() == 1 && (kinetic_or_derivative_nodes.size() || using_cvode)) {
logger->debug("Will emit code for CVODE");
info.emit_cvode = true;
info.emit_cvode = enable_cvode;
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/codegen/codegen_helper_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor {
/// holds all codegen related information
codegen::CodegenInfo info;

/// Config variable for enabling/disabling CVODE, see `emit_cvode`.
bool enable_cvode;

/// if visiting net receive block
bool under_net_receive_block = false;

Expand All @@ -76,8 +79,10 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor {
void find_neuron_global_variables();
static void sort_with_mod2c_symbol_order(std::vector<SymbolType>& symbols);
void check_cvode_codegen(const ast::Program& node);

public:
CodegenHelperVisitor() = default;
CodegenHelperVisitor(bool enable_cvode = false)
: enable_cvode(enable_cvode) {}

/// run visitor and return information for code generation
codegen::CodegenInfo analyze(const ast::Program& node);
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2809,6 +2809,7 @@ static void rename_net_receive_arguments(const ast::NetReceiveBlock& net_receive
}
}


CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::net_receive_args() {
return {{"", "Point_process*", "", "_pnt"},
{"", "double*", "", "_args"},
Expand Down
11 changes: 9 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ int run_nmodl(int argc, const char* argv[]) {
/// true if top level local variables to be converted to range
bool nmodl_local_to_range(false);

/// true if CVODE should be emitted
bool codegen_cvode(false);

/// true if localize variables even if verbatim block is used
bool localize_verbatim(false);

Expand Down Expand Up @@ -282,6 +285,9 @@ int run_nmodl(int argc, const char* argv[]) {
codegen_opt->add_flag("--opt-ionvar-copy",
optimize_ionvar_copies_codegen,
fmt::format("Optimize copies of ion variables ({})", optimize_ionvar_copies_codegen))->ignore_case();
codegen_opt->add_flag("--cvode",
codegen_cvode,
fmt::format("Print code for CVODE ({})", codegen_cvode))->ignore_case();

#if NMODL_ENABLE_BACKWARD
auto blame_opt = app.add_subcommand("blame", "Blame NMODL code that generated some code.");
Expand Down Expand Up @@ -352,7 +358,7 @@ int run_nmodl(int argc, const char* argv[]) {
}

/// use cnexp instead of after_cvode solve method
{
if (codegen_cvode) {
logger->info("Running CVode to cnexp visitor");
AfterCVodeToCnexpVisitor().visit_program(*ast);
ast_to_nmodl(*ast, filepath("after_cvode_to_cnexp"));
Expand Down Expand Up @@ -531,7 +537,7 @@ int run_nmodl(int argc, const char* argv[]) {
.api()
.initialize_interpreter();

if (neuron_code) {
if (neuron_code && codegen_cvode) {
logger->info("Running CVODE visitor");
CvodeVisitor().visit_program(*ast);
SymtabVisitor(update_symtab).visit_program(*ast);
Expand Down Expand Up @@ -631,6 +637,7 @@ int run_nmodl(int argc, const char* argv[]) {
output_stream,
data_type,
optimize_ionvar_copies_codegen,
codegen_cvode,
utils::make_blame(blame_line, blame_level));
visitor.visit_program(*ast);
}
Expand Down
3 changes: 2 additions & 1 deletion test/unit/codegen/codegen_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ CodegenInfo run_codegen_helper_get_info(const std::string& text) {
SolveBlockVisitor{}.visit_program(*ast);
SymtabVisitor{true}.visit_program(*ast);

CodegenHelperVisitor v;
bool enable_cvode = true;
CodegenHelperVisitor v(enable_cvode);
const auto info = v.analyze(*ast);

return info;
Expand Down
6 changes: 5 additions & 1 deletion test/unit/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ std::shared_ptr<CodegenNeuronCppVisitor> create_neuron_cpp_visitor(
SolveBlockVisitor().visit_program(*ast);
FunctionCallpathVisitor().visit_program(*ast);

bool optimize_ion_vars = false;
bool enable_cvode = true;

/// create C code generation visitor
auto cv = std::make_shared<CodegenNeuronCppVisitor>("_test", ss, "double", false);
auto cv = std::make_shared<CodegenNeuronCppVisitor>(
"_test", ss, "double", optimize_ion_vars, enable_cvode);
return cv;
}

Expand Down
2 changes: 1 addition & 1 deletion test/usecases/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ run_tests nocmodl
# NRN + NMODL
echo "-- Running NRN+NMODL --------"
rm -r "${output_dir}" tmp || true
nrnivmodl -nmodl "${nmodl}"
nrnivmodl -nmodl "${nmodl}" -nmodlflags "codegen --cvode"
run_tests nmodl

0 comments on commit 25da002

Please sign in to comment.