diff --git a/src/codegen/codegen_naming.hpp b/src/codegen/codegen_naming.hpp index 2a0aa190f6..e8240b7df2 100644 --- a/src/codegen/codegen_naming.hpp +++ b/src/codegen/codegen_naming.hpp @@ -152,6 +152,9 @@ static constexpr char NRN_STATE_METHOD[] = "nrn_state"; /// nrn_cur method in generated code static constexpr char NRN_CUR_METHOD[] = "nrn_cur"; +/// nrn_jacob method in generated code +static constexpr char NRN_JACOB_METHOD[] = "nrn_jacob"; + /// nrn_watch_check method in generated c++ file static constexpr char NRN_WATCH_CHECK_METHOD[] = "nrn_watch_check"; @@ -164,6 +167,9 @@ static constexpr char THREAD_ARGS_PROTO[] = "_threadargsproto_"; /// prefix for ion variable static constexpr char ION_VARNAME_PREFIX[] = "ion_"; +/// hoc_nrnpointerindex name +static constexpr char NRN_POINTERINDEX[] = "hoc_nrnpointerindex"; + /// commonly used variables in verbatim block and how they /// should be mapped to new code generation backends diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index f4b3747ce0..5933c5a1d2 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -324,12 +324,19 @@ void CodegenNeuronCppVisitor::print_sdlists_init(bool print_initializers) { void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_initializers) { /// TODO: Print only global variables printed in NEURON - printer->add_line(); + printer->add_newline(2); printer->add_line("/* NEURON global variables */"); if (info.primes_size != 0) { printer->fmt_line("static neuron::container::field_index _slist1[{0}], _dlist1[{0}];", info.primes_size); } + printer->add_line("static int mech_type;"); + + printer->fmt_line("static int {} = {};", + naming::NRN_POINTERINDEX, + info.pointer_variables.size() > 0 + ? static_cast(info.pointer_variables.size()) + : -1); } @@ -388,11 +395,29 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { /// TODO: Write this according to NEURON printer->add_newline(2); printer->add_line("/** register channel with the simulator */"); - printer->fmt_push_block("void _{}_reg()", info.mod_file); + printer->fmt_push_block("extern \"C\" void _{}_reg()", info.mod_file); print_sdlists_init(true); + printer->add_newline(); + + const auto compute_functions_parameters = + breakpoint_exist() + ? fmt::format("{}, {}, {}", + nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "nullptr", + method_name(naming::NRN_JACOB_METHOD), + nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "nullptr") + : "nullptr, nullptr, nullptr"; + const auto register_mech_args = fmt::format("{}, {}, {}, {}, {}, {}", + get_channel_info_var_name(), + method_name(naming::NRN_ALLOC_METHOD), + compute_functions_parameters, + method_name(naming::NRN_INIT_METHOD), + naming::NRN_POINTERINDEX, + 1 + info.thread_data_index); + printer->fmt_line("register_mech({});", register_mech_args); + // type related information printer->add_newline(); - printer->fmt_line("int mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name()); + printer->fmt_line("mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name()); // More things to add here printer->add_line("_nrn_mechanism_register_data_fields(mech_type,"); @@ -425,6 +450,7 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) { + printer->add_newline(2); printer->add_line("/* NEURON RANGE variables macro definitions */"); for (auto i = 0; i < codegen_float_variables.size(); ++i) { const auto float_var = codegen_float_variables[i]; @@ -454,6 +480,34 @@ void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type, } +void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) { + codegen = true; + printer->add_newline(2); + printer->add_line("/** initialize channel */"); + + printer->fmt_line( + "static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* " + "_ml_arg, int _type) {{}}", + method_name(naming::NRN_INIT_METHOD)); + + codegen = false; +} + + +void CodegenNeuronCppVisitor::print_nrn_jacob() { + codegen = true; + printer->add_newline(2); + printer->add_line("/** nrn_jacob function */"); + + printer->fmt_line( + "static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* " + "_nt, Memb_list* _ml_arg, int _type) {{}}", + method_name(naming::NRN_JACOB_METHOD)); + + codegen = false; +} + + /// TODO: Edit for NEURON void CodegenNeuronCppVisitor::print_nrn_constructor() { return; @@ -468,7 +522,11 @@ void CodegenNeuronCppVisitor::print_nrn_destructor() { /// TODO: Print the equivalent of `nrn_alloc_` void CodegenNeuronCppVisitor::print_nrn_alloc() { - return; + printer->add_newline(2); + auto method = method_name(naming::NRN_ALLOC_METHOD); + printer->fmt_push_block("static void {}(Prop* _prop)", method); + printer->add_line("// do nothing"); + printer->pop_block(); } @@ -483,8 +541,13 @@ void CodegenNeuronCppVisitor::print_nrn_state() { return; } codegen = true; + printer->add_newline(2); + + printer->fmt_line( + "void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* " + "_ml_arg, int _type) {{}}", + method_name(naming::NRN_STATE_METHOD)); - printer->add_line("void nrn_state() {}"); /// TODO: Fill in codegen = false; @@ -533,8 +596,13 @@ void CodegenNeuronCppVisitor::print_nrn_cur() { } codegen = true; + printer->add_newline(2); + + printer->fmt_line( + "void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, " + "int _type) {{}}", + method_name(naming::NRN_CUR_METHOD)); - printer->add_line("void nrn_cur() {}"); /// TODO: Fill in codegen = false; @@ -542,7 +610,7 @@ void CodegenNeuronCppVisitor::print_nrn_cur() { /****************************************************************************************/ -/* Main code printing entry points */ +/* Main code printing entry points */ /****************************************************************************************/ void CodegenNeuronCppVisitor::print_headers_include() { @@ -590,6 +658,7 @@ void CodegenNeuronCppVisitor::print_mechanism_variables_macros() { using _nrn_model_sorted_token = neuron::model_sorted_token; using _nrn_mechanism_cache_range = neuron::cache::MechanismRange; using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance; + using _nrn_non_owning_id_without_container = neuron::container::non_owning_identifier_without_container; template using _nrn_mechanism_field = neuron::mechanism::field; template @@ -641,8 +710,10 @@ void CodegenNeuronCppVisitor::print_g_unused() const { /// TODO: Edit for NEURON void CodegenNeuronCppVisitor::print_compute_functions() { + print_nrn_init(); print_nrn_cur(); print_nrn_state(); + print_nrn_jacob(); } @@ -657,6 +728,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines() { print_prcellstate_macros(); print_mechanism_info(); print_data_structures(true); + print_nrn_alloc(); print_global_variables_for_hoc(); print_compute_functions(); // only nrn_cur and nrn_state print_mechanism_register(); diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index e8bf7e36af..374c0e817f 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -361,6 +361,13 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { const std::string& function_name = "") override; + /** + * Print the \c nrn\_init function definition + * \param skip_init_check \c true to generate code executing the initialization conditionally + */ + void print_nrn_init(bool skip_init_check = true); + + /** * Print nrn_constructor function definition * @@ -382,6 +389,13 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void print_nrn_alloc() override; + /** + * Print nrn_jacob function definition + * + */ + void print_nrn_jacob(); + + /****************************************************************************************/ /* Print nrn_state routine */ /****************************************************************************************/ diff --git a/test/unit/codegen/codegen_neuron_cpp_visitor.cpp b/test/unit/codegen/codegen_neuron_cpp_visitor.cpp index d56141010b..20fc2c4cb8 100644 --- a/test/unit/codegen/codegen_neuron_cpp_visitor.cpp +++ b/test/unit/codegen/codegen_neuron_cpp_visitor.cpp @@ -53,7 +53,7 @@ std::string get_neuron_cpp_code(const std::string& nmodl_text, std::stringstream ss; auto cvisitor = create_neuron_cpp_visitor(ast, nmodl_text, ss); cvisitor->visit_program(*ast); - return reindent_text(ss.str()); + return ss.str(); } @@ -138,6 +138,7 @@ using _nrn_mechanism_std_vector = std::vector; using _nrn_model_sorted_token = neuron::model_sorted_token; using _nrn_mechanism_cache_range = neuron::cache::MechanismRange; using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance; +using _nrn_non_owning_id_without_container = neuron::container::non_owning_identifier_without_container; template using _nrn_mechanism_field = neuron::mechanism::field; template @@ -213,26 +214,30 @@ void _nrn_mechanism_register_data_fields(Args&&... args) { ContainsSubstring(reindent_and_trim_text(expected_hoc_global_variables))); } THEN("Placeholder nrn_cur function is printed") { - std::string expected_placeholder_nrn_cur = R"(void nrn_cur() {})"; + std::string expected_placeholder_nrn_cur = + R"(void nrn_cur_pas_test(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int _type) {})"; REQUIRE_THAT(generated, ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_cur))); } THEN("Placeholder nrn_state function is printed") { - std::string expected_placeholder_nrn_state = R"(void nrn_state() {})"; + std::string expected_placeholder_nrn_state = + R"(void nrn_state_pas_test(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int _type) {})"; REQUIRE_THAT(generated, ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_state))); } THEN("Placeholder registration function is printed") { - std::string expected_placeholder_reg = R"(/** register channel with the simulator */ - void __test_reg() { + std::string expected_placeholder_reg = R"CODE(/** register channel with the simulator */ + extern "C" void __test_reg() { /* s */ _slist1[0] = {4, 0}; /* Ds */ _dlist1[0] = {7, 0}; - int mech_type = nrn_get_mechtype(mechanism_info[1]); + register_mech(mechanism_info, nrn_alloc_pas_test, nrn_cur_pas_test, nrn_jacob_pas_test, nrn_state_pas_test, nrn_init_pas_test, hoc_nrnpointerindex, 1); + + mech_type = nrn_get_mechtype(mechanism_info[1]); _nrn_mechanism_register_data_fields(mech_type, _nrn_mechanism_field{"g"} /* 0 */, _nrn_mechanism_field{"e"} /* 1 */, @@ -246,7 +251,7 @@ void _nrn_mechanism_register_data_fields(Args&&... args) { _nrn_mechanism_field{"g_unused"} /* 9 */ ); - })"; + })CODE"; REQUIRE_THAT(generated, ContainsSubstring(reindent_and_trim_text(expected_placeholder_reg)));