Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mechanism registration function for non POINT_PROCESSes #1111

Merged
merged 5 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/codegen/codegen_naming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ static constexpr char NRN_STATE_METHOD[] = "nrn_state";
/// nrn_cur method in generated code
static constexpr char NRN_CUR_METHOD[] = "nrn_cur";

/// nrn_jacob method in generated code
static constexpr char NRN_JACOB_METHOD[] = "nrn_jacob";

/// nrn_watch_check method in generated c++ file
static constexpr char NRN_WATCH_CHECK_METHOD[] = "nrn_watch_check";

Expand All @@ -164,6 +167,9 @@ static constexpr char THREAD_ARGS_PROTO[] = "_threadargsproto_";
/// prefix for ion variable
static constexpr char ION_VARNAME_PREFIX[] = "ion_";

/// hoc_nrnpointerindex name
static constexpr char NRN_POINTERINDEX[] = "hoc_nrnpointerindex";


/// commonly used variables in verbatim block and how they
/// should be mapped to new code generation backends
Expand Down
86 changes: 79 additions & 7 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,19 @@ void CodegenNeuronCppVisitor::print_sdlists_init(bool print_initializers) {

void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_initializers) {
/// TODO: Print only global variables printed in NEURON
printer->add_line();
printer->add_newline(2);
printer->add_line("/* NEURON global variables */");
if (info.primes_size != 0) {
printer->fmt_line("static neuron::container::field_index _slist1[{0}], _dlist1[{0}];",
info.primes_size);
}
printer->add_line("static int mech_type;");

printer->fmt_line("static int {} = {};",
naming::NRN_POINTERINDEX,
info.pointer_variables.size() > 0
? static_cast<int>(info.pointer_variables.size())
: -1);
}


Expand Down Expand Up @@ -388,11 +395,29 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
/// TODO: Write this according to NEURON
printer->add_newline(2);
printer->add_line("/** register channel with the simulator */");
printer->fmt_push_block("void _{}_reg()", info.mod_file);
printer->fmt_push_block("extern \"C\" void _{}_reg()", info.mod_file);
print_sdlists_init(true);
printer->add_newline();

const auto compute_functions_parameters =
breakpoint_exist()
? fmt::format("{}, {}, {}",
nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "nullptr",
method_name(naming::NRN_JACOB_METHOD),
nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "nullptr")
: "nullptr, nullptr, nullptr";
const auto register_mech_args = fmt::format("{}, {}, {}, {}, {}, {}",
get_channel_info_var_name(),
method_name(naming::NRN_ALLOC_METHOD),
compute_functions_parameters,
method_name(naming::NRN_INIT_METHOD),
naming::NRN_POINTERINDEX,
1 + info.thread_data_index);
printer->fmt_line("register_mech({});", register_mech_args);

