diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index 76f20f335..c395bb7dc 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -1465,7 +1465,7 @@ std::vector CodegenCppVisitor::get_int_variables() { void CodegenCppVisitor::setup(const Program& node) { program_symtab = node.get_symbol_table(); - CodegenHelperVisitor v; + CodegenHelperVisitor v(enable_cvode); info = v.analyze(node); info.mod_file = mod_filename; diff --git a/src/codegen/codegen_cpp_visitor.hpp b/src/codegen/codegen_cpp_visitor.hpp index 5a958ad5b..1b2eda1e9 100644 --- a/src/codegen/codegen_cpp_visitor.hpp +++ b/src/codegen/codegen_cpp_visitor.hpp @@ -266,6 +266,20 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { , float_type(std::move(float_type)) , optimize_ionvar_copies(optimize_ionvar_copies) {} + CodegenCppVisitor(std::string mod_filename, + std::ostream& stream, + std::string float_type, + const bool optimize_ionvar_copies, + const bool enable_cvode, + std::unique_ptr blame = nullptr) + : printer(std::make_unique(stream, std::move(blame))) + , mod_filename(std::move(mod_filename)) + , float_type(std::move(float_type)) + , optimize_ionvar_copies(optimize_ionvar_copies) + , enable_cvode(enable_cvode) {} + + private: + bool enable_cvode = false; protected: using SymbolType = std::shared_ptr; @@ -1481,11 +1495,10 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { std::string compute_method_name(BlockType type) const; // Automatically called as part of `visit_program`. - void setup(const ast::Program& node); + virtual void setup(const ast::Program& node); public: - /** * Main and only member function to call after creating an instance of this class. * \param program the AST to translate to C++ code diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index 000966c9e..f2b5eb7ff 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -223,7 +223,7 @@ void CodegenHelperVisitor::check_cvode_codegen(const ast::Program& node) { // PROCEDURE with `after_cvode` method if (solve_nodes.size() == 1 && (kinetic_or_derivative_nodes.size() || using_cvode)) { logger->debug("Will emit code for CVODE"); - info.emit_cvode = true; + info.emit_cvode = enable_cvode; } } diff --git a/src/codegen/codegen_helper_visitor.hpp b/src/codegen/codegen_helper_visitor.hpp index 1e31fef46..20e5f88e1 100644 --- a/src/codegen/codegen_helper_visitor.hpp +++ b/src/codegen/codegen_helper_visitor.hpp @@ -51,6 +51,9 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor { /// holds all codegen related information codegen::CodegenInfo info; + /// Config variable for enabling/disabling CVODE, see `emit_cvode`. + bool enable_cvode; + /// if visiting net receive block bool under_net_receive_block = false; @@ -76,8 +79,10 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor { void find_neuron_global_variables(); static void sort_with_mod2c_symbol_order(std::vector& symbols); void check_cvode_codegen(const ast::Program& node); + public: - CodegenHelperVisitor() = default; + CodegenHelperVisitor(bool enable_cvode = false) + : enable_cvode(enable_cvode) {} /// run visitor and return information for code generation codegen::CodegenInfo analyze(const ast::Program& node); diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 1113c7ad6..e4aec03b5 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -2809,6 +2809,7 @@ static void rename_net_receive_arguments(const ast::NetReceiveBlock& net_receive } } + CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::net_receive_args() { return {{"", "Point_process*", "", "_pnt"}, {"", "double*", "", "_args"}, diff --git a/src/main.cpp b/src/main.cpp index 393021e7b..865527068 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -117,6 +117,9 @@ int run_nmodl(int argc, const char* argv[]) { /// true if top level local variables to be converted to range bool nmodl_local_to_range(false); + /// true if CVODE should be emitted + bool codegen_cvode(false); + /// true if localize variables even if verbatim block is used bool localize_verbatim(false); @@ -282,6 +285,9 @@ int run_nmodl(int argc, const char* argv[]) { codegen_opt->add_flag("--opt-ionvar-copy", optimize_ionvar_copies_codegen, fmt::format("Optimize copies of ion variables ({})", optimize_ionvar_copies_codegen))->ignore_case(); + codegen_opt->add_flag("--cvode", + codegen_cvode, + fmt::format("Print code for CVODE ({})", codegen_cvode))->ignore_case(); #if NMODL_ENABLE_BACKWARD auto blame_opt = app.add_subcommand("blame", "Blame NMODL code that generated some code."); @@ -352,7 +358,7 @@ int run_nmodl(int argc, const char* argv[]) { } /// use cnexp instead of after_cvode solve method - { + if (codegen_cvode) { logger->info("Running CVode to cnexp visitor"); AfterCVodeToCnexpVisitor().visit_program(*ast); ast_to_nmodl(*ast, filepath("after_cvode_to_cnexp")); @@ -531,7 +537,7 @@ int run_nmodl(int argc, const char* argv[]) { .api() .initialize_interpreter(); - if (neuron_code) { + if (neuron_code && codegen_cvode) { logger->info("Running CVODE visitor"); CvodeVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); @@ -631,6 +637,7 @@ int run_nmodl(int argc, const char* argv[]) { output_stream, data_type, optimize_ionvar_copies_codegen, + codegen_cvode, utils::make_blame(blame_line, blame_level)); visitor.visit_program(*ast); } diff --git a/test/unit/codegen/codegen_helper.cpp b/test/unit/codegen/codegen_helper.cpp index 5e5006f5d..1ea312061 100644 --- a/test/unit/codegen/codegen_helper.cpp +++ b/test/unit/codegen/codegen_helper.cpp @@ -69,7 +69,8 @@ CodegenInfo run_codegen_helper_get_info(const std::string& text) { SolveBlockVisitor{}.visit_program(*ast); SymtabVisitor{true}.visit_program(*ast); - CodegenHelperVisitor v; + bool enable_cvode = true; + CodegenHelperVisitor v(enable_cvode); const auto info = v.analyze(*ast); return info; diff --git a/test/unit/codegen/codegen_neuron_cpp_visitor.cpp b/test/unit/codegen/codegen_neuron_cpp_visitor.cpp index 860e23cae..463a20121 100644 --- a/test/unit/codegen/codegen_neuron_cpp_visitor.cpp +++ b/test/unit/codegen/codegen_neuron_cpp_visitor.cpp @@ -42,8 +42,12 @@ std::shared_ptr create_neuron_cpp_visitor( SolveBlockVisitor().visit_program(*ast); FunctionCallpathVisitor().visit_program(*ast); + bool optimize_ion_vars = false; + bool enable_cvode = true; + /// create C code generation visitor - auto cv = std::make_shared("_test", ss, "double", false); + auto cv = std::make_shared( + "_test", ss, "double", optimize_ion_vars, enable_cvode); return cv; } diff --git a/test/usecases/run_test.sh b/test/usecases/run_test.sh index 99dde6879..a90c916e9 100755 --- a/test/usecases/run_test.sh +++ b/test/usecases/run_test.sh @@ -38,5 +38,5 @@ run_tests nocmodl # NRN + NMODL echo "-- Running NRN+NMODL --------" rm -r "${output_dir}" tmp || true -nrnivmodl -nmodl "${nmodl}" +nrnivmodl -nmodl "${nmodl}" -nmodlflags "codegen --cvode" run_tests nmodl