Skip to content

Commit

Permalink
Add more logic for POINT_PROCESS registration function (#1106)
Browse files Browse the repository at this point in the history
* Added hoc_register_prop_size and hoc_register_dparam_semantics
* Implemented big part of the nrn_alloc to make POINT_PROCESS work
* Added test for POINT_PROCESS location
* Added test for parameters
  • Loading branch information
iomaganaris authored Feb 10, 2024
1 parent d1290e0 commit 545b7b3
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ void CodegenCppVisitor::setup(const Program& node) {

update_index_semantics();
rename_function_arguments();

info.semantic_variable_count = int_variables_size();
}

std::string CodegenCppVisitor::compute_method_name(BlockType type) const {
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/codegen_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ struct CodegenInfo {
/// typically equal to number of primes
int num_equations = 0;

/// number of semantic variables
int semantic_variable_count;

/// True if we have to emit CVODE code
/// TODO: Figure out when this needs to be true
bool emit_cvode = false;
Expand Down
253 changes: 234 additions & 19 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,99 @@ bool CodegenNeuronCppVisitor::optimize_ion_variable_copies() const {
/****************************************************************************************/


void CodegenNeuronCppVisitor::print_point_process_function_definitions() {
if (info.point_process) {
printer->add_line("/* Point Process specific functions */");
printer->add_multi_line(R"CODE(
static void* _hoc_create_pnt(Object* _ho) {
return create_point_process(_pointtype, _ho);
}
)CODE");
printer->push_block("static void _hoc_destroy_pnt(void* _vptr)");
if (info.is_watch_used() || info.for_netcon_used) {
printer->add_line("Prop* _prop = ((Point_process*)_vptr)->prop;");
}
if (info.is_watch_used()) {
printer->push_block("if (_prop)");
printer->fmt_line("_nrn_free_watch(_nrn_mechanism_access_dparam(_prop), {}, {});",
info.watch_count,
info.is_watch_used());
printer->pop_block();
}
if (info.for_netcon_used) {
printer->push_block("if (_prop)");
printer->fmt_line(
"_nrn_free_fornetcon(&(_nrn_mechanism_access_dparam(_prop)[_fnc_index].literal_"
"value<void*>()));");
printer->pop_block();
}
printer->add_line("destroy_point_process(_vptr);");
printer->pop_block();
printer->add_multi_line(R"CODE(
static double _hoc_loc_pnt(void* _vptr) {
return loc_point_process(_pointtype, _vptr);
}
)CODE");
printer->add_multi_line(R"CODE(
static double _hoc_has_loc(void* _vptr) {
return has_loc_point(_vptr);
}
)CODE");
printer->add_multi_line(R"CODE(
static double _hoc_get_loc_pnt(void* _vptr) {
return (get_loc_point_process(_vptr));
}
)CODE");
}
}


void CodegenNeuronCppVisitor::print_setdata_functions() {
printer->add_line("/* Neuron setdata functions */");
printer->add_line("extern void _nrn_setdata_reg(int, void(*)(Prop*));");
printer->push_block("static void _setdata(Prop* _prop)");
if (!info.point_process) {
printer->add_multi_line(R"CODE(
_extcall_prop = _prop;
_prop_id = _nrn_get_prop_id(_prop);
)CODE");
}
if (!info.vectorize) {
printer->add_multi_line(R"CODE(
neuron::legacy::set_globals_from_prop(_prop, _ml_real, _ml, _iml);
_ppvar = _nrn_mechanism_access_dparam(_prop);
)CODE");
}
printer->pop_block();

if (info.point_process) {
printer->push_block("static void _hoc_setdata(void* _vptr)");
printer->add_multi_line(R"CODE(
Prop* _prop;
_prop = ((Point_process*)_vptr)->prop;
_setdata(_prop);
)CODE");
} else {
printer->push_block("static void _hoc_setdata()");
printer->add_multi_line(R"CODE(
Prop *_prop, *hoc_getdata_range(int);
_prop = hoc_getdata_range(mech_type);
_setdata(_prop);
hoc_retpushx(1.);
)CODE");
}
printer->pop_block();
}


/// TODO: Edit for NEURON
void CodegenNeuronCppVisitor::print_function_prototypes() {
if (info.functions.empty() && info.procedures.empty()) {
return;
}
printer->add_newline(2);

print_point_process_function_definitions();
print_setdata_functions();

/// TODO: Add mechanism function and procedures declarations
}


Expand Down Expand Up @@ -439,8 +527,22 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in
printer->fmt_line("static neuron::container::field_index _slist1[{0}], _dlist1[{0}];",
info.primes_size);
}

for (const auto& ion: info.ions) {
printer->fmt_line("static Symbol* _{}_sym;", ion.name);
}

printer->add_line("static int mech_type;");

if (info.point_process) {
printer->add_line("static int _pointtype;");
} else {
printer->add_multi_line(R"CODE(
static Prop* _extcall_prop;
/* _prop_id kind of shadows _extcall_prop to allow validity checking. */
static _nrn_non_owning_id_without_container _prop_id{};)CODE");
}

printer->fmt_line("static int {} = {};",
naming::NRN_POINTERINDEX,
info.pointer_variables.size() > 0
Expand All @@ -457,10 +559,6 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in
// TODO implement these when needed.
}

if (info.point_process) {
throw std::runtime_error("Not implemented, global point process.");
}

if (!info.vectorize && !info.top_local_variables.empty()) {
throw std::runtime_error("Not implemented, global vectorize something.");
}
Expand Down Expand Up @@ -564,6 +662,30 @@ void CodegenNeuronCppVisitor::print_global_variables_for_hoc() {
printer->add_line("{nullptr, nullptr, 0}");
printer->decrease_indent();
printer->add_line("};");

printer->add_newline(2);
printer->add_line("/* connect user functions to hoc names */");
printer->add_line("static VoidFunc hoc_intfunc[] = {");
printer->increase_indent();
if (info.point_process) {
printer->add_line("{0, 0}");
printer->decrease_indent();
printer->add_line("};");
printer->add_line("static Member_func _member_func[] = {");
printer->increase_indent();
printer->add_multi_line(R"CODE(
{"loc", _hoc_loc_pnt},
{"has_loc", _hoc_has_loc},
{"get_loc", _hoc_get_loc_pnt},)CODE");
} else {
printer->fmt_line("{{\"setdata_{}\", _hoc_setdata}},", info.mod_suffix);
}

/// TODO: Add _hoc_procedures and _hoc_functions

printer->add_line("{0, 0}");
printer->decrease_indent();
printer->add_line("};");
}

void CodegenNeuronCppVisitor::print_make_instance() const {
Expand Down Expand Up @@ -616,6 +738,23 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
printer->add_line("/** register channel with the simulator */");
printer->fmt_push_block("extern \"C\" void _{}_reg()", info.mod_file);
printer->add_line("_initlists();");

printer->add_newline();

for (const auto& ion: info.ions) {
printer->fmt_line("ion_reg(\"{}\", {});", ion.name, "-10000.");
}
printer->add_newline();

if (info.diam_used) {
printer->add_line("_morphology_sym = hoc_lookup(\"morphology\");");
printer->add_newline();
}

for (const auto& ion: info.ions) {
printer->fmt_line("_{0}_sym = hoc_lookup(\"{0}_ion\");", ion.name);
}

printer->add_newline();

const auto compute_functions_parameters =
Expand All @@ -632,18 +771,34 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
method_name(naming::NRN_INIT_METHOD),
naming::NRN_POINTERINDEX,
1 + info.thread_data_index);
printer->fmt_line("register_mech({});", register_mech_args);
if (info.point_process) {
printer->fmt_line(
"_pointtype = point_register_mech({}, _hoc_create_pnt, _hoc_destroy_pnt, "
"_member_func);",
register_mech_args);
} else {
printer->fmt_line("register_mech({});", register_mech_args);
}

// type related information
/// type related information
printer->add_newline();
printer->fmt_line("mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name());

// More things to add here
/// Call _nrn_mechanism_register_data_fields() with the correct arguments
/// Geenerated code follows the style underneath
///
/// _nrn_mechanism_register_data_fields(mech_type,
/// _nrn_mechanism_field<double>{"var_name"}, /* float var index 0 */
/// ...
/// );
///
/// TODO: More things to add here
printer->add_line("_nrn_mechanism_register_data_fields(mech_type,");
printer->increase_indent();
const auto codegen_float_variables_size = codegen_float_variables.size();

const auto codegen_float_variables_size = codegen_float_variables.size();
std::vector<std::string> mech_register_args;

for (int i = 0; i < codegen_float_variables_size; ++i) {
const auto& float_var = codegen_float_variables[i];
if (float_var->is_array()) {
Expand All @@ -666,8 +821,10 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
throw std::runtime_error("Broken logic.");
}

auto type = (name == naming::POINT_PROCESS_VARIABLE) ? "Point_process*" : "double*";
mech_register_args.push_back(
fmt::format("_nrn_mechanism_field<double*>{{\"{}\", \"{}\"}} /* {} */",
fmt::format("_nrn_mechanism_field<{}>{{\"{}\", \"{}\"}} /* {} */",
type,
name,
info.semantics[i].name,
i));
Expand Down Expand Up @@ -724,7 +881,9 @@ void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_ini
}
for (auto& var: codegen_int_variables) {
const auto& name = var.symbol->get_name();
if (var.is_index || var.is_integer) {
if (name == naming::POINT_PROCESS_VARIABLE) {
continue;
} else if (var.is_index || var.is_integer) {
auto qualifier = var.is_constant ? "const " : "";
printer->fmt_line("{}{}* const* {}{};", qualifier, int_type, name, value_initialize);
} else {
Expand Down Expand Up @@ -817,15 +976,64 @@ void CodegenNeuronCppVisitor::print_nrn_destructor() {
/// TODO: Print the equivalent of `nrn_alloc_<mech_name>`
void CodegenNeuronCppVisitor::print_nrn_alloc() {
printer->add_newline(2);

auto method = method_name(naming::NRN_ALLOC_METHOD);
printer->fmt_push_block("static void {}(Prop* _prop)", method);
printer->add_multi_line(R"CODE(
Prop *prop_ion{};
Datum *_ppvar{};
)CODE");

const auto codegen_int_variables_size = codegen_int_variables.size();
if (info.point_process) {
printer->push_block("if (nrn_point_prop_)");
printer->add_multi_line(R"CODE(
_nrn_mechanism_access_alloc_seq(_prop) = _nrn_mechanism_access_alloc_seq(nrn_point_prop_);
_ppvar = _nrn_mechanism_access_dparam(nrn_point_prop_);
)CODE");
printer->chain_block("else");
}
if (info.semantic_variable_count) {
printer->fmt_line("_ppvar = nrn_prop_datum_alloc(mech_type, {}, _prop);",
info.semantic_variable_count);
printer->add_line("_nrn_mechanism_access_dparam(_prop) = _ppvar;");
}
printer->add_multi_line(R"CODE(
_nrn_mechanism_cache_instance _ml_real{_prop};
auto* const _ml = &_ml_real;
size_t const _iml{};
)CODE");
printer->fmt_line("assert(_nrn_mechanism_get_num_vars(_prop) == {});",
codegen_float_variables.size());
if (float_variables_size()) {
printer->add_line("/*initialize range parameters*/");
for (const auto& var: info.range_parameter_vars) {
if (var->is_array()) {
continue;
}
const auto& var_name = var->get_name();
printer->fmt_line("_ml->template fpfield<{}>(_iml) = {}; /* {} */",
position_of_float_var(var_name),
*var->get_value(),
var_name);
}
}
if (info.point_process) {
printer->pop_block();
}

// TODO number of datum is the number of integer vars.
printer->fmt_line("Datum *_ppvar = nrn_prop_datum_alloc(mech_type, {}, _prop);",
codegen_int_variables_size);
printer->fmt_line("_nrn_mechanism_access_dparam(_prop) = _ppvar;");
if (info.semantic_variable_count) {
printer->add_line("_nrn_mechanism_access_dparam(_prop) = _ppvar;");
}

if (info.diam_used) {
throw std::runtime_error("Diam allocation not implemented.");
}

if (info.area_used) {
throw std::runtime_error("Area allocation not implemented.");
}

const auto codegen_int_variables_size = codegen_int_variables.size();

for (const auto& ion: info.ions) {
printer->fmt_line("Symbol * {}_sym = hoc_lookup(\"{}_ion\");", ion.name, ion.name);
Expand All @@ -847,10 +1055,12 @@ void CodegenNeuronCppVisitor::print_nrn_alloc() {
ion.name,
ion.variable_index(ion_var_name));
}
// }
//}
}
}

/// TODO: CONSTRUCTOR call

printer->pop_block();
}

Expand Down Expand Up @@ -1014,6 +1224,10 @@ void CodegenNeuronCppVisitor::print_mechanism_variables_macros() {
}
} // namespace
)CODE");

if (info.point_process) {
printer->add_line("extern Prop* nrn_point_prop_;");
}
/// TODO: More prints here?
}

Expand Down Expand Up @@ -1076,6 +1290,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines() {
print_mechanism_info();
print_data_structures(true);
print_nrn_alloc();
print_function_prototypes();
print_global_variables_for_hoc();
print_compute_functions(); // only nrn_cur and nrn_state
print_sdlists_init(true);
Expand Down
14 changes: 14 additions & 0 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_net_event_call(const ast::FunctionCall& node) override;


/**
* Print POINT_PROCESS related functions
* Wrap external NEURON functions related to POINT_PROCESS mechanisms
*
*/
void print_point_process_function_definitions();

/**
* Print NEURON functions related to setting global variables of the mechanism
*
*/
void print_setdata_functions();


/**
* Print function and procedures prototype declaration
*/
Expand Down
Loading

0 comments on commit 545b7b3

Please sign in to comment.