// type related information
printer->add_newline();
printer->fmt_line("int mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name());
printer->fmt_line("mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name());

// More things to add here
printer->add_line("_nrn_mechanism_register_data_fields(mech_type,");
Expand Down Expand Up @@ -425,6 +450,7 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {


void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) {
printer->add_newline(2);
printer->add_line("/* NEURON RANGE variables macro definitions */");
for (auto i = 0; i < codegen_float_variables.size(); ++i) {
const auto float_var = codegen_float_variables[i];
Expand Down Expand Up @@ -454,6 +480,34 @@ void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type,
}


void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
codegen = true;
printer->add_newline(2);
printer->add_line("/** initialize channel */");

printer->fmt_line(
"static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* "
"_ml_arg, int _type) {{}}",
method_name(naming::NRN_INIT_METHOD));

codegen = false;
}


void CodegenNeuronCppVisitor::print_nrn_jacob() {
codegen = true;
printer->add_newline(2);
printer->add_line("/** nrn_jacob function */");

printer->fmt_line(
"static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* "
"_nt, Memb_list* _ml_arg, int _type) {{}}",
method_name(naming::NRN_JACOB_METHOD));

codegen = false;
}


/// TODO: Edit for NEURON
void CodegenNeuronCppVisitor::print_nrn_constructor() {
return;
Expand All @@ -468,7 +522,11 @@ void CodegenNeuronCppVisitor::print_nrn_destructor() {

/// TODO: Print the equivalent of `nrn_alloc_<mech_name>`
void CodegenNeuronCppVisitor::print_nrn_alloc() {
return;
printer->add_newline(2);
auto method = method_name(naming::NRN_ALLOC_METHOD);
printer->fmt_push_block("static void {}(Prop* _prop)", method);
printer->add_line("// do nothing");
printer->pop_block();
}


Expand All @@ -483,8 +541,13 @@ void CodegenNeuronCppVisitor::print_nrn_state() {
return;
}
codegen = true;
printer->add_newline(2);

printer->fmt_line(
"void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* "
"_ml_arg, int _type) {{}}",
method_name(naming::NRN_STATE_METHOD));

printer->add_line("void nrn_state() {}");
/// TODO: Fill in

codegen = false;
Expand Down Expand Up @@ -533,16 +596,21 @@ void CodegenNeuronCppVisitor::print_nrn_cur() {
}

codegen = true;
printer->add_newline(2);

printer->fmt_line(
"void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, "
"int _type) {{}}",
method_name(naming::NRN_CUR_METHOD));

printer->add_line("void nrn_cur() {}");
/// TODO: Fill in

codegen = false;
}


/****************************************************************************************/
/* Main code printing entry points */
/* Main code printing entry points */
/****************************************************************************************/

void CodegenNeuronCppVisitor::print_headers_include() {
Expand Down Expand Up @@ -590,6 +658,7 @@ void CodegenNeuronCppVisitor::print_mechanism_variables_macros() {
using _nrn_model_sorted_token = neuron::model_sorted_token;
using _nrn_mechanism_cache_range = neuron::cache::MechanismRange<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_non_owning_id_without_container = neuron::container::non_owning_identifier_without_container;
template <typename T>
using _nrn_mechanism_field = neuron::mechanism::field<T>;
template <typename... Args>
Expand Down Expand Up @@ -641,8 +710,10 @@ void CodegenNeuronCppVisitor::print_g_unused() const {

/// TODO: Edit for NEURON
void CodegenNeuronCppVisitor::print_compute_functions() {
print_nrn_init();
print_nrn_cur();
print_nrn_state();
print_nrn_jacob();
}


Expand All @@ -657,6 +728,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines() {
print_prcellstate_macros();
print_mechanism_info();
print_data_structures(true);
print_nrn_alloc();
print_global_variables_for_hoc();
print_compute_functions(); // only nrn_cur and nrn_state
print_mechanism_register();
Expand Down
14 changes: 14 additions & 0 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
const std::string& function_name = "") override;


/**
* Print the \c nrn\_init function definition
* \param skip_init_check \c true to generate code executing the initialization conditionally
*/
void print_nrn_init(bool skip_init_check = true);


/**
* Print nrn_constructor function definition
*
Expand All @@ -382,6 +389,13 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_nrn_alloc() override;


/**
* Print nrn_jacob function definition
*
*/
void print_nrn_jacob();


/****************************************************************************************/
/* Print nrn_state routine */
/****************************************************************************************/
Expand Down
19 changes: 12 additions & 7 deletions test/unit/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ std::string get_neuron_cpp_code(const std::string& nmodl_text,
std::stringstream ss;
auto cvisitor = create_neuron_cpp_visitor(ast, nmodl_text, ss);
cvisitor->visit_program(*ast);
return reindent_text(ss.str());
return ss.str();
}


Expand Down Expand Up @@ -138,6 +138,7 @@ using _nrn_mechanism_std_vector = std::vector<T>;
using _nrn_model_sorted_token = neuron::model_sorted_token;
using _nrn_mechanism_cache_range = neuron::cache::MechanismRange<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_non_owning_id_without_container = neuron::container::non_owning_identifier_without_container;
template <typename T>
using _nrn_mechanism_field = neuron::mechanism::field<T>;
template <typename... Args>
Expand Down Expand Up @@ -213,26 +214,30 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
ContainsSubstring(reindent_and_trim_text(expected_hoc_global_variables)));
}
THEN("Placeholder nrn_cur function is printed") {
std::string expected_placeholder_nrn_cur = R"(void nrn_cur() {})";
std::string expected_placeholder_nrn_cur =
R"(void nrn_cur_pas_test(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int _type) {})";

REQUIRE_THAT(generated,
ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_cur)));
}
THEN("Placeholder nrn_state function is printed") {
std::string expected_placeholder_nrn_state = R"(void nrn_state() {})";
std::string expected_placeholder_nrn_state =
R"(void nrn_state_pas_test(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int _type) {})";

REQUIRE_THAT(generated,
ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_state)));
}
THEN("Placeholder registration function is printed") {
std::string expected_placeholder_reg = R"(/** register channel with the simulator */
void __test_reg() {
std::string expected_placeholder_reg = R"CODE(/** register channel with the simulator */
extern "C" void __test_reg() {
/* s */
_slist1[0] = {4, 0};
/* Ds */
_dlist1[0] = {7, 0};

int mech_type = nrn_get_mechtype(mechanism_info[1]);
register_mech(mechanism_info, nrn_alloc_pas_test, nrn_cur_pas_test, nrn_jacob_pas_test, nrn_state_pas_test, nrn_init_pas_test, hoc_nrnpointerindex, 1);

mech_type = nrn_get_mechtype(mechanism_info[1]);
_nrn_mechanism_register_data_fields(mech_type,
_nrn_mechanism_field<double>{"g"} /* 0 */,
_nrn_mechanism_field<double>{"e"} /* 1 */,
Expand All @@ -246,7 +251,7 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
_nrn_mechanism_field<double>{"g_unused"} /* 9 */
);

})";
})CODE";

REQUIRE_THAT(generated,
ContainsSubstring(reindent_and_trim_text(expected_placeholder_reg)));
Expand Down
Loading