diff --git a/src/codegen/CMakeLists.txt b/src/codegen/CMakeLists.txt index f77e34e34e..bbafaf44dd 100644 --- a/src/codegen/CMakeLists.txt +++ b/src/codegen/CMakeLists.txt @@ -5,6 +5,8 @@ add_library( codegen STATIC codegen_acc_visitor.cpp codegen_transform_visitor.cpp + codegen_coreneuron_cpp_visitor.cpp + codegen_neuron_cpp_visitor.cpp codegen_cpp_visitor.cpp codegen_compatibility_visitor.cpp codegen_helper_visitor.cpp diff --git a/src/codegen/codegen_acc_visitor.cpp b/src/codegen/codegen_acc_visitor.cpp index 03a7589ff7..68ff009471 100644 --- a/src/codegen/codegen_acc_visitor.cpp +++ b/src/codegen/codegen_acc_visitor.cpp @@ -84,7 +84,7 @@ std::string CodegenAccVisitor::backend_name() const { void CodegenAccVisitor::print_memory_allocation_routine() const { // memory for artificial cells should be allocated on CPU if (info.artificial_cell) { - CodegenCppVisitor::print_memory_allocation_routine(); + CodegenCoreneuronCppVisitor::print_memory_allocation_routine(); return; } printer->add_newline(2); diff --git a/src/codegen/codegen_acc_visitor.hpp b/src/codegen/codegen_acc_visitor.hpp index 1a3f60503f..80bcba9d68 100644 --- a/src/codegen/codegen_acc_visitor.hpp +++ b/src/codegen/codegen_acc_visitor.hpp @@ -12,7 +12,7 @@ * \brief \copybrief nmodl::codegen::CodegenAccVisitor */ -#include "codegen/codegen_cpp_visitor.hpp" +#include "codegen/codegen_coreneuron_cpp_visitor.hpp" namespace nmodl { @@ -27,7 +27,7 @@ namespace codegen { * \class CodegenAccVisitor * \brief %Visitor for printing C++ code with OpenACC backend */ -class CodegenAccVisitor: public CodegenCppVisitor { +class CodegenAccVisitor: public CodegenCoreneuronCppVisitor { protected: /// name of the code generation backend std::string backend_name() const override; @@ -140,13 +140,13 @@ class CodegenAccVisitor: public CodegenCppVisitor { const std::string& output_dir, const std::string& float_type, bool optimize_ionvar_copies) - : CodegenCppVisitor(mod_file, output_dir, float_type, optimize_ionvar_copies) {} + : CodegenCoreneuronCppVisitor(mod_file, output_dir, float_type, optimize_ionvar_copies) {} CodegenAccVisitor(const std::string& mod_file, std::ostream& stream, const std::string& float_type, bool optimize_ionvar_copies) - : CodegenCppVisitor(mod_file, stream, float_type, optimize_ionvar_copies) {} + : CodegenCoreneuronCppVisitor(mod_file, stream, float_type, optimize_ionvar_copies) {} }; /** \} */ // end of codegen_backends diff --git a/src/codegen/codegen_coreneuron_cpp_visitor.cpp b/src/codegen/codegen_coreneuron_cpp_visitor.cpp new file mode 100644 index 0000000000..177b0b50c1 --- /dev/null +++ b/src/codegen/codegen_coreneuron_cpp_visitor.cpp @@ -0,0 +1,3900 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "codegen/codegen_coreneuron_cpp_visitor.hpp" + +#include +#include +#include +#include +#include + +#include "ast/all.hpp" +#include "codegen/codegen_helper_visitor.hpp" +#include "codegen/codegen_naming.hpp" +#include "codegen/codegen_utils.hpp" +#include "config/config.h" +#include "lexer/token_mapping.hpp" +#include "parser/c11_driver.hpp" +#include "utils/logger.hpp" +#include "utils/string_utils.hpp" +#include "visitors/defuse_analyze_visitor.hpp" +#include "visitors/rename_visitor.hpp" +#include "visitors/symtab_visitor.hpp" +#include "visitors/var_usage_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +namespace nmodl { +namespace codegen { + +using namespace ast; + +using visitor::DefUseAnalyzeVisitor; +using visitor::DUState; +using visitor::RenameVisitor; +using visitor::SymtabVisitor; +using visitor::VarUsageVisitor; + +using symtab::syminfo::NmodlType; + +extern const std::regex regex_special_chars; + +/****************************************************************************************/ +/* Generic information getters */ +/****************************************************************************************/ + + +std::string CodegenCoreneuronCppVisitor::backend_name() const { + return "C++ (api-compatibility)"; +} + + +std::string CodegenCoreneuronCppVisitor::simulator_name() { + return "CoreNEURON"; +} + + +/****************************************************************************************/ +/* Common helper routines accross codegen functions */ +/****************************************************************************************/ + + +int CodegenCoreneuronCppVisitor::position_of_float_var(const std::string& name) const { + int index = 0; + for (const auto& var: codegen_float_variables) { + if (var->get_name() == name) { + return index; + } + index += var->get_length(); + } + throw std::logic_error(name + " variable not found"); +} + + +int CodegenCoreneuronCppVisitor::position_of_int_var(const std::string& name) const { + int index = 0; + for (const auto& var: codegen_int_variables) { + if (var.symbol->get_name() == name) { + return index; + } + index += var.symbol->get_length(); + } + throw std::logic_error(name + " variable not found"); +} + + +/** + * \details Current variable used in breakpoint block could be local variable. + * In this case, neuron has already renamed the variable name by prepending + * "_l". In our implementation, the variable could have been renamed by + * one of the pass. And hence, we search all local variables and check if + * the variable is renamed. Note that we have to look into the symbol table + * of statement block and not breakpoint. + */ +std::string CodegenCoreneuronCppVisitor::breakpoint_current(std::string current) const { + auto breakpoint = info.breakpoint_node; + if (breakpoint == nullptr) { + return current; + } + auto symtab = breakpoint->get_statement_block()->get_symbol_table(); + auto variables = symtab->get_variables_with_properties(NmodlType::local_var); + for (const auto& var: variables) { + auto renamed_name = var->get_name(); + auto original_name = var->get_original_name(); + if (current == original_name) { + current = renamed_name; + break; + } + } + return current; +} + + +/** + * \details Depending upon the block type, we have to print read/write ion variables + * during code generation. Depending on block/procedure being printed, this + * method return statements as vector. As different code backends could have + * different variable names, we rely on backend-specific read_ion_variable_name + * and write_ion_variable_name method which will be overloaded. + */ +std::vector CodegenCoreneuronCppVisitor::ion_read_statements(BlockType type) const { + if (optimize_ion_variable_copies()) { + return ion_read_statements_optimized(type); + } + std::vector statements; + for (const auto& ion: info.ions) { + auto name = ion.name; + for (const auto& var: ion.reads) { + auto const iter = std::find(ion.implicit_reads.begin(), ion.implicit_reads.end(), var); + if (iter != ion.implicit_reads.end()) { + continue; + } + auto variable_names = read_ion_variable_name(var); + auto first = get_variable_name(variable_names.first); + auto second = get_variable_name(variable_names.second); + statements.push_back(fmt::format("{} = {};", first, second)); + } + for (const auto& var: ion.writes) { + if (ion.is_ionic_conc(var)) { + auto variables = read_ion_variable_name(var); + auto first = get_variable_name(variables.first); + auto second = get_variable_name(variables.second); + statements.push_back(fmt::format("{} = {};", first, second)); + } + } + } + return statements; +} + + +std::vector CodegenCoreneuronCppVisitor::ion_read_statements_optimized( + BlockType type) const { + std::vector statements; + for (const auto& ion: info.ions) { + for (const auto& var: ion.writes) { + if (ion.is_ionic_conc(var)) { + auto variables = read_ion_variable_name(var); + auto first = "ionvar." + variables.first; + const auto& second = get_variable_name(variables.second); + statements.push_back(fmt::format("{} = {};", first, second)); + } + } + } + return statements; +} + +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +std::vector CodegenCoreneuronCppVisitor::ion_write_statements(BlockType type) { + std::vector statements; + for (const auto& ion: info.ions) { + std::string concentration; + auto name = ion.name; + for (const auto& var: ion.writes) { + auto variable_names = write_ion_variable_name(var); + if (ion.is_ionic_current(var)) { + if (type == BlockType::Equation) { + auto current = breakpoint_current(var); + auto lhs = variable_names.first; + auto op = "+="; + auto rhs = get_variable_name(current); + if (info.point_process) { + auto area = get_variable_name(naming::NODE_AREA_VARIABLE); + rhs += fmt::format("*(1.e2/{})", area); + } + statements.push_back(ShadowUseStatement{lhs, op, rhs}); + } + } else { + if (!ion.is_rev_potential(var)) { + concentration = var; + } + auto lhs = variable_names.first; + auto op = "="; + auto rhs = get_variable_name(variable_names.second); + statements.push_back(ShadowUseStatement{lhs, op, rhs}); + } + } + + if (type == BlockType::Initial && !concentration.empty()) { + int index = 0; + if (ion.is_intra_cell_conc(concentration)) { + index = 1; + } else if (ion.is_extra_cell_conc(concentration)) { + index = 2; + } else { + /// \todo Unhandled case in neuron implementation + throw std::logic_error(fmt::format("codegen error for {} ion", ion.name)); + } + auto ion_type_name = fmt::format("{}_type", ion.name); + auto lhs = fmt::format("int {}", ion_type_name); + auto op = "="; + auto rhs = get_variable_name(ion_type_name); + statements.push_back(ShadowUseStatement{lhs, op, rhs}); + auto statement = conc_write_statement(ion.name, concentration, index); + statements.push_back(ShadowUseStatement{statement, "", ""}); + } + } + return statements; +} + + +/** + * \details Often top level verbatim blocks use variables with old names. + * Here we process if we are processing verbatim block at global scope. + */ +std::string CodegenCoreneuronCppVisitor::process_verbatim_token(const std::string& token) { + const std::string& name = token; + + /* + * If given token is procedure name and if it's defined + * in the current mod file then it must be replaced + */ + if (program_symtab->is_method_defined(token)) { + return method_name(token); + } + + /* + * Check if token is commongly used variable name in + * verbatim block like nt, \c \_threadargs etc. If so, replace + * it and return. + */ + auto new_name = replace_if_verbatim_variable(name); + if (new_name != name) { + return get_variable_name(new_name, false); + } + + /* + * For top level verbatim blocks we shouldn't replace variable + * names with Instance because arguments are provided from coreneuron + * and they are missing inst. + */ + auto use_instance = !printing_top_verbatim_blocks; + return get_variable_name(token, use_instance); +} + + +bool CodegenCoreneuronCppVisitor::ion_variable_struct_required() const { + return optimize_ion_variable_copies() && info.ion_has_write_variable(); +} + + +/** + * \details This can be override in the backend. For example, parameters can be constant + * except in INITIAL block where they are set to 0. As initial block is/can be + * executed on c++/cpu backend, gpu backend can mark the parameter as constant. + */ +bool CodegenCoreneuronCppVisitor::is_constant_variable(const std::string& name) const { + auto symbol = program_symtab->lookup_in_scope(name); + bool is_constant = false; + if (symbol != nullptr) { + // per mechanism ion variables needs to be updated from neuron/coreneuron values + if (info.is_ion_variable(name)) { + is_constant = false; + } + // for parameter variable to be const, make sure it's write count is 0 + // and it's not used in the verbatim block + else if (symbol->has_any_property(NmodlType::param_assign) && + info.variables_in_verbatim.find(name) == info.variables_in_verbatim.end() && + symbol->get_write_count() == 0) { + is_constant = true; + } + } + return is_constant; +} + + +/****************************************************************************************/ +/* Backend specific routines */ +/****************************************************************************************/ + +std::string CodegenCoreneuronCppVisitor::get_parameter_str(const ParamVector& params) { + std::string str; + bool is_first = true; + for (const auto& param: params) { + if (is_first) { + is_first = false; + } else { + str += ", "; + } + str += fmt::format("{}{} {}{}", + std::get<0>(param), + std::get<1>(param), + std::get<2>(param), + std::get<3>(param)); + } + return str; +} + + +void CodegenCoreneuronCppVisitor::print_deriv_advance_flag_transfer_to_device() const { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_device_atomic_capture_annotation() const { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_net_send_buf_count_update_to_host() const { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_net_send_buf_update_to_host() const { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_net_send_buf_count_update_to_device() const { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_dt_update_to_device() const { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_device_stream_wait() const { + // backend specific, do nothing +} + + +/** + * \details Each kernel such as \c nrn\_init, \c nrn\_state and \c nrn\_cur could be offloaded + * to accelerator. In this case, at very top level, we print pragma + * for data present. For example: + * + * \code{.cpp} + * void nrn_state(...) { + * #pragma acc data present (nt, ml...) + * { + * + * } + * } + * \endcode + */ +void CodegenCoreneuronCppVisitor::print_kernel_data_present_annotation_block_begin() { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_kernel_data_present_annotation_block_end() { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_net_init_acc_serial_annotation_block_begin() { + // backend specific, do nothing +} + + +void CodegenCoreneuronCppVisitor::print_net_init_acc_serial_annotation_block_end() { + // backend specific, do nothing +} + + +/** + * \details Depending programming model and compiler, we print compiler hint + * for parallelization. For example: + * + * \code + * #pragma omp simd + * for(int id = 0; id < nodecount; id++) { + * + * #pragma acc parallel loop + * for(int id = 0; id < nodecount; id++) { + * \endcode + */ +void CodegenCoreneuronCppVisitor::print_channel_iteration_block_parallel_hint( + BlockType /* type */, + const ast::Block* block) { + // ivdep allows SIMD parallelisation of a block/loop but doesn't provide + // a standard mechanism for atomics. Also, even with openmp 5.0, openmp + // atomics do not enable vectorisation under "omp simd" (gives compiler + // error with gcc < 9 if atomic and simd pragmas are nested). So, emit + // ivdep/simd pragma when no MUTEXLOCK/MUTEXUNLOCK/PROTECT statements + // are used in the given block. + std::vector> nodes; + if (block) { + nodes = collect_nodes(*block, + {ast::AstNodeType::PROTECT_STATEMENT, + ast::AstNodeType::MUTEX_LOCK, + ast::AstNodeType::MUTEX_UNLOCK}); + } + if (nodes.empty()) { + printer->add_line("#pragma omp simd"); + printer->add_line("#pragma ivdep"); + } +} + + +bool CodegenCoreneuronCppVisitor::nrn_cur_reduction_loop_required() { + return info.point_process; +} + + +void CodegenCoreneuronCppVisitor::print_rhs_d_shadow_variables() { + if (info.point_process) { + printer->fmt_line("double* shadow_rhs = nt->{};", naming::NTHREAD_RHS_SHADOW); + printer->fmt_line("double* shadow_d = nt->{};", naming::NTHREAD_D_SHADOW); + } +} + + +void CodegenCoreneuronCppVisitor::print_nrn_cur_matrix_shadow_update() { + if (info.point_process) { + printer->add_line("shadow_rhs[id] = rhs;"); + printer->add_line("shadow_d[id] = g;"); + } else { + auto rhs_op = operator_for_rhs(); + auto d_op = operator_for_d(); + printer->fmt_line("vec_rhs[node_id] {} rhs;", rhs_op); + printer->fmt_line("vec_d[node_id] {} g;", d_op); + } +} + + +void CodegenCoreneuronCppVisitor::print_nrn_cur_matrix_shadow_reduction() { + auto rhs_op = operator_for_rhs(); + auto d_op = operator_for_d(); + if (info.point_process) { + printer->add_line("int node_id = node_index[id];"); + printer->fmt_line("vec_rhs[node_id] {} shadow_rhs[id];", rhs_op); + printer->fmt_line("vec_d[node_id] {} shadow_d[id];", d_op); + } +} + + +/** + * In the current implementation of CPU/CPP backend we need to emit atomic pragma + * only with PROTECT construct (atomic rduction requirement for other cases on CPU + * is handled via separate shadow vectors). + */ +void CodegenCoreneuronCppVisitor::print_atomic_reduction_pragma() { + printer->add_line("#pragma omp atomic update"); +} + + +void CodegenCoreneuronCppVisitor::print_device_method_annotation() { + // backend specific, nothing for cpu +} + + +void CodegenCoreneuronCppVisitor::print_global_method_annotation() { + // backend specific, nothing for cpu +} + + +void CodegenCoreneuronCppVisitor::print_backend_namespace_start() { + // no separate namespace for C++ (cpu) backend +} + + +void CodegenCoreneuronCppVisitor::print_backend_namespace_stop() { + // no separate namespace for C++ (cpu) backend +} + + +void CodegenCoreneuronCppVisitor::print_backend_includes() { + // backend specific, nothing for cpu +} + + +bool CodegenCoreneuronCppVisitor::optimize_ion_variable_copies() const { + return optimize_ionvar_copies; +} + + +void CodegenCoreneuronCppVisitor::print_memory_allocation_routine() const { + printer->add_newline(2); + auto args = "size_t num, size_t size, size_t alignment = 16"; + printer->fmt_push_block("static inline void* mem_alloc({})", args); + printer->add_line("void* ptr;"); + printer->add_line("posix_memalign(&ptr, alignment, num*size);"); + printer->add_line("memset(ptr, 0, size);"); + printer->add_line("return ptr;"); + printer->pop_block(); + + printer->add_newline(2); + printer->push_block("static inline void mem_free(void* ptr)"); + printer->add_line("free(ptr);"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_abort_routine() const { + printer->add_newline(2); + printer->push_block("static inline void coreneuron_abort()"); + printer->add_line("abort();"); + printer->pop_block(); +} + + +std::string CodegenCoreneuronCppVisitor::compute_method_name(BlockType type) const { + if (type == BlockType::Initial) { + return method_name(naming::NRN_INIT_METHOD); + } + if (type == BlockType::Constructor) { + return method_name(naming::NRN_CONSTRUCTOR_METHOD); + } + if (type == BlockType::Destructor) { + return method_name(naming::NRN_DESTRUCTOR_METHOD); + } + if (type == BlockType::State) { + return method_name(naming::NRN_STATE_METHOD); + } + if (type == BlockType::Equation) { + return method_name(naming::NRN_CUR_METHOD); + } + if (type == BlockType::Watch) { + return method_name(naming::NRN_WATCH_CHECK_METHOD); + } + throw std::logic_error("compute_method_name not implemented"); +} + + +void CodegenCoreneuronCppVisitor::print_global_var_struct_decl() { + printer->add_line(global_struct(), ' ', global_struct_instance(), ';'); +} + + +/****************************************************************************************/ +/* Printing routines for code generation */ +/****************************************************************************************/ + + +void CodegenCoreneuronCppVisitor::print_function_call(const FunctionCall& node) { + const auto& name = node.get_node_name(); + auto function_name = name; + if (defined_method(name)) { + function_name = method_name(name); + } + + if (is_net_send(name)) { + print_net_send_call(node); + return; + } + + if (is_net_move(name)) { + print_net_move_call(node); + return; + } + + if (is_net_event(name)) { + print_net_event_call(node); + return; + } + + const auto& arguments = node.get_arguments(); + printer->add_text(function_name, '('); + + if (defined_method(name)) { + printer->add_text(internal_method_arguments()); + if (!arguments.empty()) { + printer->add_text(", "); + } + } + + print_vector_elements(arguments, ", "); + printer->add_text(')'); +} + + +void CodegenCoreneuronCppVisitor::print_top_verbatim_blocks() { + if (info.top_verbatim_blocks.empty()) { + return; + } + print_namespace_stop(); + + printer->add_newline(2); + printer->add_line("using namespace coreneuron;"); + codegen = true; + printing_top_verbatim_blocks = true; + + for (const auto& block: info.top_blocks) { + if (block->is_verbatim()) { + printer->add_newline(2); + block->accept(*this); + } + } + + printing_top_verbatim_blocks = false; + codegen = false; + print_namespace_start(); +} + + +void CodegenCoreneuronCppVisitor::print_function_prototypes() { + if (info.functions.empty() && info.procedures.empty()) { + return; + } + codegen = true; + printer->add_newline(2); + for (const auto& node: info.functions) { + print_function_declaration(*node, node->get_node_name()); + printer->add_text(';'); + printer->add_newline(); + } + for (const auto& node: info.procedures) { + print_function_declaration(*node, node->get_node_name()); + printer->add_text(';'); + printer->add_newline(); + } + codegen = false; +} + + +static const TableStatement* get_table_statement(const ast::Block& node) { + // TableStatementVisitor v; + + const auto& table_statements = collect_nodes(node, {AstNodeType::TABLE_STATEMENT}); + + if (table_statements.size() != 1) { + auto message = fmt::format("One table statement expected in {} found {}", + node.get_node_name(), + table_statements.size()); + throw std::runtime_error(message); + } + return dynamic_cast(table_statements.front().get()); +} + + +std::tuple CodegenCoreneuronCppVisitor::check_if_var_is_array(const std::string& name) { + auto symbol = program_symtab->lookup_in_scope(name); + if (!symbol) { + throw std::runtime_error( + fmt::format("CodegenCoreneuronCppVisitor:: {} not found in symbol table!", name)); + } + if (symbol->is_array()) { + return {true, symbol->get_length()}; + } else { + return {false, 0}; + } +} + + +void CodegenCoreneuronCppVisitor::print_table_check_function(const Block& node) { + auto statement = get_table_statement(node); + auto table_variables = statement->get_table_vars(); + auto depend_variables = statement->get_depend_vars(); + const auto& from = statement->get_from(); + const auto& to = statement->get_to(); + auto name = node.get_node_name(); + auto internal_params = internal_method_parameters(); + auto with = statement->get_with()->eval(); + auto use_table_var = get_variable_name(naming::USE_TABLE_VARIABLE); + auto tmin_name = get_variable_name("tmin_" + name); + auto mfac_name = get_variable_name("mfac_" + name); + auto float_type = default_float_data_type(); + + printer->add_newline(2); + print_device_method_annotation(); + printer->fmt_push_block("void check_{}({})", + method_name(name), + get_parameter_str(internal_params)); + { + printer->fmt_push_block("if ({} == 0)", use_table_var); + printer->add_line("return;"); + printer->pop_block(); + + printer->add_line("static bool make_table = true;"); + for (const auto& variable: depend_variables) { + printer->fmt_line("static {} save_{};", float_type, variable->get_node_name()); + } + + for (const auto& variable: depend_variables) { + const auto& var_name = variable->get_node_name(); + const auto& instance_name = get_variable_name(var_name); + printer->fmt_push_block("if (save_{} != {})", var_name, instance_name); + printer->add_line("make_table = true;"); + printer->pop_block(); + } + + printer->push_block("if (make_table)"); + { + printer->add_line("make_table = false;"); + + printer->add_indent(); + printer->add_text(tmin_name, " = "); + from->accept(*this); + printer->add_text(';'); + printer->add_newline(); + + printer->add_indent(); + printer->add_text("double tmax = "); + to->accept(*this); + printer->add_text(';'); + printer->add_newline(); + + + printer->fmt_line("double dx = (tmax-{}) / {}.;", tmin_name, with); + printer->fmt_line("{} = 1./dx;", mfac_name); + + printer->fmt_line("double x = {};", tmin_name); + printer->fmt_push_block("for (std::size_t i = 0; i < {}; x += dx, i++)", with + 1); + auto function = method_name("f_" + name); + if (node.is_procedure_block()) { + printer->fmt_line("{}({}, x);", function, internal_method_arguments()); + for (const auto& variable: table_variables) { + auto var_name = variable->get_node_name(); + auto instance_name = get_variable_name(var_name); + auto table_name = get_variable_name("t_" + var_name); + auto [is_array, array_length] = check_if_var_is_array(var_name); + if (is_array) { + for (int j = 0; j < array_length; j++) { + printer->fmt_line( + "{}[{}][i] = {}[{}];", table_name, j, instance_name, j); + } + } else { + printer->fmt_line("{}[i] = {};", table_name, instance_name); + } + } + } else { + auto table_name = get_variable_name("t_" + name); + printer->fmt_line("{}[i] = {}({}, x);", + table_name, + function, + internal_method_arguments()); + } + printer->pop_block(); + + for (const auto& variable: depend_variables) { + auto var_name = variable->get_node_name(); + auto instance_name = get_variable_name(var_name); + printer->fmt_line("save_{} = {};", var_name, instance_name); + } + } + printer->pop_block(); + } + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_table_replacement_function(const ast::Block& node) { + auto name = node.get_node_name(); + auto statement = get_table_statement(node); + auto table_variables = statement->get_table_vars(); + auto with = statement->get_with()->eval(); + auto use_table_var = get_variable_name(naming::USE_TABLE_VARIABLE); + auto tmin_name = get_variable_name("tmin_" + name); + auto mfac_name = get_variable_name("mfac_" + name); + auto function_name = method_name("f_" + name); + + printer->add_newline(2); + print_function_declaration(node, name); + printer->push_block(); + { + const auto& params = node.get_parameters(); + printer->fmt_push_block("if ({} == 0)", use_table_var); + if (node.is_procedure_block()) { + printer->fmt_line("{}({}, {});", + function_name, + internal_method_arguments(), + params[0].get()->get_node_name()); + printer->add_line("return 0;"); + } else { + printer->fmt_line("return {}({}, {});", + function_name, + internal_method_arguments(), + params[0].get()->get_node_name()); + } + printer->pop_block(); + + printer->fmt_line("double xi = {} * ({} - {});", + mfac_name, + params[0].get()->get_node_name(), + tmin_name); + printer->push_block("if (isnan(xi))"); + if (node.is_procedure_block()) { + for (const auto& var: table_variables) { + auto var_name = get_variable_name(var->get_node_name()); + auto [is_array, array_length] = check_if_var_is_array(var->get_node_name()); + if (is_array) { + for (int j = 0; j < array_length; j++) { + printer->fmt_line("{}[{}] = xi;", var_name, j); + } + } else { + printer->fmt_line("{} = xi;", var_name); + } + } + printer->add_line("return 0;"); + } else { + printer->add_line("return xi;"); + } + printer->pop_block(); + + printer->fmt_push_block("if (xi <= 0. || xi >= {}.)", with); + printer->fmt_line("int index = (xi <= 0.) ? 0 : {};", with); + if (node.is_procedure_block()) { + for (const auto& variable: table_variables) { + auto var_name = variable->get_node_name(); + auto instance_name = get_variable_name(var_name); + auto table_name = get_variable_name("t_" + var_name); + auto [is_array, array_length] = check_if_var_is_array(var_name); + if (is_array) { + for (int j = 0; j < array_length; j++) { + printer->fmt_line( + "{}[{}] = {}[{}][index];", instance_name, j, table_name, j); + } + } else { + printer->fmt_line("{} = {}[index];", instance_name, table_name); + } + } + printer->add_line("return 0;"); + } else { + auto table_name = get_variable_name("t_" + name); + printer->fmt_line("return {}[index];", table_name); + } + printer->pop_block(); + + printer->add_line("int i = int(xi);"); + printer->add_line("double theta = xi - double(i);"); + if (node.is_procedure_block()) { + for (const auto& var: table_variables) { + auto var_name = var->get_node_name(); + auto instance_name = get_variable_name(var_name); + auto table_name = get_variable_name("t_" + var_name); + auto [is_array, array_length] = check_if_var_is_array(var->get_node_name()); + if (is_array) { + for (size_t j = 0; j < array_length; j++) { + printer->fmt_line( + "{0}[{1}] = {2}[{1}][i] + theta*({2}[{1}][i+1]-{2}[{1}][i]);", + instance_name, + j, + table_name); + } + } else { + printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);", + instance_name, + table_name); + } + } + printer->add_line("return 0;"); + } else { + auto table_name = get_variable_name("t_" + name); + printer->fmt_line("return {0}[i] + theta * ({0}[i+1] - {0}[i]);", table_name); + } + } + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_check_table_thread_function() { + if (info.table_count == 0) { + return; + } + + printer->add_newline(2); + auto name = method_name("check_table_thread"); + auto parameters = external_method_parameters(true); + + printer->fmt_push_block("static void {} ({})", name, parameters); + printer->add_line("setup_instance(nt, ml);"); + printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); + printer->add_line("double v = 0;"); + + for (const auto& function: info.functions_with_table) { + auto method_name_str = method_name("check_" + function->get_node_name()); + auto arguments = internal_method_arguments(); + printer->fmt_line("{}({});", method_name_str, arguments); + } + + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_function_or_procedure(const ast::Block& node, + const std::string& name) { + printer->add_newline(2); + print_function_declaration(node, name); + printer->add_text(" "); + printer->push_block(); + + // function requires return variable declaration + if (node.is_function_block()) { + auto type = default_float_data_type(); + printer->fmt_line("{} ret_{} = 0.0;", type, name); + } else { + printer->fmt_line("int ret_{} = 0;", name); + } + + print_statement_block(*node.get_statement_block(), false, false); + printer->fmt_line("return ret_{};", name); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_function_procedure_helper(const ast::Block& node) { + codegen = true; + auto name = node.get_node_name(); + + if (info.function_uses_table(name)) { + auto new_name = "f_" + name; + print_function_or_procedure(node, new_name); + print_table_check_function(node); + print_table_replacement_function(node); + } else { + print_function_or_procedure(node, name); + } + + codegen = false; +} + + +void CodegenCoreneuronCppVisitor::print_procedure(const ast::ProcedureBlock& node) { + print_function_procedure_helper(node); +} + + +void CodegenCoreneuronCppVisitor::print_function(const ast::FunctionBlock& node) { + auto name = node.get_node_name(); + + // name of return variable + std::string return_var; + if (info.function_uses_table(name)) { + return_var = "ret_f_" + name; + } else { + return_var = "ret_" + name; + } + + // first rename return variable name + auto block = node.get_statement_block().get(); + RenameVisitor v(name, return_var); + block->accept(v); + + print_function_procedure_helper(node); +} + + +void CodegenCoreneuronCppVisitor::print_function_tables(const ast::FunctionTableBlock& node) { + auto name = node.get_node_name(); + const auto& p = node.get_parameters(); + auto params = internal_method_parameters(); + for (const auto& i: p) { + params.emplace_back("", "double", "", i->get_node_name()); + } + printer->fmt_line("double {}({})", method_name(name), get_parameter_str(params)); + printer->push_block(); + printer->fmt_line("double _arg[{}];", p.size()); + for (size_t i = 0; i < p.size(); ++i) { + printer->fmt_line("_arg[{}] = {};", i, p[i]->get_node_name()); + } + printer->fmt_line("return hoc_func_table({}, {}, _arg);", + get_variable_name(std::string("_ptable_" + name), true), + p.size()); + printer->pop_block(); + + printer->fmt_push_block("double table_{}()", method_name(name)); + printer->fmt_line("hoc_spec_table(&{}, {});", + get_variable_name(std::string("_ptable_" + name)), + p.size()); + printer->add_line("return 0.;"); + printer->pop_block(); +} + + +/** + * @brief Checks whether the functor_block generated by sympy solver modifies any variable outside + * its scope. If it does then return false, so that the operator() of the struct functor of the + * Eigen Newton solver doesn't have const qualifier. + * + * @param variable_block Statement Block of the variables declarations used in the functor struct of + * the solver + * @param functor_block Actual code being printed in the operator() of the functor struct of the + * solver + * @return True if operator() is const else False + */ +bool CodegenCoreneuronCppVisitor::is_functor_const(const ast::StatementBlock& variable_block, + const ast::StatementBlock& functor_block) { + // Create complete_block with both variable declarations (done in variable_block) and solver + // part (done in functor_block) to be able to run the SymtabVisitor and DefUseAnalyzeVisitor + // then and get the proper DUChains for the variables defined in the variable_block + ast::StatementBlock complete_block(functor_block); + // Typically variable_block has only one statement, a statement containing the declaration + // of the local variables + for (const auto& statement: variable_block.get_statements()) { + complete_block.insert_statement(complete_block.get_statements().begin(), statement); + } + + // Create Symbol Table for complete_block + auto model_symbol_table = std::make_shared(); + SymtabVisitor(model_symbol_table.get()).visit_statement_block(complete_block); + // Initialize DefUseAnalyzeVisitor to generate the DUChains for the variables defined in the + // variable_block + DefUseAnalyzeVisitor v(*complete_block.get_symbol_table()); + + // Check the DUChains for all the variables in the variable_block + // If variable is defined in complete_block don't add const quilifier in operator() + auto is_functor_const = true; + const auto& variables = collect_nodes(variable_block, {ast::AstNodeType::LOCAL_VAR}); + for (const auto& variable: variables) { + const auto& chain = v.analyze(complete_block, variable->get_node_name()); + is_functor_const = !(chain.eval() == DUState::D || chain.eval() == DUState::LD || + chain.eval() == DUState::CD); + if (!is_functor_const) { + break; + } + } + + return is_functor_const; +} + + +void CodegenCoreneuronCppVisitor::print_functor_definition( + const ast::EigenNewtonSolverBlock& node) { + // functor that evaluates F(X) and J(X) for + // Newton solver + auto float_type = default_float_data_type(); + int N = node.get_n_state_vars()->get_value(); + + const auto functor_name = info.functor_names[&node]; + printer->fmt_push_block("struct {0}", functor_name); + printer->add_line("NrnThread* nt;"); + printer->add_line(instance_struct(), "* inst;"); + printer->add_line("int id, pnodecount;"); + printer->add_line("double v;"); + printer->add_line("const Datum* indexes;"); + printer->add_line("double* data;"); + printer->add_line("ThreadDatum* thread;"); + + if (ion_variable_struct_required()) { + print_ion_variable(); + } + + print_statement_block(*node.get_variable_block(), false, false); + printer->add_newline(); + + printer->push_block("void initialize()"); + print_statement_block(*node.get_initialize_block(), false, false); + printer->pop_block(); + printer->add_newline(); + + printer->fmt_line( + "{0}(NrnThread* nt, {1}* inst, int id, int pnodecount, double v, const Datum* indexes, " + "double* data, ThreadDatum* thread) : " + "nt{{nt}}, inst{{inst}}, id{{id}}, pnodecount{{pnodecount}}, v{{v}}, indexes{{indexes}}, " + "data{{data}}, thread{{thread}} " + "{{}}", + functor_name, + instance_struct()); + + printer->add_indent(); + + const auto& variable_block = *node.get_variable_block(); + const auto& functor_block = *node.get_functor_block(); + + printer->fmt_text( + "void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, " + "1>& nmodl_eigen_fm, " + "Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}", + float_type, + N, + is_functor_const(variable_block, functor_block) ? "const " : ""); + printer->push_block(); + printer->fmt_line("const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); + printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type); + printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type); + print_statement_block(functor_block, false, false); + printer->pop_block(); + printer->add_newline(); + + // assign newton solver results in matrix X to state vars + printer->push_block("void finalize()"); + print_statement_block(*node.get_finalize_block(), false, false); + printer->pop_block(); + + printer->pop_block(";"); +} + + +void CodegenCoreneuronCppVisitor::print_eigen_linear_solver(const std::string& float_type, int N) { + if (N <= 4) { + // Faster compared to LU, given the template specialization in Eigen. + printer->add_multi_line(R"CODE( + bool invertible; + nmodl_eigen_jm.computeInverseWithCheck(nmodl_eigen_jm_inv,invertible); + nmodl_eigen_xm = nmodl_eigen_jm_inv*nmodl_eigen_fm; + if (!invertible) assert(false && "Singular or ill-conditioned matrix (Eigen::inverse)!"); + )CODE"); + } else { + // In Eigen the default storage order is ColMajor. + // Crout's implementation requires matrices stored in RowMajor order (C++-style arrays). + // Therefore, the transposeInPlace is critical such that the data() method to give the rows + // instead of the columns. + printer->add_line("if (!nmodl_eigen_jm.IsRowMajor) nmodl_eigen_jm.transposeInPlace();"); + + // pivot vector + printer->fmt_line("Eigen::Matrix pivot;", N); + printer->fmt_line("Eigen::Matrix<{0}, {1}, 1> rowmax;", float_type, N); + + // In-place LU-Decomposition (Crout Algo) : Jm is replaced by its LU-decomposition + printer->fmt_line( + "if (nmodl::crout::Crout<{0}>({1}, nmodl_eigen_jm.data(), pivot.data(), rowmax.data()) " + "< 0) assert(false && \"Singular or ill-conditioned matrix (nmodl::crout)!\");", + float_type, + N); + + // Solve the linear system : Forward/Backward substitution part + printer->fmt_line( + "nmodl::crout::solveCrout<{0}>({1}, nmodl_eigen_jm.data(), nmodl_eigen_fm.data(), " + "nmodl_eigen_xm.data(), pivot.data());", + float_type, + N); + } +} + + +/****************************************************************************************/ +/* Code-specific helper routines */ +/****************************************************************************************/ + + +std::string CodegenCoreneuronCppVisitor::internal_method_arguments() { + if (ion_variable_struct_required()) { + return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v"; + } + return "id, pnodecount, inst, data, indexes, thread, nt, v"; +} + + +/** + * @todo: figure out how to correctly handle qualifiers + */ +CodegenCoreneuronCppVisitor::ParamVector CodegenCoreneuronCppVisitor::internal_method_parameters() { + ParamVector params; + params.emplace_back("", "int", "", "id"); + params.emplace_back("", "int", "", "pnodecount"); + params.emplace_back("", fmt::format("{}*", instance_struct()), "", "inst"); + if (ion_variable_struct_required()) { + params.emplace_back("", "IonCurVar&", "", "ionvar"); + } + params.emplace_back("", "double*", "", "data"); + params.emplace_back("const ", "Datum*", "", "indexes"); + params.emplace_back("", "ThreadDatum*", "", "thread"); + params.emplace_back("", "NrnThread*", "", "nt"); + params.emplace_back("", "double", "", "v"); + return params; +} + + +const char* CodegenCoreneuronCppVisitor::external_method_arguments() noexcept { + return "id, pnodecount, data, indexes, thread, nt, ml, v"; +} + + +const char* CodegenCoreneuronCppVisitor::external_method_parameters(bool table) noexcept { + if (table) { + return "int id, int pnodecount, double* data, Datum* indexes, " + "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, int tml_id"; + } + return "int id, int pnodecount, double* data, Datum* indexes, " + "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, double v"; +} + + +std::string CodegenCoreneuronCppVisitor::nrn_thread_arguments() const { + if (ion_variable_struct_required()) { + return "id, pnodecount, ionvar, data, indexes, thread, nt, ml, v"; + } + return "id, pnodecount, data, indexes, thread, nt, ml, v"; +} + + +/** + * Function call arguments when function or procedure is defined in the + * same mod file itself + */ +std::string CodegenCoreneuronCppVisitor::nrn_thread_internal_arguments() { + if (ion_variable_struct_required()) { + return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v"; + } + return "id, pnodecount, inst, data, indexes, thread, nt, v"; +} + + +/** + * Replace commonly used variables in the verbatim blocks into their corresponding + * variable name in the new code generation backend. + */ +std::string CodegenCoreneuronCppVisitor::replace_if_verbatim_variable(std::string name) { + if (naming::VERBATIM_VARIABLES_MAPPING.find(name) != naming::VERBATIM_VARIABLES_MAPPING.end()) { + name = naming::VERBATIM_VARIABLES_MAPPING.at(name); + } + + /** + * if function is defined the same mod file then the arguments must + * contain mechanism instance as well. + */ + if (name == naming::THREAD_ARGS) { + if (internal_method_call_encountered) { + name = nrn_thread_internal_arguments(); + internal_method_call_encountered = false; + } else { + name = nrn_thread_arguments(); + } + } + if (name == naming::THREAD_ARGS_PROTO) { + name = external_method_parameters(); + } + return name; +} + + +/** + * Processing commonly used constructs in the verbatim blocks. + * @todo : this is still ad-hoc and requires re-implementation to + * handle it more elegantly. + */ +std::string CodegenCoreneuronCppVisitor::process_verbatim_text(std::string const& text) { + parser::CDriver driver; + driver.scan_string(text); + auto tokens = driver.all_tokens(); + std::string result; + for (size_t i = 0; i < tokens.size(); i++) { + auto token = tokens[i]; + + // check if we have function call in the verbatim block where + // function is defined in the same mod file + if (program_symtab->is_method_defined(token) && tokens[i + 1] == "(") { + internal_method_call_encountered = true; + } + auto name = process_verbatim_token(token); + + if (token == (std::string("_") + naming::TQITEM_VARIABLE)) { + name.insert(0, 1, '&'); + } + if (token == "_STRIDE") { + name = "pnodecount+id"; + } + result += name; + } + return result; +} + + +std::string CodegenCoreneuronCppVisitor::register_mechanism_arguments() const { + auto nrn_cur = nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "nullptr"; + auto nrn_state = nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "nullptr"; + auto nrn_alloc = method_name(naming::NRN_ALLOC_METHOD); + auto nrn_init = method_name(naming::NRN_INIT_METHOD); + auto const nrn_private_constructor = method_name(naming::NRN_PRIVATE_CONSTRUCTOR_METHOD); + auto const nrn_private_destructor = method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD); + return fmt::format("mechanism, {}, {}, nullptr, {}, {}, {}, {}, first_pointer_var_index()", + nrn_alloc, + nrn_cur, + nrn_state, + nrn_init, + nrn_private_constructor, + nrn_private_destructor); +} + + +std::pair CodegenCoreneuronCppVisitor::read_ion_variable_name( + const std::string& name) { + return {name, naming::ION_VARNAME_PREFIX + name}; +} + + +std::pair CodegenCoreneuronCppVisitor::write_ion_variable_name( + const std::string& name) { + return {naming::ION_VARNAME_PREFIX + name, name}; +} + + +std::string CodegenCoreneuronCppVisitor::conc_write_statement(const std::string& ion_name, + const std::string& concentration, + int index) { + auto conc_var_name = get_variable_name(naming::ION_VARNAME_PREFIX + concentration); + auto style_var_name = get_variable_name("style_" + ion_name); + return fmt::format( + "nrn_wrote_conc({}_type," + " &({})," + " {}," + " {}," + " nrn_ion_global_map," + " {}," + " nt->_ml_list[{}_type]->_nodecount_padded)", + ion_name, + conc_var_name, + index, + style_var_name, + get_variable_name(naming::CELSIUS_VARIABLE), + ion_name); +} + + +/** + * If mechanisms dependency level execution is enabled then certain updates + * like ionic current contributions needs to be atomically updated. In this + * case we first update current mechanism's shadow vector and then add statement + * to queue that will be used in reduction queue. + */ +std::string CodegenCoreneuronCppVisitor::process_shadow_update_statement( + const ShadowUseStatement& statement, + BlockType /* type */) { + // when there is no operator or rhs then that statement doesn't need shadow update + if (statement.op.empty() && statement.rhs.empty()) { + auto text = statement.lhs + ";"; + return text; + } + + // return regular statement + auto lhs = get_variable_name(statement.lhs); + auto text = fmt::format("{} {} {};", lhs, statement.op, statement.rhs); + return text; +} + + +/****************************************************************************************/ +/* Code-specific printing routines for code generation */ +/****************************************************************************************/ + + +void CodegenCoreneuronCppVisitor::print_first_pointer_var_index_getter() { + printer->add_newline(2); + print_device_method_annotation(); + printer->push_block("static inline int first_pointer_var_index()"); + printer->fmt_line("return {};", info.first_pointer_var_index); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_num_variable_getter() { + printer->add_newline(2); + print_device_method_annotation(); + printer->push_block("static inline int float_variables_size()"); + printer->fmt_line("return {};", float_variables_size()); + printer->pop_block(); + + printer->add_newline(2); + print_device_method_annotation(); + printer->push_block("static inline int int_variables_size()"); + printer->fmt_line("return {};", int_variables_size()); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_net_receive_arg_size_getter() { + if (!net_receive_exist()) { + return; + } + printer->add_newline(2); + print_device_method_annotation(); + printer->push_block("static inline int num_net_receive_args()"); + printer->fmt_line("return {};", info.num_net_receive_parameters); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_mech_type_getter() { + printer->add_newline(2); + print_device_method_annotation(); + printer->push_block("static inline int get_mech_type()"); + // false => get it from the host-only global struct, not the instance structure + printer->fmt_line("return {};", get_variable_name("mech_type", false)); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_memb_list_getter() { + printer->add_newline(2); + print_device_method_annotation(); + printer->push_block("static inline Memb_list* get_memb_list(NrnThread* nt)"); + printer->push_block("if (!nt->_ml_list)"); + printer->add_line("return nullptr;"); + printer->pop_block(); + printer->add_line("return nt->_ml_list[get_mech_type()];"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_namespace_start() { + printer->add_newline(2); + printer->push_block("namespace coreneuron"); +} + + +void CodegenCoreneuronCppVisitor::print_namespace_stop() { + printer->pop_block(); +} + + +/** + * \details There are three types of thread variables currently considered: + * - top local thread variables + * - thread variables in the mod file + * - thread variables for solver + * + * These variables are allocated into different thread structures and have + * corresponding thread ids. Thread id start from 0. In mod2c implementation, + * thread_data_index is increased at various places and it is used to + * decide the index of thread. + */ +void CodegenCoreneuronCppVisitor::print_thread_getters() { + if (info.vectorize && info.derivimplicit_used()) { + int tid = info.derivimplicit_var_thread_id; + int list = info.derivimplicit_list_num; + + // clang-format off + printer->add_newline(2); + printer->add_line("/** thread specific helper routines for derivimplicit */"); + + printer->add_newline(1); + printer->fmt_push_block("static inline int* deriv{}_advance(ThreadDatum* thread)", list); + printer->fmt_line("return &(thread[{}].i);", tid); + printer->pop_block(); + printer->add_newline(); + + printer->fmt_push_block("static inline int dith{}()", list); + printer->fmt_line("return {};", tid+1); + printer->pop_block(); + printer->add_newline(); + + printer->fmt_push_block("static inline void** newtonspace{}(ThreadDatum* thread)", list); + printer->fmt_line("return &(thread[{}]._pvoid);", tid+2); + printer->pop_block(); + } + + if (info.vectorize && !info.thread_variables.empty()) { + printer->add_newline(2); + printer->add_line("/** tid for thread variables */"); + printer->push_block("static inline int thread_var_tid()"); + printer->fmt_line("return {};", info.thread_var_thread_id); + printer->pop_block(); + } + + if (info.vectorize && !info.top_local_variables.empty()) { + printer->add_newline(2); + printer->add_line("/** tid for top local tread variables */"); + printer->push_block("static inline int top_local_var_tid()"); + printer->fmt_line("return {};", info.top_local_thread_id); + printer->pop_block(); + } + // clang-format on +} + + +/****************************************************************************************/ +/* Routines for returning variable name */ +/****************************************************************************************/ + + +std::string CodegenCoreneuronCppVisitor::float_variable_name(const SymbolType& symbol, + bool use_instance) const { + auto name = symbol->get_name(); + auto dimension = symbol->get_length(); + auto position = position_of_float_var(name); + // clang-format off + if (symbol->is_array()) { + if (use_instance) { + return fmt::format("(inst->{}+id*{})", name, dimension); + } + return fmt::format("(data + {}*pnodecount + id*{})", position, dimension); + } + if (use_instance) { + return fmt::format("inst->{}[id]", name); + } + return fmt::format("data[{}*pnodecount + id]", position); + // clang-format on +} + + +std::string CodegenCoreneuronCppVisitor::int_variable_name(const IndexVariableInfo& symbol, + const std::string& name, + bool use_instance) const { + auto position = position_of_int_var(name); + // clang-format off + if (symbol.is_index) { + if (use_instance) { + return fmt::format("inst->{}[{}]", name, position); + } + return fmt::format("indexes[{}]", position); + } + if (symbol.is_integer) { + if (use_instance) { + return fmt::format("inst->{}[{}*pnodecount+id]", name, position); + } + return fmt::format("indexes[{}*pnodecount+id]", position); + } + if (use_instance) { + return fmt::format("inst->{}[indexes[{}*pnodecount + id]]", name, position); + } + auto data = symbol.is_vdata ? "_vdata" : "_data"; + return fmt::format("nt->{}[indexes[{}*pnodecount + id]]", data, position); + // clang-format on +} + + +std::string CodegenCoreneuronCppVisitor::global_variable_name(const SymbolType& symbol, + bool use_instance) const { + if (use_instance) { + return fmt::format("inst->{}->{}", naming::INST_GLOBAL_MEMBER, symbol->get_name()); + } else { + return fmt::format("{}.{}", global_struct_instance(), symbol->get_name()); + } +} + + +std::string CodegenCoreneuronCppVisitor::update_if_ion_variable_name( + const std::string& name) const { + std::string result(name); + if (ion_variable_struct_required()) { + if (info.is_ion_read_variable(name)) { + result = naming::ION_VARNAME_PREFIX + name; + } + if (info.is_ion_write_variable(name)) { + result = "ionvar." + name; + } + if (info.is_current(name)) { + result = "ionvar." + name; + } + } + return result; +} + + +std::string CodegenCoreneuronCppVisitor::get_variable_name(const std::string& name, + bool use_instance) const { + const std::string& varname = update_if_ion_variable_name(name); + + // clang-format off + auto symbol_comparator = [&varname](const SymbolType& sym) { + return varname == sym->get_name(); + }; + + auto index_comparator = [&varname](const IndexVariableInfo& var) { + return varname == var.symbol->get_name(); + }; + // clang-format on + + // float variable + auto f = std::find_if(codegen_float_variables.begin(), + codegen_float_variables.end(), + symbol_comparator); + if (f != codegen_float_variables.end()) { + return float_variable_name(*f, use_instance); + } + + // integer variable + auto i = + std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), index_comparator); + if (i != codegen_int_variables.end()) { + return int_variable_name(*i, varname, use_instance); + } + + // global variable + auto g = std::find_if(codegen_global_variables.begin(), + codegen_global_variables.end(), + symbol_comparator); + if (g != codegen_global_variables.end()) { + return global_variable_name(*g, use_instance); + } + + if (varname == naming::NTHREAD_DT_VARIABLE) { + return std::string("nt->_") + naming::NTHREAD_DT_VARIABLE; + } + + // t in net_receive method is an argument to function and hence it should + // ne used instead of nt->_t which is current time of thread + if (varname == naming::NTHREAD_T_VARIABLE && !printing_net_receive) { + return std::string("nt->_") + naming::NTHREAD_T_VARIABLE; + } + + auto const iter = + std::find_if(info.neuron_global_variables.begin(), + info.neuron_global_variables.end(), + [&varname](auto const& entry) { return entry.first->get_name() == varname; }); + if (iter != info.neuron_global_variables.end()) { + std::string ret; + if (use_instance) { + ret = "*(inst->"; + } + ret.append(varname); + if (use_instance) { + ret.append(")"); + } + return ret; + } + + // otherwise return original name + return varname; +} + + +/****************************************************************************************/ +/* Main printing routines for code generation */ +/****************************************************************************************/ + + +void CodegenCoreneuronCppVisitor::print_backend_info() { + time_t current_time{}; + time(¤t_time); + std::string data_time_str{std::ctime(¤t_time)}; + auto version = nmodl::Version::NMODL_VERSION + " [" + nmodl::Version::GIT_REVISION + "]"; + + printer->add_line("/*********************************************************"); + printer->add_line("Model Name : ", info.mod_suffix); + printer->add_line("Filename : ", info.mod_file, ".mod"); + printer->add_line("NMODL Version : ", nmodl_version()); + printer->fmt_line("Vectorized : {}", info.vectorize); + printer->fmt_line("Threadsafe : {}", info.thread_safe); + printer->add_line("Created : ", stringutils::trim(data_time_str)); + printer->add_line("Simulator : ", simulator_name()); + printer->add_line("Backend : ", backend_name()); + printer->add_line("NMODL Compiler : ", version); + printer->add_line("*********************************************************/"); +} + + +void CodegenCoreneuronCppVisitor::print_standard_includes() { + printer->add_newline(); + printer->add_multi_line(R"CODE( + #include + #include + #include + #include + )CODE"); +} + + +void CodegenCoreneuronCppVisitor::print_coreneuron_includes() { + printer->add_newline(); + printer->add_multi_line(R"CODE( + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + )CODE"); + if (info.eigen_newton_solver_exist) { + printer->add_line("#include "); + } + if (info.eigen_linear_solver_exist) { + if (std::accumulate(info.state_vars.begin(), + info.state_vars.end(), + 0, + [](int l, const SymbolType& variable) { + return l += variable->get_length(); + }) > 4) { + printer->add_line("#include "); + } else { + printer->add_line("#include "); + printer->add_line("#include "); + } + } +} + + +void CodegenCoreneuronCppVisitor::print_sdlists_init(bool print_initializers) { + if (info.primes_size == 0) { + return; + } + const auto count_prime_variables = [](auto size, const SymbolType& symbol) { + return size += symbol->get_length(); + }; + const auto prime_variables_by_order_size = + std::accumulate(info.prime_variables_by_order.begin(), + info.prime_variables_by_order.end(), + 0, + count_prime_variables); + if (info.primes_size != prime_variables_by_order_size) { + throw std::runtime_error{ + fmt::format("primes_size = {} differs from prime_variables_by_order.size() = {}, " + "this should not happen.", + info.primes_size, + info.prime_variables_by_order.size())}; + } + auto const initializer_list = [&](auto const& primes, const char* prefix) -> std::string { + if (!print_initializers) { + return {}; + } + std::string list{"{"}; + for (auto iter = primes.begin(); iter != primes.end(); ++iter) { + auto const& prime = *iter; + list.append(std::to_string(position_of_float_var(prefix + prime->get_name()))); + if (std::next(iter) != primes.end()) { + list.append(", "); + } + } + list.append("}"); + return list; + }; + printer->fmt_line("int slist1[{}]{};", + info.primes_size, + initializer_list(info.prime_variables_by_order, "")); + printer->fmt_line("int dlist1[{}]{};", + info.primes_size, + initializer_list(info.prime_variables_by_order, "D")); + codegen_global_variables.push_back(make_symbol("slist1")); + codegen_global_variables.push_back(make_symbol("dlist1")); + // additional list for derivimplicit method + if (info.derivimplicit_used()) { + auto primes = program_symtab->get_variables_with_properties(NmodlType::prime_name); + printer->fmt_line("int slist2[{}]{};", info.primes_size, initializer_list(primes, "")); + codegen_global_variables.push_back(make_symbol("slist2")); + } +} + + +/** + * \details Variables required for type of ion, type of point process etc. are + * of static int type. For the C++ backend type, it's ok to have + * these variables as file scoped static variables. + * + * Initial values of state variables (h0) are also defined as static + * variables. Note that the state could be ion variable and it could + * be also range variable. Hence lookup into symbol table before. + * + * When model is not vectorized (shouldn't be the case in coreneuron) + * the top local variables become static variables. + * + * Note that static variables are already initialized to 0. We do the + * same for some variables to keep same code as neuron. + */ +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void CodegenCoreneuronCppVisitor::print_mechanism_global_var_structure(bool print_initializers) { + const auto value_initialize = print_initializers ? "{}" : ""; + + auto float_type = default_float_data_type(); + printer->add_newline(2); + printer->add_line("/** all global variables */"); + printer->fmt_push_block("struct {}", global_struct()); + + for (const auto& ion: info.ions) { + auto name = fmt::format("{}_type", ion.name); + printer->fmt_line("int {}{};", name, value_initialize); + codegen_global_variables.push_back(make_symbol(name)); + } + + if (info.point_process) { + printer->fmt_line("int point_type{};", value_initialize); + codegen_global_variables.push_back(make_symbol("point_type")); + } + + for (const auto& var: info.state_vars) { + auto name = var->get_name() + "0"; + auto symbol = program_symtab->lookup(name); + if (symbol == nullptr) { + printer->fmt_line("{} {}{};", float_type, name, value_initialize); + codegen_global_variables.push_back(make_symbol(name)); + } + } + + // Neuron and Coreneuron adds "v" to global variables when vectorize + // is false. But as v is always local variable and passed as argument, + // we don't need to use global variable v + + auto& top_locals = info.top_local_variables; + if (!info.vectorize && !top_locals.empty()) { + for (const auto& var: top_locals) { + auto name = var->get_name(); + auto length = var->get_length(); + if (var->is_array()) { + printer->fmt_line("{} {}[{}] /* TODO init top-local-array */;", + float_type, + name, + length); + } else { + printer->fmt_line("{} {} /* TODO init top-local */;", float_type, name); + } + codegen_global_variables.push_back(var); + } + } + + if (!info.thread_variables.empty()) { + printer->fmt_line("int thread_data_in_use{};", value_initialize); + printer->fmt_line("{} thread_data[{}] /* TODO init thread_data */;", + float_type, + info.thread_var_data_size); + codegen_global_variables.push_back(make_symbol("thread_data_in_use")); + auto symbol = make_symbol("thread_data"); + symbol->set_as_array(info.thread_var_data_size); + codegen_global_variables.push_back(symbol); + } + + // TODO: remove this entirely? + printer->fmt_line("int reset{};", value_initialize); + codegen_global_variables.push_back(make_symbol("reset")); + + printer->fmt_line("int mech_type{};", value_initialize); + codegen_global_variables.push_back(make_symbol("mech_type")); + + for (const auto& var: info.global_variables) { + auto name = var->get_name(); + auto length = var->get_length(); + if (var->is_array()) { + printer->fmt_line("{} {}[{}] /* TODO init const-array */;", float_type, name, length); + } else { + double value{}; + if (auto const& value_ptr = var->get_value()) { + value = *value_ptr; + } + printer->fmt_line("{} {}{};", + float_type, + name, + print_initializers ? fmt::format("{{{:g}}}", value) : std::string{}); + } + codegen_global_variables.push_back(var); + } + + for (const auto& var: info.constant_variables) { + auto const name = var->get_name(); + auto* const value_ptr = var->get_value().get(); + double const value{value_ptr ? *value_ptr : 0}; + printer->fmt_line("{} {}{};", + float_type, + name, + print_initializers ? fmt::format("{{{:g}}}", value) : std::string{}); + codegen_global_variables.push_back(var); + } + + print_sdlists_init(print_initializers); + + if (info.table_count > 0) { + printer->fmt_line("double usetable{};", print_initializers ? "{1}" : ""); + codegen_global_variables.push_back(make_symbol(naming::USE_TABLE_VARIABLE)); + + for (const auto& block: info.functions_with_table) { + const auto& name = block->get_node_name(); + printer->fmt_line("{} tmin_{}{};", float_type, name, value_initialize); + printer->fmt_line("{} mfac_{}{};", float_type, name, value_initialize); + codegen_global_variables.push_back(make_symbol("tmin_" + name)); + codegen_global_variables.push_back(make_symbol("mfac_" + name)); + } + + for (const auto& variable: info.table_statement_variables) { + auto const name = "t_" + variable->get_name(); + auto const num_values = variable->get_num_values(); + if (variable->is_array()) { + int array_len = variable->get_length(); + printer->fmt_line( + "{} {}[{}][{}]{};", float_type, name, array_len, num_values, value_initialize); + } else { + printer->fmt_line("{} {}[{}]{};", float_type, name, num_values, value_initialize); + } + codegen_global_variables.push_back(make_symbol(name)); + } + } + + for (const auto& f: info.function_tables) { + printer->fmt_line("void* _ptable_{}{{}};", f->get_node_name()); + codegen_global_variables.push_back(make_symbol("_ptable_" + f->get_node_name())); + } + + if (info.vectorize && info.thread_data_index) { + printer->fmt_line("ThreadDatum ext_call_thread[{}]{};", + info.thread_data_index, + value_initialize); + codegen_global_variables.push_back(make_symbol("ext_call_thread")); + } + + printer->pop_block(";"); + + print_global_var_struct_assertions(); + print_global_var_struct_decl(); +} + + +void CodegenCoreneuronCppVisitor::print_global_var_struct_assertions() const { + // Assert some things that we assume when copying instances of this struct + // to the GPU and so on. + printer->fmt_line("static_assert(std::is_trivially_copy_constructible_v<{}>);", + global_struct()); + printer->fmt_line("static_assert(std::is_trivially_move_constructible_v<{}>);", + global_struct()); + printer->fmt_line("static_assert(std::is_trivially_copy_assignable_v<{}>);", global_struct()); + printer->fmt_line("static_assert(std::is_trivially_move_assignable_v<{}>);", global_struct()); + printer->fmt_line("static_assert(std::is_trivially_destructible_v<{}>);", global_struct()); +} + + +/** + * Print structs that encapsulate information about scalar and + * vector elements of type global and thread variables. + */ +void CodegenCoreneuronCppVisitor::print_global_variables_for_hoc() { + auto variable_printer = + [&](const std::vector& variables, bool if_array, bool if_vector) { + for (const auto& variable: variables) { + if (variable->is_array() == if_array) { + // false => do not use the instance struct, which is not + // defined in the global declaration that we are printing + auto name = get_variable_name(variable->get_name(), false); + auto ename = add_escape_quote(variable->get_name() + "_" + info.mod_suffix); + auto length = variable->get_length(); + if (if_vector) { + printer->fmt_line("{{{}, {}, {}}},", ename, name, length); + } else { + printer->fmt_line("{{{}, &{}}},", ename, name); + } + } + } + }; + + auto globals = info.global_variables; + auto thread_vars = info.thread_variables; + + if (info.table_count > 0) { + globals.push_back(make_symbol(naming::USE_TABLE_VARIABLE)); + } + + printer->add_newline(2); + printer->add_line("/** connect global (scalar) variables to hoc -- */"); + printer->add_line("static DoubScal hoc_scalar_double[] = {"); + printer->increase_indent(); + variable_printer(globals, false, false); + variable_printer(thread_vars, false, false); + printer->add_line("{nullptr, nullptr}"); + printer->decrease_indent(); + printer->add_line("};"); + + printer->add_newline(2); + printer->add_line("/** connect global (array) variables to hoc -- */"); + printer->add_line("static DoubVec hoc_vector_double[] = {"); + printer->increase_indent(); + variable_printer(globals, true, true); + variable_printer(thread_vars, true, true); + printer->add_line("{nullptr, nullptr, 0}"); + printer->decrease_indent(); + printer->add_line("};"); +} + + +/** + * Return registration type for a given BEFORE/AFTER block + * /param block A BEFORE/AFTER block being registered + * + * Depending on a block type i.e. BEFORE or AFTER and also type + * of it's associated block i.e. BREAKPOINT, INITIAL, SOLVE and + * STEP, the registration type (as an integer) is calculated. + * These values are then interpreted by CoreNEURON internally. + */ +static std::string get_register_type_for_ba_block(const ast::Block* block) { + std::string register_type{}; + BAType ba_type{}; + /// before block have value 10 and after block 20 + if (block->is_before_block()) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + register_type = "BAType::Before"; + ba_type = + dynamic_cast(block)->get_bablock()->get_type()->get_value(); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + register_type = "BAType::After"; + ba_type = + dynamic_cast(block)->get_bablock()->get_type()->get_value(); + } + + /// associated blocks have different values (1 to 4) based on type. + /// These values are based on neuron/coreneuron implementation details. + if (ba_type == BATYPE_BREAKPOINT) { + register_type += " + BAType::Breakpoint"; + } else if (ba_type == BATYPE_SOLVE) { + register_type += " + BAType::Solve"; + } else if (ba_type == BATYPE_INITIAL) { + register_type += " + BAType::Initial"; + } else if (ba_type == BATYPE_STEP) { + register_type += " + BAType::Step"; + } else { + throw std::runtime_error("Unhandled Before/After type encountered during code generation"); + } + return register_type; +} + + +/** + * \details Every mod file has register function to connect with the simulator. + * Various information about mechanism and callbacks get registered with + * the simulator using suffix_reg() function. + * + * Here are details: + * - We should exclude that callback based on the solver, watch statements. + * - If nrn_get_mechtype is < -1 means that mechanism is not used in the + * context of neuron execution and hence could be ignored in coreneuron + * execution. + * - Ions are internally defined and their types can be queried similar to + * other mechanisms. + * - hoc_register_var may not be needed in the context of coreneuron + * - We assume net receive buffer is on. This is because generated code is + * compatible for cpu as well as gpu target. + */ +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void CodegenCoreneuronCppVisitor::print_mechanism_register() { + printer->add_newline(2); + printer->add_line("/** register channel with the simulator */"); + printer->fmt_push_block("void _{}_reg()", info.mod_file); + + // type related information + auto suffix = add_escape_quote(info.mod_suffix); + printer->add_newline(); + printer->fmt_line("int mech_type = nrn_get_mechtype({});", suffix); + printer->fmt_line("{} = mech_type;", get_variable_name("mech_type", false)); + printer->push_block("if (mech_type == -1)"); + printer->add_line("return;"); + printer->pop_block(); + + printer->add_newline(); + printer->add_line("_nrn_layout_reg(mech_type, 0);"); // 0 for SoA + + // register mechanism + const auto mech_arguments = register_mechanism_arguments(); + const auto number_of_thread_objects = num_thread_objects(); + if (info.point_process) { + printer->fmt_line("point_register_mech({}, {}, {}, {});", + mech_arguments, + info.constructor_node ? method_name(naming::NRN_CONSTRUCTOR_METHOD) + : "nullptr", + info.destructor_node ? method_name(naming::NRN_DESTRUCTOR_METHOD) + : "nullptr", + number_of_thread_objects); + } else { + printer->fmt_line("register_mech({}, {});", mech_arguments, number_of_thread_objects); + if (info.constructor_node) { + printer->fmt_line("register_constructor({});", + method_name(naming::NRN_CONSTRUCTOR_METHOD)); + } + } + + // types for ion + for (const auto& ion: info.ions) { + printer->fmt_line("{} = nrn_get_mechtype({});", + get_variable_name(ion.name + "_type", false), + add_escape_quote(ion.name + "_ion")); + } + printer->add_newline(); + + /* + * Register callbacks for thread allocation and cleanup. Note that thread_data_index + * represent total number of thread used minus 1 (i.e. index of last thread). + */ + if (info.vectorize && (info.thread_data_index != 0)) { + // false to avoid getting the copy from the instance structure + printer->fmt_line("thread_mem_init({});", get_variable_name("ext_call_thread", false)); + } + + if (!info.thread_variables.empty()) { + printer->fmt_line("{} = 0;", get_variable_name("thread_data_in_use")); + } + + if (info.thread_callback_register) { + printer->add_line("_nrn_thread_reg0(mech_type, thread_mem_cleanup);"); + printer->add_line("_nrn_thread_reg1(mech_type, thread_mem_init);"); + } + + if (info.emit_table_thread()) { + auto name = method_name("check_table_thread"); + printer->fmt_line("_nrn_thread_table_reg(mech_type, {});", name); + } + + // register read/write callbacks for pointers + if (info.bbcore_pointer_used) { + printer->add_line("hoc_reg_bbcore_read(mech_type, bbcore_read);"); + printer->add_line("hoc_reg_bbcore_write(mech_type, bbcore_write);"); + } + + // register size of double and int elements + // clang-format off + printer->add_line("hoc_register_prop_size(mech_type, float_variables_size(), int_variables_size());"); + // clang-format on + + // register semantics for index variables + for (auto& semantic: info.semantics) { + auto args = + fmt::format("mech_type, {}, {}", semantic.index, add_escape_quote(semantic.name)); + printer->fmt_line("hoc_register_dparam_semantics({});", args); + } + + if (info.is_watch_used()) { + auto watch_fun = compute_method_name(BlockType::Watch); + printer->fmt_line("hoc_register_watch_check({}, mech_type);", watch_fun); + } + + if (info.write_concentration) { + printer->add_line("nrn_writes_conc(mech_type, 0);"); + } + + // register various information for point process type + if (info.net_event_used) { + printer->add_line("add_nrn_has_net_event(mech_type);"); + } + if (info.artificial_cell) { + printer->fmt_line("add_nrn_artcell(mech_type, {});", info.tqitem_index); + } + if (net_receive_buffering_required()) { + printer->fmt_line("hoc_register_net_receive_buffering({}, mech_type);", + method_name("net_buf_receive")); + } + if (info.num_net_receive_parameters != 0) { + auto net_recv_init_arg = "nullptr"; + if (info.net_receive_initial_node != nullptr) { + net_recv_init_arg = "net_init"; + } + printer->fmt_line("set_pnt_receive(mech_type, {}, {}, num_net_receive_args());", + method_name("net_receive"), + net_recv_init_arg); + } + if (info.for_netcon_used) { + // index where information about FOR_NETCON is stored in the integer array + const auto index = + std::find_if(info.semantics.begin(), info.semantics.end(), [](const IndexSemantics& a) { + return a.name == naming::FOR_NETCON_SEMANTIC; + })->index; + printer->fmt_line("add_nrn_fornetcons(mech_type, {});", index); + } + + if (info.net_event_used || info.net_send_used) { + printer->add_line("hoc_register_net_send_buffering(mech_type);"); + } + + /// register all before/after blocks + for (size_t i = 0; i < info.before_after_blocks.size(); i++) { + // register type and associated function name for the block + const auto& block = info.before_after_blocks[i]; + std::string register_type = get_register_type_for_ba_block(block); + std::string function_name = method_name(fmt::format("nrn_before_after_{}", i)); + printer->fmt_line("hoc_reg_ba(mech_type, {}, {});", function_name, register_type); + } + + // register variables for hoc + printer->add_line("hoc_register_var(hoc_scalar_double, hoc_vector_double, NULL);"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_thread_memory_callbacks() { + if (!info.thread_callback_register) { + return; + } + + // thread_mem_init callback + printer->add_newline(2); + printer->add_line("/** thread memory allocation callback */"); + printer->push_block("static void thread_mem_init(ThreadDatum* thread) "); + + if (info.vectorize && info.derivimplicit_used()) { + printer->fmt_line("thread[dith{}()].pval = nullptr;", info.derivimplicit_list_num); + } + if (info.vectorize && (info.top_local_thread_size != 0)) { + auto length = info.top_local_thread_size; + auto allocation = fmt::format("(double*)mem_alloc({}, sizeof(double))", length); + printer->fmt_line("thread[top_local_var_tid()].pval = {};", allocation); + } + if (info.thread_var_data_size != 0) { + auto length = info.thread_var_data_size; + auto thread_data = get_variable_name("thread_data"); + auto thread_data_in_use = get_variable_name("thread_data_in_use"); + auto allocation = fmt::format("(double*)mem_alloc({}, sizeof(double))", length); + printer->fmt_push_block("if ({})", thread_data_in_use); + printer->fmt_line("thread[thread_var_tid()].pval = {};", allocation); + printer->chain_block("else"); + printer->fmt_line("thread[thread_var_tid()].pval = {};", thread_data); + printer->fmt_line("{} = 1;", thread_data_in_use); + printer->pop_block(); + } + printer->pop_block(); + printer->add_newline(2); + + + // thread_mem_cleanup callback + printer->add_line("/** thread memory cleanup callback */"); + printer->push_block("static void thread_mem_cleanup(ThreadDatum* thread) "); + + // clang-format off + if (info.vectorize && info.derivimplicit_used()) { + int n = info.derivimplicit_list_num; + printer->fmt_line("free(thread[dith{}()].pval);", n); + printer->fmt_line("nrn_destroy_newtonspace(static_cast(*newtonspace{}(thread)));", n); + } + // clang-format on + + if (info.top_local_thread_size != 0) { + auto line = "free(thread[top_local_var_tid()].pval);"; + printer->add_line(line); + } + if (info.thread_var_data_size != 0) { + auto thread_data = get_variable_name("thread_data"); + auto thread_data_in_use = get_variable_name("thread_data_in_use"); + printer->fmt_push_block("if (thread[thread_var_tid()].pval == {})", thread_data); + printer->fmt_line("{} = 0;", thread_data_in_use); + printer->chain_block("else"); + printer->add_line("free(thread[thread_var_tid()].pval);"); + printer->pop_block(); + } + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) { + auto const value_initialize = print_initializers ? "{}" : ""; + auto int_type = default_int_data_type(); + printer->add_newline(2); + printer->add_line("/** all mechanism instance variables and global variables */"); + printer->fmt_push_block("struct {} ", instance_struct()); + + for (auto const& [var, type]: info.neuron_global_variables) { + auto const name = var->get_name(); + printer->fmt_line("{}* {}{};", + type, + name, + print_initializers ? fmt::format("{{&coreneuron::{}}}", name) + : std::string{}); + } + for (auto& var: codegen_float_variables) { + const auto& name = var->get_name(); + auto type = get_range_var_float_type(var); + auto qualifier = is_constant_variable(name) ? "const " : ""; + printer->fmt_line("{}{}* {}{};", qualifier, type, name, value_initialize); + } + for (auto& var: codegen_int_variables) { + const auto& name = var.symbol->get_name(); + if (var.is_index || var.is_integer) { + auto qualifier = var.is_constant ? "const " : ""; + printer->fmt_line("{}{}* {}{};", qualifier, int_type, name, value_initialize); + } else { + auto qualifier = var.is_constant ? "const " : ""; + auto type = var.is_vdata ? "void*" : default_float_data_type(); + printer->fmt_line("{}{}* {}{};", qualifier, type, name, value_initialize); + } + } + + printer->fmt_line("{}* {}{};", + global_struct(), + naming::INST_GLOBAL_MEMBER, + print_initializers ? fmt::format("{{&{}}}", global_struct_instance()) + : std::string{}); + printer->pop_block(";"); +} + + +void CodegenCoreneuronCppVisitor::print_ion_var_structure() { + if (!ion_variable_struct_required()) { + return; + } + printer->add_newline(2); + printer->add_line("/** ion write variables */"); + printer->push_block("struct IonCurVar"); + + std::string float_type = default_float_data_type(); + std::vector members; + + for (auto& ion: info.ions) { + for (auto& var: ion.writes) { + printer->fmt_line("{} {};", float_type, var); + members.push_back(var); + } + } + for (auto& var: info.currents) { + if (!info.is_ion_variable(var)) { + printer->fmt_line("{} {};", float_type, var); + members.push_back(var); + } + } + + print_ion_var_constructor(members); + + printer->pop_block(";"); +} + + +void CodegenCoreneuronCppVisitor::print_ion_var_constructor( + const std::vector& members) { + // constructor + printer->add_newline(); + printer->add_indent(); + printer->add_text("IonCurVar() : "); + for (int i = 0; i < members.size(); i++) { + printer->fmt_text("{}(0)", members[i]); + if (i + 1 < members.size()) { + printer->add_text(", "); + } + } + printer->add_text(" {}"); + printer->add_newline(); +} + + +void CodegenCoreneuronCppVisitor::print_ion_variable() { + printer->add_line("IonCurVar ionvar;"); +} + + +void CodegenCoreneuronCppVisitor::print_global_variable_device_update_annotation() { + // nothing for cpu +} + + +void CodegenCoreneuronCppVisitor::print_setup_range_variable() { + auto type = float_data_type(); + printer->add_newline(2); + printer->add_line("/** allocate and setup array for range variable */"); + printer->fmt_push_block("static inline {}* setup_range_variable(double* variable, int n)", + type); + printer->fmt_line("{0}* data = ({0}*) mem_alloc(n, sizeof({0}));", type); + printer->push_block("for(size_t i = 0; i < n; i++)"); + printer->add_line("data[i] = variable[i];"); + printer->pop_block(); + printer->add_line("return data;"); + printer->pop_block(); +} + + +/** + * \details If floating point type like "float" is specified on command line then + * we can't turn all variables to new type. This is because certain variables + * are pointers to internal variables (e.g. ions). Hence, we check if given + * variable can be safely converted to new type. If so, return new type. + */ +std::string CodegenCoreneuronCppVisitor::get_range_var_float_type(const SymbolType& symbol) { + // clang-format off + auto with = NmodlType::read_ion_var + | NmodlType::write_ion_var + | NmodlType::pointer_var + | NmodlType::bbcore_pointer_var + | NmodlType::extern_neuron_variable; + // clang-format on + bool need_default_type = symbol->has_any_property(with); + if (need_default_type) { + return default_float_data_type(); + } + return float_data_type(); +} + + +void CodegenCoreneuronCppVisitor::print_instance_variable_setup() { + if (range_variable_setup_required()) { + print_setup_range_variable(); + } + + printer->add_newline(); + printer->add_line("// Allocate instance structure"); + printer->fmt_push_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", + method_name(naming::NRN_PRIVATE_CONSTRUCTOR_METHOD)); + printer->add_line("assert(!ml->instance);"); + printer->add_line("assert(!ml->global_variables);"); + printer->add_line("assert(ml->global_variables_size == 0);"); + printer->fmt_line("auto* const inst = new {}{{}};", instance_struct()); + printer->fmt_line("assert(inst->{} == &{});", + naming::INST_GLOBAL_MEMBER, + global_struct_instance()); + printer->add_line("ml->instance = inst;"); + printer->fmt_line("ml->global_variables = inst->{};", naming::INST_GLOBAL_MEMBER); + printer->fmt_line("ml->global_variables_size = sizeof({});", global_struct()); + printer->pop_block(); + printer->add_newline(); + + auto const cast_inst_and_assert_validity = [&]() { + printer->fmt_line("auto* const inst = static_cast<{}*>(ml->instance);", instance_struct()); + printer->add_line("assert(inst);"); + printer->fmt_line("assert(inst->{});", naming::INST_GLOBAL_MEMBER); + printer->fmt_line("assert(inst->{} == &{});", + naming::INST_GLOBAL_MEMBER, + global_struct_instance()); + printer->fmt_line("assert(inst->{} == ml->global_variables);", naming::INST_GLOBAL_MEMBER); + printer->fmt_line("assert(ml->global_variables_size == sizeof({}));", global_struct()); + }; + + // Must come before print_instance_struct_copy_to_device and + // print_instance_struct_delete_from_device + print_instance_struct_transfer_routine_declarations(); + + printer->add_line("// Deallocate the instance structure"); + printer->fmt_push_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", + method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD)); + cast_inst_and_assert_validity(); + print_instance_struct_delete_from_device(); + printer->add_multi_line(R"CODE( + delete inst; + ml->instance = nullptr; + ml->global_variables = nullptr; + ml->global_variables_size = 0; + )CODE"); + printer->pop_block(); + printer->add_newline(); + + + printer->add_line("/** initialize mechanism instance variables */"); + printer->push_block("static inline void setup_instance(NrnThread* nt, Memb_list* ml)"); + cast_inst_and_assert_validity(); + + std::string stride; + printer->add_line("int pnodecount = ml->_nodecount_padded;"); + stride = "*pnodecount"; + + printer->add_line("Datum* indexes = ml->pdata;"); + + auto const float_type = default_float_data_type(); + + int id = 0; + std::vector ptr_members{naming::INST_GLOBAL_MEMBER}; + for (auto const& [var, type]: info.neuron_global_variables) { + ptr_members.push_back(var->get_name()); + } + ptr_members.reserve(ptr_members.size() + codegen_float_variables.size() + + codegen_int_variables.size()); + for (auto& var: codegen_float_variables) { + auto name = var->get_name(); + auto range_var_type = get_range_var_float_type(var); + if (float_type == range_var_type) { + auto const variable = fmt::format("ml->data+{}{}", id, stride); + printer->fmt_line("inst->{} = {};", name, variable); + } else { + // TODO what MOD file exercises this? + printer->fmt_line("inst->{} = setup_range_variable(ml->data+{}{}, pnodecount);", + name, + id, + stride); + } + ptr_members.push_back(std::move(name)); + id += var->get_length(); + } + + for (auto& var: codegen_int_variables) { + auto name = var.symbol->get_name(); + auto const variable = [&var]() { + if (var.is_index || var.is_integer) { + return "ml->pdata"; + } else if (var.is_vdata) { + return "nt->_vdata"; + } else { + return "nt->_data"; + } + }(); + printer->fmt_line("inst->{} = {};", name, variable); + ptr_members.push_back(std::move(name)); + } + print_instance_struct_copy_to_device(); + printer->pop_block(); // setup_instance + printer->add_newline(); + + print_instance_struct_transfer_routines(ptr_members); +} + + +void CodegenCoreneuronCppVisitor::print_initial_block(const InitialBlock* node) { + if (info.artificial_cell) { + printer->add_line("double v = 0.0;"); + } else { + printer->add_line("int node_id = node_index[id];"); + printer->add_line("double v = voltage[node_id];"); + print_v_unused(); + } + + if (ion_variable_struct_required()) { + printer->add_line("IonCurVar ionvar;"); + } + + // read ion statements + auto read_statements = ion_read_statements(BlockType::Initial); + for (auto& statement: read_statements) { + printer->add_line(statement); + } + + // initialize state variables (excluding ion state) + for (auto& var: info.state_vars) { + auto name = var->get_name(); + if (!info.is_ionic_conc(name)) { + auto lhs = get_variable_name(name); + auto rhs = get_variable_name(name + "0"); + if (var->is_array()) { + for (int i = 0; i < var->get_length(); ++i) { + printer->fmt_line("{}[{}] = {};", lhs, i, rhs); + } + } else { + printer->fmt_line("{} = {};", lhs, rhs); + } + } + } + + // initial block + if (node != nullptr) { + const auto& block = node->get_statement_block(); + print_statement_block(*block, false, false); + } + + // write ion statements + auto write_statements = ion_write_statements(BlockType::Initial); + for (auto& statement: write_statements) { + auto text = process_shadow_update_statement(statement, BlockType::Initial); + printer->add_line(text); + } +} + + +void CodegenCoreneuronCppVisitor::print_global_function_common_code( + BlockType type, + const std::string& function_name) { + std::string method; + if (function_name.empty()) { + method = compute_method_name(type); + } else { + method = function_name; + } + auto args = "NrnThread* nt, Memb_list* ml, int type"; + + // watch statement function doesn't have type argument + if (type == BlockType::Watch) { + args = "NrnThread* nt, Memb_list* ml"; + } + + print_global_method_annotation(); + printer->fmt_push_block("void {}({})", method, args); + if (type != BlockType::Destructor && type != BlockType::Constructor) { + // We do not (currently) support DESTRUCTOR and CONSTRUCTOR blocks + // running anything on the GPU. + print_kernel_data_present_annotation_block_begin(); + } else { + /// TODO: Remove this when the code generation is propery done + /// Related to https://github.com/BlueBrain/nmodl/issues/692 + printer->add_line("#ifndef CORENEURON_BUILD"); + } + printer->add_multi_line(R"CODE( + int nodecount = ml->nodecount; + int pnodecount = ml->_nodecount_padded; + const int* node_index = ml->nodeindices; + double* data = ml->data; + const double* voltage = nt->_actual_v; + )CODE"); + + if (type == BlockType::Equation) { + printer->add_line("double* vec_rhs = nt->_actual_rhs;"); + printer->add_line("double* vec_d = nt->_actual_d;"); + print_rhs_d_shadow_variables(); + } + printer->add_line("Datum* indexes = ml->pdata;"); + printer->add_line("ThreadDatum* thread = ml->_thread;"); + + if (type == BlockType::Initial) { + printer->add_newline(); + printer->add_line("setup_instance(nt, ml);"); + } + printer->fmt_line("auto* const inst = static_cast<{}*>(ml->instance);", instance_struct()); + printer->add_newline(1); +} + +void CodegenCoreneuronCppVisitor::print_nrn_init(bool skip_init_check) { + codegen = true; + printer->add_newline(2); + printer->add_line("/** initialize channel */"); + + print_global_function_common_code(BlockType::Initial); + if (info.derivimplicit_used()) { + printer->add_newline(); + int nequation = info.num_equations; + int list_num = info.derivimplicit_list_num; + // clang-format off + printer->fmt_line("int& deriv_advance_flag = *deriv{}_advance(thread);", list_num); + printer->add_line("deriv_advance_flag = 0;"); + print_deriv_advance_flag_transfer_to_device(); + printer->fmt_line("auto ns = newtonspace{}(thread);", list_num); + printer->fmt_line("auto& th = thread[dith{}()];", list_num); + printer->push_block("if (*ns == nullptr)"); + printer->fmt_line("int vec_size = 2*{}*pnodecount*sizeof(double);", nequation); + printer->fmt_line("double* vec = makevector(vec_size);", nequation); + printer->fmt_line("th.pval = vec;", list_num); + printer->fmt_line("*ns = nrn_cons_newtonspace({}, pnodecount);", nequation); + print_newtonspace_transfer_to_device(); + printer->pop_block(); + // clang-format on + } + + // update global variable as those might be updated via python/hoc API + // NOTE: CoreNEURON has enough information to do this on its own, which + // would be neater. + print_global_variable_device_update_annotation(); + + if (skip_init_check) { + printer->push_block("if (_nrn_skip_initmodel == 0)"); + } + + if (!info.changed_dt.empty()) { + printer->fmt_line("double _save_prev_dt = {};", + get_variable_name(naming::NTHREAD_DT_VARIABLE)); + printer->fmt_line("{} = {};", + get_variable_name(naming::NTHREAD_DT_VARIABLE), + info.changed_dt); + print_dt_update_to_device(); + } + + print_channel_iteration_block_parallel_hint(BlockType::Initial, info.initial_node); + printer->push_block("for (int id = 0; id < nodecount; id++)"); + + if (info.net_receive_node != nullptr) { + printer->fmt_line("{} = -1e20;", get_variable_name("tsave")); + } + + print_initial_block(info.initial_node); + printer->pop_block(); + + if (!info.changed_dt.empty()) { + printer->fmt_line("{} = _save_prev_dt;", get_variable_name(naming::NTHREAD_DT_VARIABLE)); + print_dt_update_to_device(); + } + + printer->pop_block(); + + if (info.derivimplicit_used()) { + printer->add_line("deriv_advance_flag = 1;"); + print_deriv_advance_flag_transfer_to_device(); + } + + if (info.net_send_used && !info.artificial_cell) { + print_send_event_move(); + } + + print_kernel_data_present_annotation_block_end(); + if (skip_init_check) { + printer->pop_block(); + } + codegen = false; +} + +void CodegenCoreneuronCppVisitor::print_before_after_block(const ast::Block* node, + size_t block_id) { + codegen = true; + + std::string ba_type; + std::shared_ptr ba_block; + + if (node->is_before_block()) { + ba_block = dynamic_cast(node)->get_bablock(); + ba_type = "BEFORE"; + } else { + ba_block = dynamic_cast(node)->get_bablock(); + ba_type = "AFTER"; + } + + std::string ba_block_type = ba_block->get_type()->eval(); + + /// name of the before/after function + std::string function_name = method_name(fmt::format("nrn_before_after_{}", block_id)); + + /// print common function code like init/state/current + printer->add_newline(2); + printer->fmt_line("/** {} of block type {} # {} */", ba_type, ba_block_type, block_id); + print_global_function_common_code(BlockType::BeforeAfter, function_name); + + print_channel_iteration_block_parallel_hint(BlockType::BeforeAfter, node); + printer->push_block("for (int id = 0; id < nodecount; id++)"); + + printer->add_line("int node_id = node_index[id];"); + printer->add_line("double v = voltage[node_id];"); + print_v_unused(); + + // read ion statements + const auto& read_statements = ion_read_statements(BlockType::Equation); + for (auto& statement: read_statements) { + printer->add_line(statement); + } + + /// print main body + printer->add_indent(); + print_statement_block(*ba_block->get_statement_block()); + printer->add_newline(); + + // write ion statements + const auto& write_statements = ion_write_statements(BlockType::Equation); + for (auto& statement: write_statements) { + auto text = process_shadow_update_statement(statement, BlockType::Equation); + printer->add_line(text); + } + + /// loop end including data annotation block + printer->pop_block(); + printer->pop_block(); + print_kernel_data_present_annotation_block_end(); + + codegen = false; +} + +void CodegenCoreneuronCppVisitor::print_nrn_constructor() { + printer->add_newline(2); + print_global_function_common_code(BlockType::Constructor); + if (info.constructor_node != nullptr) { + const auto& block = info.constructor_node->get_statement_block(); + print_statement_block(*block, false, false); + } + printer->add_line("#endif"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_nrn_destructor() { + printer->add_newline(2); + print_global_function_common_code(BlockType::Destructor); + if (info.destructor_node != nullptr) { + const auto& block = info.destructor_node->get_statement_block(); + print_statement_block(*block, false, false); + } + printer->add_line("#endif"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_functors_definitions() { + codegen = true; + for (const auto& functor_name: info.functor_names) { + printer->add_newline(2); + print_functor_definition(*functor_name.first); + } + codegen = false; +} + + +void CodegenCoreneuronCppVisitor::print_nrn_alloc() { + printer->add_newline(2); + auto method = method_name(naming::NRN_ALLOC_METHOD); + printer->fmt_push_block("static void {}(double* data, Datum* indexes, int type)", method); + printer->add_line("// do nothing"); + printer->pop_block(); +} + +/** + * \todo Number of watch could be more than number of statements + * according to grammar. Check if this is correctly handled in neuron + * and coreneuron. + */ +void CodegenCoreneuronCppVisitor::print_watch_activate() { + if (info.watch_statements.empty()) { + return; + } + codegen = true; + printer->add_newline(2); + auto inst = fmt::format("{}* inst", instance_struct()); + + printer->fmt_push_block( + "static void nrn_watch_activate({}, int id, int pnodecount, int watch_id, " + "double v, bool &watch_remove)", + inst); + + // initialize all variables only during first watch statement + printer->push_block("if (watch_remove == false)"); + for (int i = 0; i < info.watch_count; i++) { + auto name = get_variable_name(fmt::format("watch{}", i + 1)); + printer->fmt_line("{} = 0;", name); + } + printer->add_line("watch_remove = true;"); + printer->pop_block(); + + /** + * \todo Similar to neuron/coreneuron we are using + * first watch and ignoring rest. + */ + for (int i = 0; i < info.watch_statements.size(); i++) { + auto statement = info.watch_statements[i]; + printer->fmt_push_block("if (watch_id == {})", i); + + auto varname = get_variable_name(fmt::format("watch{}", i + 1)); + printer->add_indent(); + printer->fmt_text("{} = 2 + (", varname); + auto watch = statement->get_statements().front(); + watch->get_expression()->visit_children(*this); + printer->add_text(");"); + printer->add_newline(); + + printer->pop_block(); + } + printer->pop_block(); + codegen = false; +} + + +/** + * \todo Similar to print_watch_activate, we are using only + * first watch. need to verify with neuron/coreneuron about rest. + */ +void CodegenCoreneuronCppVisitor::print_watch_check() { + if (info.watch_statements.empty()) { + return; + } + codegen = true; + printer->add_newline(2); + printer->add_line("/** routine to check watch activation */"); + print_global_function_common_code(BlockType::Watch); + + // WATCH statements appears in NET_RECEIVE block and while printing + // net_receive function we already check if it contains any MUTEX/PROTECT + // constructs. As WATCH is not a top level block but list of statements, + // we don't need to have ivdep pragma related check + print_channel_iteration_block_parallel_hint(BlockType::Watch, nullptr); + + printer->push_block("for (int id = 0; id < nodecount; id++)"); + + if (info.is_voltage_used_by_watch_statements()) { + printer->add_line("int node_id = node_index[id];"); + printer->add_line("double v = voltage[node_id];"); + print_v_unused(); + } + + // flat to make sure only one WATCH statement can be triggered at a time + printer->add_line("bool watch_untriggered = true;"); + + for (int i = 0; i < info.watch_statements.size(); i++) { + auto statement = info.watch_statements[i]; + const auto& watch = statement->get_statements().front(); + const auto& varname = get_variable_name(fmt::format("watch{}", i + 1)); + + // start block 1 + printer->fmt_push_block("if ({}&2 && watch_untriggered)", varname); + + // start block 2 + printer->add_indent(); + printer->add_text("if ("); + watch->get_expression()->accept(*this); + printer->add_text(") {"); + printer->add_newline(); + printer->increase_indent(); + + // start block 3 + printer->fmt_push_block("if (({}&1) == 0)", varname); + + printer->add_line("watch_untriggered = false;"); + + const auto& tqitem = get_variable_name("tqitem"); + const auto& point_process = get_variable_name("point_process"); + printer->add_indent(); + printer->add_text("net_send_buffering("); + const auto& t = get_variable_name("t"); + printer->fmt_text("nt, ml->_net_send_buffer, 0, {}, -1, {}, {}+0.0, ", + tqitem, + point_process, + t); + watch->get_value()->accept(*this); + printer->add_text(");"); + printer->add_newline(); + printer->pop_block(); + + printer->add_line(varname, " = 3;"); + // end block 3 + + // start block 3 + printer->decrease_indent(); + printer->push_block("} else"); + printer->add_line(varname, " = 2;"); + printer->pop_block(); + // end block 3 + + printer->pop_block(); + // end block 1 + } + + printer->pop_block(); + print_send_event_move(); + print_kernel_data_present_annotation_block_end(); + printer->pop_block(); + codegen = false; +} + + +void CodegenCoreneuronCppVisitor::print_net_receive_common_code(const Block& node, + bool need_mech_inst) { + printer->add_multi_line(R"CODE( + int tid = pnt->_tid; + int id = pnt->_i_instance; + double v = 0; + )CODE"); + + if (info.artificial_cell || node.is_initial_block()) { + printer->add_line("NrnThread* nt = nrn_threads + tid;"); + printer->add_line("Memb_list* ml = nt->_ml_list[pnt->_type];"); + } + if (node.is_initial_block()) { + print_kernel_data_present_annotation_block_begin(); + } + + printer->add_multi_line(R"CODE( + int nodecount = ml->nodecount; + int pnodecount = ml->_nodecount_padded; + double* data = ml->data; + double* weights = nt->weights; + Datum* indexes = ml->pdata; + ThreadDatum* thread = ml->_thread; + )CODE"); + if (need_mech_inst) { + printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); + } + + if (node.is_initial_block()) { + print_net_init_acc_serial_annotation_block_begin(); + } + + // rename variables but need to see if they are actually used + auto parameters = info.net_receive_node->get_parameters(); + if (!parameters.empty()) { + int i = 0; + printer->add_newline(); + for (auto& parameter: parameters) { + auto name = parameter->get_node_name(); + bool var_used = VarUsageVisitor().variable_used(node, "(*" + name + ")"); + if (var_used) { + printer->fmt_line("double* {} = weights + weight_index + {};", name, i); + RenameVisitor vr(name, "*" + name); + node.visit_children(vr); + } + i++; + } + } +} + + +void CodegenCoreneuronCppVisitor::print_net_send_call(const FunctionCall& node) { + auto const& arguments = node.get_arguments(); + const auto& tqitem = get_variable_name("tqitem"); + std::string weight_index = "weight_index"; + std::string pnt = "pnt"; + + // for functions not generated from NET_RECEIVE blocks (i.e. top level INITIAL block) + // the weight_index argument is 0. + if (!printing_net_receive && !printing_net_init) { + weight_index = "0"; + auto var = get_variable_name("point_process"); + if (info.artificial_cell) { + pnt = "(Point_process*)" + var; + } + } + + // artificial cells don't use spike buffering + // clang-format off + if (info.artificial_cell) { + printer->fmt_text("artcell_net_send(&{}, {}, {}, nt->_t+", tqitem, weight_index, pnt); + } else { + const auto& point_process = get_variable_name("point_process"); + const auto& t = get_variable_name("t"); + printer->add_text("net_send_buffering("); + printer->fmt_text("nt, ml->_net_send_buffer, 0, {}, {}, {}, {}+", tqitem, weight_index, point_process, t); + } + // clang-format off + print_vector_elements(arguments, ", "); + printer->add_text(')'); +} + + +void CodegenCoreneuronCppVisitor::print_net_move_call(const FunctionCall& node) { + if (!printing_net_receive && !printing_net_init) { + throw std::runtime_error("Error : net_move only allowed in NET_RECEIVE block"); + } + + auto const& arguments = node.get_arguments(); + const auto& tqitem = get_variable_name("tqitem"); + std::string weight_index = "-1"; + std::string pnt = "pnt"; + + // artificial cells don't use spike buffering + // clang-format off + if (info.artificial_cell) { + printer->fmt_text("artcell_net_move(&{}, {}, ", tqitem, pnt); + print_vector_elements(arguments, ", "); + printer->add_text(")"); + } else { + const auto& point_process = get_variable_name("point_process"); + printer->add_text("net_send_buffering("); + printer->fmt_text("nt, ml->_net_send_buffer, 2, {}, {}, {}, ", tqitem, weight_index, point_process); + print_vector_elements(arguments, ", "); + printer->add_text(", 0.0"); + printer->add_text(")"); + } +} + + +void CodegenCoreneuronCppVisitor::print_net_event_call(const FunctionCall& node) { + const auto& arguments = node.get_arguments(); + if (info.artificial_cell) { + printer->add_text("net_event(pnt, "); + print_vector_elements(arguments, ", "); + } else { + const auto& point_process = get_variable_name("point_process"); + printer->add_text("net_send_buffering("); + printer->fmt_text("nt, ml->_net_send_buffer, 1, -1, -1, {}, ", point_process); + print_vector_elements(arguments, ", "); + printer->add_text(", 0.0"); + } + printer->add_text(")"); +} + +/** + * Rename arguments to NET_RECEIVE block with corresponding pointer variable + * + * Arguments to NET_RECEIVE block are packed and passed via weight vector. These + * variables need to be replaced with corresponding pointer variable. For example, + * if mod file is like + * + * \code{.mod} + * NET_RECEIVE (weight, R){ + * INITIAL { + * R=1 + * } + * } + * \endcode + * + * then generated code for initial block should be: + * + * \code{.cpp} + * double* R = weights + weight_index + 0; + * (*R) = 1.0; + * \endcode + * + * So, the `R` in AST needs to be renamed with `(*R)`. + */ +static void rename_net_receive_arguments(const ast::NetReceiveBlock& net_receive_node, const ast::Node& node) { + const auto& parameters = net_receive_node.get_parameters(); + for (auto& parameter: parameters) { + const auto& name = parameter->get_node_name(); + auto var_used = VarUsageVisitor().variable_used(node, name); + if (var_used) { + RenameVisitor vr(name, "(*" + name + ")"); + node.get_statement_block()->visit_children(vr); + } + } +} + + +void CodegenCoreneuronCppVisitor::print_net_init() { + const auto node = info.net_receive_initial_node; + if (node == nullptr) { + return; + } + + // rename net_receive arguments used in the initial block of net_receive + rename_net_receive_arguments(*info.net_receive_node, *node); + + codegen = true; + printing_net_init = true; + auto args = "Point_process* pnt, int weight_index, double flag"; + printer->add_newline(2); + printer->add_line("/** initialize block for net receive */"); + printer->fmt_push_block("static void net_init({})", args); + auto block = node->get_statement_block().get(); + if (block->get_statements().empty()) { + printer->add_line("// do nothing"); + } else { + print_net_receive_common_code(*node); + print_statement_block(*block, false, false); + if (node->is_initial_block()) { + print_net_init_acc_serial_annotation_block_end(); + print_kernel_data_present_annotation_block_end(); + printer->add_line("auto& nsb = ml->_net_send_buffer;"); + print_net_send_buf_update_to_host(); + } + } + printer->pop_block(); + codegen = false; + printing_net_init = false; +} + + +void CodegenCoreneuronCppVisitor::print_send_event_move() { + printer->add_newline(); + printer->add_line("NetSendBuffer_t* nsb = ml->_net_send_buffer;"); + print_net_send_buf_update_to_host(); + printer->push_block("for (int i=0; i < nsb->_cnt; i++)"); + printer->add_multi_line(R"CODE( + int type = nsb->_sendtype[i]; + int tid = nt->id; + double t = nsb->_nsb_t[i]; + double flag = nsb->_nsb_flag[i]; + int vdata_index = nsb->_vdata_index[i]; + int weight_index = nsb->_weight_index[i]; + int point_index = nsb->_pnt_index[i]; + net_sem_from_gpu(type, vdata_index, weight_index, tid, point_index, t, flag); + )CODE"); + printer->pop_block(); + printer->add_line("nsb->_cnt = 0;"); + print_net_send_buf_count_update_to_device(); +} + + +std::string CodegenCoreneuronCppVisitor::net_receive_buffering_declaration() { + return fmt::format("void {}(NrnThread* nt)", method_name("net_buf_receive")); +} + + +void CodegenCoreneuronCppVisitor::print_get_memb_list() { + printer->add_line("Memb_list* ml = get_memb_list(nt);"); + printer->push_block("if (!ml)"); + printer->add_line("return;"); + printer->pop_block(); + printer->add_newline(); +} + + +void CodegenCoreneuronCppVisitor::print_net_receive_loop_begin() { + printer->add_line("int count = nrb->_displ_cnt;"); + print_channel_iteration_block_parallel_hint(BlockType::NetReceive, info.net_receive_node); + printer->push_block("for (int i = 0; i < count; i++)"); +} + + +void CodegenCoreneuronCppVisitor::print_net_receive_loop_end() { + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_net_receive_buffering(bool need_mech_inst) { + if (!net_receive_required() || info.artificial_cell) { + return; + } + printer->add_newline(2); + printer->push_block(net_receive_buffering_declaration()); + + print_get_memb_list(); + + const auto& net_receive = method_name("net_receive_kernel"); + + print_kernel_data_present_annotation_block_begin(); + + printer->add_line("NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;"); + if (need_mech_inst) { + printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); + } + print_net_receive_loop_begin(); + printer->add_line("int start = nrb->_displ[i];"); + printer->add_line("int end = nrb->_displ[i+1];"); + printer->push_block("for (int j = start; j < end; j++)"); + printer->add_multi_line(R"CODE( + int index = nrb->_nrb_index[j]; + int offset = nrb->_pnt_index[index]; + double t = nrb->_nrb_t[index]; + int weight_index = nrb->_weight_index[index]; + double flag = nrb->_nrb_flag[index]; + Point_process* point_process = nt->pntprocs + offset; + )CODE"); + printer->add_line(net_receive, "(t, point_process, inst, nt, ml, weight_index, flag);"); + printer->pop_block(); + print_net_receive_loop_end(); + + print_device_stream_wait(); + printer->add_line("nrb->_displ_cnt = 0;"); + printer->add_line("nrb->_cnt = 0;"); + + if (info.net_send_used || info.net_event_used) { + print_send_event_move(); + } + + print_kernel_data_present_annotation_block_end(); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_net_send_buffering_cnt_update() const { + printer->add_line("i = nsb->_cnt++;"); +} + + +void CodegenCoreneuronCppVisitor::print_net_send_buffering_grow() { + printer->push_block("if (i >= nsb->_size)"); + printer->add_line("nsb->grow();"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_net_send_buffering() { + if (!net_send_buffer_required()) { + return; + } + + printer->add_newline(2); + print_device_method_annotation(); + auto args = + "const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, " + "int weight_index, int point_index, double t, double flag"; + printer->fmt_push_block("static inline void net_send_buffering({})", args); + printer->add_line("int i = 0;"); + print_net_send_buffering_cnt_update(); + print_net_send_buffering_grow(); + printer->push_block("if (i < nsb->_size)"); + printer->add_multi_line(R"CODE( + nsb->_sendtype[i] = type; + nsb->_vdata_index[i] = vdata_index; + nsb->_weight_index[i] = weight_index; + nsb->_pnt_index[i] = point_index; + nsb->_nsb_t[i] = t; + nsb->_nsb_flag[i] = flag; + )CODE"); + printer->pop_block(); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_net_receive_kernel() { + if (!net_receive_required()) { + return; + } + codegen = true; + printing_net_receive = true; + const auto node = info.net_receive_node; + + // rename net_receive arguments used in the block itself + rename_net_receive_arguments(*info.net_receive_node, *node); + + std::string name; + ParamVector params; + if (!info.artificial_cell) { + name = method_name("net_receive_kernel"); + params.emplace_back("", "double", "", "t"); + params.emplace_back("", "Point_process*", "", "pnt"); + params.emplace_back("", fmt::format("{}*", instance_struct()), + "", "inst"); + params.emplace_back("", "NrnThread*", "", "nt"); + params.emplace_back("", "Memb_list*", "", "ml"); + params.emplace_back("", "int", "", "weight_index"); + params.emplace_back("", "double", "", "flag"); + } else { + name = method_name("net_receive"); + params.emplace_back("", "Point_process*", "", "pnt"); + params.emplace_back("", "int", "", "weight_index"); + params.emplace_back("", "double", "", "flag"); + } + + printer->add_newline(2); + printer->fmt_push_block("static inline void {}({})", name, get_parameter_str(params)); + print_net_receive_common_code(*node, info.artificial_cell); + if (info.artificial_cell) { + printer->add_line("double t = nt->_t;"); + } + + // set voltage variable if it is used in the block (e.g. for WATCH statement) + auto v_used = VarUsageVisitor().variable_used(*node->get_statement_block(), "v"); + if (v_used) { + printer->add_line("int node_id = ml->nodeindices[id];"); + printer->add_line("v = nt->_actual_v[node_id];"); + } + + printer->fmt_line("{} = t;", get_variable_name("tsave")); + + if (info.is_watch_used()) { + printer->add_line("bool watch_remove = false;"); + } + + printer->add_indent(); + node->get_statement_block()->accept(*this); + printer->add_newline(); + printer->pop_block(); + + printing_net_receive = false; + codegen = false; +} + + +void CodegenCoreneuronCppVisitor::print_net_receive() { + if (!net_receive_required()) { + return; + } + codegen = true; + printing_net_receive = true; + if (!info.artificial_cell) { + const auto& name = method_name("net_receive"); + ParamVector params; + params.emplace_back("", "Point_process*", "", "pnt"); + params.emplace_back("", "int", "", "weight_index"); + params.emplace_back("", "double", "", "flag"); + printer->add_newline(2); + printer->fmt_push_block("static void {}({})", name, get_parameter_str(params)); + printer->add_line("NrnThread* nt = nrn_threads + pnt->_tid;"); + printer->add_line("Memb_list* ml = get_memb_list(nt);"); + printer->add_line("NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;"); + printer->push_block("if (nrb->_cnt >= nrb->_size)"); + printer->add_line("realloc_net_receive_buffer(nt, ml);"); + printer->pop_block(); + printer->add_multi_line(R"CODE( + int id = nrb->_cnt; + nrb->_pnt_index[id] = pnt-nt->pntprocs; + nrb->_weight_index[id] = weight_index; + nrb->_nrb_t[id] = nt->_t; + nrb->_nrb_flag[id] = flag; + nrb->_cnt++; + )CODE"); + printer->pop_block(); + } + printing_net_receive = false; + codegen = false; +} + + +/** + * \todo Data is not derived. Need to add instance into instance struct? + * data used here is wrong in AoS because as in original implementation, + * data is not incremented every iteration for AoS. May be better to derive + * actual variable names? [resolved now?] + * slist needs to added as local variable + */ +void CodegenCoreneuronCppVisitor::print_derivimplicit_kernel(const Block& block) { + auto ext_args = external_method_arguments(); + auto ext_params = external_method_parameters(); + auto suffix = info.mod_suffix; + auto list_num = info.derivimplicit_list_num; + auto block_name = block.get_node_name(); + auto primes_size = info.primes_size; + auto stride = "*pnodecount+id"; + + printer->add_newline(2); + + printer->push_block("namespace"); + printer->fmt_push_block("struct _newton_{}_{}", block_name, info.mod_suffix); + printer->fmt_push_block("int operator()({}) const", external_method_parameters()); + auto const instance = fmt::format("auto* const inst = static_cast<{0}*>(ml->instance);", + instance_struct()); + auto const slist1 = fmt::format("auto const& slist{} = {};", + list_num, + get_variable_name(fmt::format("slist{}", list_num))); + auto const slist2 = fmt::format("auto& slist{} = {};", + list_num + 1, + get_variable_name(fmt::format("slist{}", list_num + 1))); + auto const dlist1 = fmt::format("auto const& dlist{} = {};", + list_num, + get_variable_name(fmt::format("dlist{}", list_num))); + auto const dlist2 = fmt::format( + "double* dlist{} = static_cast(thread[dith{}()].pval) + ({}*pnodecount);", + list_num + 1, + list_num, + info.primes_size); + printer->add_line(instance); + if (ion_variable_struct_required()) { + print_ion_variable(); + } + printer->fmt_line("double* savstate{} = static_cast(thread[dith{}()].pval);", + list_num, + list_num); + printer->add_line(slist1); + printer->add_line(dlist1); + printer->add_line(dlist2); + codegen = true; + print_statement_block(*block.get_statement_block(), false, false); + codegen = false; + printer->add_line("int counter = -1;"); + printer->fmt_push_block("for (int i=0; i<{}; i++)", info.num_primes); + printer->fmt_push_block("if (*deriv{}_advance(thread))", list_num); + printer->fmt_line( + "dlist{0}[(++counter){1}] = " + "data[dlist{2}[i]{1}]-(data[slist{2}[i]{1}]-savstate{2}[i{1}])/nt->_dt;", + list_num + 1, + stride, + list_num); + printer->chain_block("else"); + printer->fmt_line("dlist{0}[(++counter){1}] = data[slist{2}[i]{1}]-savstate{2}[i{1}];", + list_num + 1, + stride, + list_num); + printer->pop_block(); + printer->pop_block(); + printer->add_line("return 0;"); + printer->pop_block(); // operator() + printer->pop_block(";"); // struct + printer->pop_block(); // namespace + printer->add_newline(); + printer->fmt_push_block("int {}_{}({})", block_name, suffix, ext_params); + printer->add_line(instance); + printer->fmt_line("double* savstate{} = (double*) thread[dith{}()].pval;", list_num, list_num); + printer->add_line(slist1); + printer->add_line(slist2); + printer->add_line(dlist2); + printer->fmt_push_block("for (int i=0; i<{}; i++)", info.num_primes); + printer->fmt_line("savstate{}[i{}] = data[slist{}[i]{}];", list_num, stride, list_num, stride); + printer->pop_block(); + printer->fmt_line( + "int reset = nrn_newton_thread(static_cast(*newtonspace{}(thread)), {}, " + "slist{}, _newton_{}_{}{{}}, dlist{}, {});", + list_num, + primes_size, + list_num + 1, + block_name, + suffix, + list_num + 1, + ext_args); + printer->add_line("return reset;"); + printer->pop_block(); + printer->add_newline(2); +} + + +void CodegenCoreneuronCppVisitor::print_newtonspace_transfer_to_device() const { + // nothing to do on cpu +} + + +/****************************************************************************************/ +/* Print nrn_state routine */ +/****************************************************************************************/ + + +void CodegenCoreneuronCppVisitor::print_nrn_state() { + if (!nrn_state_required()) { + return; + } + codegen = true; + + printer->add_newline(2); + printer->add_line("/** update state */"); + print_global_function_common_code(BlockType::State); + print_channel_iteration_block_parallel_hint(BlockType::State, info.nrn_state_block); + printer->push_block("for (int id = 0; id < nodecount; id++)"); + + printer->add_line("int node_id = node_index[id];"); + printer->add_line("double v = voltage[node_id];"); + print_v_unused(); + + /** + * \todo Eigen solver node also emits IonCurVar variable in the functor + * but that shouldn't update ions in derivative block + */ + if (ion_variable_struct_required()) { + print_ion_variable(); + } + + auto read_statements = ion_read_statements(BlockType::State); + for (auto& statement: read_statements) { + printer->add_line(statement); + } + + if (info.nrn_state_block) { + info.nrn_state_block->visit_children(*this); + } + + if (info.currents.empty() && info.breakpoint_node != nullptr) { + auto block = info.breakpoint_node->get_statement_block(); + print_statement_block(*block, false, false); + } + + const auto& write_statements = ion_write_statements(BlockType::State); + for (auto& statement: write_statements) { + const auto& text = process_shadow_update_statement(statement, BlockType::State); + printer->add_line(text); + } + printer->pop_block(); + + print_kernel_data_present_annotation_block_end(); + + printer->pop_block(); + codegen = false; +} + + +/****************************************************************************************/ +/* Print nrn_cur related routines */ +/****************************************************************************************/ + + +void CodegenCoreneuronCppVisitor::print_nrn_current(const BreakpointBlock& node) { + const auto& args = internal_method_parameters(); + const auto& block = node.get_statement_block(); + printer->add_newline(2); + print_device_method_annotation(); + printer->fmt_push_block("inline double nrn_current_{}({})", + info.mod_suffix, + get_parameter_str(args)); + printer->add_line("double current = 0.0;"); + print_statement_block(*block, false, false); + for (auto& current: info.currents) { + const auto& name = get_variable_name(current); + printer->fmt_line("current += {};", name); + } + printer->add_line("return current;"); + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_nrn_cur_conductance_kernel(const BreakpointBlock& node) { + const auto& block = node.get_statement_block(); + print_statement_block(*block, false, false); + if (!info.currents.empty()) { + std::string sum; + for (const auto& current: info.currents) { + auto var = breakpoint_current(current); + sum += get_variable_name(var); + if (¤t != &info.currents.back()) { + sum += "+"; + } + } + printer->fmt_line("double rhs = {};", sum); + } + + std::string sum; + for (const auto& conductance: info.conductances) { + auto var = breakpoint_current(conductance.variable); + sum += get_variable_name(var); + if (&conductance != &info.conductances.back()) { + sum += "+"; + } + } + printer->fmt_line("double g = {};", sum); + + for (const auto& conductance: info.conductances) { + if (!conductance.ion.empty()) { + const auto& lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + conductance.ion + "dv"; + const auto& rhs = get_variable_name(conductance.variable); + const ShadowUseStatement statement{lhs, "+=", rhs}; + const auto& text = process_shadow_update_statement(statement, BlockType::Equation); + printer->add_line(text); + } + } +} + + +void CodegenCoreneuronCppVisitor::print_nrn_cur_non_conductance_kernel() { + printer->fmt_line("double g = nrn_current_{}({}+0.001);", + info.mod_suffix, + internal_method_arguments()); + for (auto& ion: info.ions) { + for (auto& var: ion.writes) { + if (ion.is_ionic_current(var)) { + const auto& name = get_variable_name(var); + printer->fmt_line("double di{} = {};", ion.name, name); + } + } + } + printer->fmt_line("double rhs = nrn_current_{}({});", + info.mod_suffix, + internal_method_arguments()); + printer->add_line("g = (g-rhs)/0.001;"); + for (auto& ion: info.ions) { + for (auto& var: ion.writes) { + if (ion.is_ionic_current(var)) { + const auto& lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + ion.name + "dv"; + auto rhs = fmt::format("(di{}-{})/0.001", ion.name, get_variable_name(var)); + if (info.point_process) { + auto area = get_variable_name(naming::NODE_AREA_VARIABLE); + rhs += fmt::format("*1.e2/{}", area); + } + const ShadowUseStatement statement{lhs, "+=", rhs}; + const auto& text = process_shadow_update_statement(statement, BlockType::Equation); + printer->add_line(text); + } + } + } +} + + +void CodegenCoreneuronCppVisitor::print_nrn_cur_kernel(const BreakpointBlock& node) { + printer->add_line("int node_id = node_index[id];"); + printer->add_line("double v = voltage[node_id];"); + print_v_unused(); + if (ion_variable_struct_required()) { + print_ion_variable(); + } + + const auto& read_statements = ion_read_statements(BlockType::Equation); + for (auto& statement: read_statements) { + printer->add_line(statement); + } + + if (info.conductances.empty()) { + print_nrn_cur_non_conductance_kernel(); + } else { + print_nrn_cur_conductance_kernel(node); + } + + const auto& write_statements = ion_write_statements(BlockType::Equation); + for (auto& statement: write_statements) { + auto text = process_shadow_update_statement(statement, BlockType::Equation); + printer->add_line(text); + } + + if (info.point_process) { + const auto& area = get_variable_name(naming::NODE_AREA_VARIABLE); + printer->fmt_line("double mfactor = 1.e2/{};", area); + printer->add_line("g = g*mfactor;"); + printer->add_line("rhs = rhs*mfactor;"); + } + + print_g_unused(); +} + + +void CodegenCoreneuronCppVisitor::print_fast_imem_calculation() { + if (!info.electrode_current) { + return; + } + std::string rhs, d; + auto rhs_op = operator_for_rhs(); + auto d_op = operator_for_d(); + if (info.point_process) { + rhs = "shadow_rhs[id]"; + d = "shadow_d[id]"; + } else { + rhs = "rhs"; + d = "g"; + } + + printer->push_block("if (nt->nrn_fast_imem)"); + if (nrn_cur_reduction_loop_required()) { + printer->push_block("for (int id = 0; id < nodecount; id++)"); + printer->add_line("int node_id = node_index[id];"); + } + printer->fmt_line("nt->nrn_fast_imem->nrn_sav_rhs[node_id] {} {};", rhs_op, rhs); + printer->fmt_line("nt->nrn_fast_imem->nrn_sav_d[node_id] {} {};", d_op, d); + if (nrn_cur_reduction_loop_required()) { + printer->pop_block(); + } + printer->pop_block(); +} + + +void CodegenCoreneuronCppVisitor::print_nrn_cur() { + if (!nrn_cur_required()) { + return; + } + + codegen = true; + if (info.conductances.empty()) { + print_nrn_current(*info.breakpoint_node); + } + + printer->add_newline(2); + printer->add_line("/** update current */"); + print_global_function_common_code(BlockType::Equation); + print_channel_iteration_block_parallel_hint(BlockType::Equation, info.breakpoint_node); + printer->push_block("for (int id = 0; id < nodecount; id++)"); + print_nrn_cur_kernel(*info.breakpoint_node); + print_nrn_cur_matrix_shadow_update(); + if (!nrn_cur_reduction_loop_required()) { + print_fast_imem_calculation(); + } + printer->pop_block(); + + if (nrn_cur_reduction_loop_required()) { + printer->push_block("for (int id = 0; id < nodecount; id++)"); + print_nrn_cur_matrix_shadow_reduction(); + printer->pop_block(); + print_fast_imem_calculation(); + } + + print_kernel_data_present_annotation_block_end(); + printer->pop_block(); + codegen = false; +} + + +/****************************************************************************************/ +/* Main code printing entry points */ +/****************************************************************************************/ + +void CodegenCoreneuronCppVisitor::print_headers_include() { + print_standard_includes(); + print_backend_includes(); + print_coreneuron_includes(); +} + + +void CodegenCoreneuronCppVisitor::print_namespace_begin() { + print_namespace_start(); + print_backend_namespace_start(); +} + + +void CodegenCoreneuronCppVisitor::print_namespace_end() { + print_backend_namespace_stop(); + print_namespace_stop(); +} + + +void CodegenCoreneuronCppVisitor::print_common_getters() { + print_first_pointer_var_index_getter(); + print_net_receive_arg_size_getter(); + print_thread_getters(); + print_num_variable_getter(); + print_mech_type_getter(); + print_memb_list_getter(); +} + + +void CodegenCoreneuronCppVisitor::print_data_structures(bool print_initializers) { + print_mechanism_global_var_structure(print_initializers); + print_mechanism_range_var_structure(print_initializers); + print_ion_var_structure(); +} + + +void CodegenCoreneuronCppVisitor::print_v_unused() const { + if (!info.vectorize) { + return; + } + printer->add_multi_line(R"CODE( + #if NRN_PRCELLSTATE + inst->v_unused[id] = v; + #endif + )CODE"); +} + + +void CodegenCoreneuronCppVisitor::print_g_unused() const { + printer->add_multi_line(R"CODE( + #if NRN_PRCELLSTATE + inst->g_unused[id] = g; + #endif + )CODE"); +} + + +void CodegenCoreneuronCppVisitor::print_compute_functions() { + print_top_verbatim_blocks(); + for (const auto& procedure: info.procedures) { + print_procedure(*procedure); + } + for (const auto& function: info.functions) { + print_function(*function); + } + for (const auto& function: info.function_tables) { + print_function_tables(*function); + } + for (size_t i = 0; i < info.before_after_blocks.size(); i++) { + print_before_after_block(info.before_after_blocks[i], i); + } + for (const auto& callback: info.derivimplicit_callbacks) { + const auto& block = *callback->get_node_to_solve(); + print_derivimplicit_kernel(block); + } + print_net_send_buffering(); + print_net_init(); + print_watch_activate(); + print_watch_check(); + print_net_receive_kernel(); + print_net_receive(); + print_net_receive_buffering(); + print_nrn_init(); + print_nrn_cur(); + print_nrn_state(); +} + + +void CodegenCoreneuronCppVisitor::print_codegen_routines() { + codegen = true; + print_backend_info(); + print_headers_include(); + print_namespace_begin(); + print_nmodl_constants(); + print_prcellstate_macros(); + print_mechanism_info(); + print_data_structures(true); + print_global_variables_for_hoc(); + print_common_getters(); + print_memory_allocation_routine(); + print_abort_routine(); + print_thread_memory_callbacks(); + print_instance_variable_setup(); + print_nrn_alloc(); + print_nrn_constructor(); + print_nrn_destructor(); + print_function_prototypes(); + print_functors_definitions(); + print_compute_functions(); + print_check_table_thread_function(); + print_mechanism_register(); + print_namespace_end(); + codegen = false; +} + + +/****************************************************************************************/ +/* Overloaded visitor routines */ +/****************************************************************************************/ + + +void CodegenCoreneuronCppVisitor::visit_derivimplicit_callback(const ast::DerivimplicitCallback& node) { + if (!codegen) { + return; + } + printer->fmt_line("{}_{}({});", + node.get_node_to_solve()->get_node_name(), + info.mod_suffix, + external_method_arguments()); +} + + +void CodegenCoreneuronCppVisitor::visit_eigen_newton_solver_block( + const ast::EigenNewtonSolverBlock& node) { + // solution vector to store copy of state vars for Newton solver + printer->add_newline(); + + auto float_type = default_float_data_type(); + int N = node.get_n_state_vars()->get_value(); + printer->fmt_line("Eigen::Matrix<{}, {}, 1> nmodl_eigen_xm;", float_type, N); + printer->fmt_line("{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); + + print_statement_block(*node.get_setup_x_block(), false, false); + + // call newton solver with functor and X matrix that contains state vars + printer->add_line("// call newton solver"); + printer->fmt_line("{} newton_functor(nt, inst, id, pnodecount, v, indexes, data, thread);", + info.functor_names[&node]); + printer->add_line("newton_functor.initialize();"); + printer->add_line( + "int newton_iterations = nmodl::newton::newton_solver(nmodl_eigen_xm, newton_functor);"); + printer->add_line( + "if (newton_iterations < 0) assert(false && \"Newton solver did not converge!\");"); + + // assign newton solver results in matrix X to state vars + print_statement_block(*node.get_update_states_block(), false, false); + printer->add_line("newton_functor.finalize();"); +} + + +void CodegenCoreneuronCppVisitor::visit_eigen_linear_solver_block( + const ast::EigenLinearSolverBlock& node) { + printer->add_newline(); + + const std::string float_type = default_float_data_type(); + int N = node.get_n_state_vars()->get_value(); + printer->fmt_line("Eigen::Matrix<{0}, {1}, 1> nmodl_eigen_xm, nmodl_eigen_fm;", float_type, N); + printer->fmt_line("Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm;", float_type, N); + if (N <= 4) + printer->fmt_line("Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm_inv;", float_type, N); + printer->fmt_line("{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); + printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type); + printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type); + print_statement_block(*node.get_variable_block(), false, false); + print_statement_block(*node.get_initialize_block(), false, false); + print_statement_block(*node.get_setup_x_block(), false, false); + + printer->add_newline(); + print_eigen_linear_solver(float_type, N); + printer->add_newline(); + + print_statement_block(*node.get_update_states_block(), false, false); + print_statement_block(*node.get_finalize_block(), false, false); +} + + +void CodegenCoreneuronCppVisitor::visit_for_netcon(const ast::ForNetcon& node) { + // For_netcon should take the same arguments as net_receive and apply the operations + // in the block to the weights of the netcons. Since all the weights are on the same vector, + // weights, we have a mask of operations that we apply iteratively, advancing the offset + // to the next netcon. + const auto& args = node.get_parameters(); + RenameVisitor v; + const auto& statement_block = node.get_statement_block(); + for (size_t i_arg = 0; i_arg < args.size(); ++i_arg) { + // sanitize node_name since we want to substitute names like (*w) as they are + auto old_name = + std::regex_replace(args[i_arg]->get_node_name(), regex_special_chars, R"(\$&)"); + const auto& new_name = fmt::format("weights[{} + nt->_fornetcon_weight_perm[i]]", i_arg); + v.set(old_name, new_name); + statement_block->accept(v); + } + + const auto index = + std::find_if(info.semantics.begin(), info.semantics.end(), [](const IndexSemantics& a) { + return a.name == naming::FOR_NETCON_SEMANTIC; + })->index; + + printer->fmt_text("const size_t offset = {}*pnodecount + id;", index); + printer->add_newline(); + printer->add_line( + "const size_t for_netcon_start = nt->_fornetcon_perm_indices[indexes[offset]];"); + printer->add_line( + "const size_t for_netcon_end = nt->_fornetcon_perm_indices[indexes[offset] + 1];"); + + printer->add_line("for (auto i = for_netcon_start; i < for_netcon_end; ++i) {"); + printer->increase_indent(); + print_statement_block(*statement_block, false, false); + printer->decrease_indent(); + + printer->add_line("}"); +} + + +void CodegenCoreneuronCppVisitor::visit_solution_expression(const SolutionExpression& node) { + auto block = node.get_node_to_solve().get(); + if (block->is_statement_block()) { + auto statement_block = dynamic_cast(block); + print_statement_block(*statement_block, false, false); + } else { + block->accept(*this); + } +} + + +void CodegenCoreneuronCppVisitor::visit_watch_statement(const ast::WatchStatement& /* node */) { + printer->add_text(fmt::format("nrn_watch_activate(inst, id, pnodecount, {}, v, watch_remove)", + current_watch_statement++)); +} + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/codegen_coreneuron_cpp_visitor.hpp b/src/codegen/codegen_coreneuron_cpp_visitor.hpp new file mode 100644 index 0000000000..cc59a2a77e --- /dev/null +++ b/src/codegen/codegen_coreneuron_cpp_visitor.hpp @@ -0,0 +1,1345 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * \dir + * \brief Code generation backend implementations for CoreNEURON + * + * \file + * \brief \copybrief nmodl::codegen::CodegenCoreneuronCppVisitor + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "codegen/codegen_cpp_visitor.hpp" +#include "codegen/codegen_info.hpp" +#include "codegen/codegen_naming.hpp" +#include "printer/code_printer.hpp" +#include "symtab/symbol_table.hpp" +#include "utils/logger.hpp" +#include "visitors/ast_visitor.hpp" + + +namespace nmodl { + +namespace codegen { + + +using printer::CodePrinter; + + +/** + * \defgroup codegen_backends Codegen Backends + * \ingroup codegen + * \brief Code generation backends for CoreNEURON + * \{ + */ + +/** + * \class CodegenCoreneuronCppVisitor + * \brief %Visitor for printing C++ code compatible with legacy api of CoreNEURON + * + * \todo + * - Handle define statement (i.e. macros) + * - If there is a return statement in the verbatim block + * of inlined function then it will be error. Need better + * error checking. For example, see netstim.mod where we + * have removed return from verbatim block. + */ +class CodegenCoreneuronCppVisitor: public CodegenCppVisitor { + protected: + /****************************************************************************************/ + /* Member variables */ + /****************************************************************************************/ + + + /****************************************************************************************/ + /* Generic information getters */ + /****************************************************************************************/ + + + /** + * Name of the simulator the code was generated for + */ + std::string simulator_name() override; + + + /** + * Name of the code generation backend + */ + virtual std::string backend_name() const override; + + + /** + * Name of structure that wraps range variables + */ + std::string instance_struct() const { + return fmt::format("{}_Instance", info.mod_suffix); + } + + + /** + * Name of structure that wraps global variables + */ + std::string global_struct() const { + return fmt::format("{}_Store", info.mod_suffix); + } + + + /** + * Name of the (host-only) global instance of `global_struct` + */ + std::string global_struct_instance() const { + return info.mod_suffix + "_global"; + } + + + /** + * Determine the number of threads to allocate + */ + int num_thread_objects() const noexcept { + return info.vectorize ? (info.thread_data_index + 1) : 0; + } + + + /****************************************************************************************/ + /* Common helper routines accross codegen functions */ + /****************************************************************************************/ + + + /** + * Determine the position in the data array for a given float variable + * \param name The name of a float variable + * \return The position index in the data array + */ + int position_of_float_var(const std::string& name) const override; + + + /** + * Determine the position in the data array for a given int variable + * \param name The name of an int variable + * \return The position index in the data array + */ + int position_of_int_var(const std::string& name) const override; + + + /** + * Determine the variable name for the "current" used in breakpoint block taking into account + * intermediate code transformations. + * \param current The variable name for the current used in the model + * \return The name for the current to be printed in C++ + */ + std::string breakpoint_current(std::string current) const; + + + /** + * For a given output block type, return statements for all read ion variables + * + * \param type The type of code block being generated + * \return A \c vector of strings representing the reading of ion variables + */ + std::vector ion_read_statements(BlockType type) const; + + + /** + * For a given output block type, return minimal statements for all read ion variables + * + * \param type The type of code block being generated + * \return A \c vector of strings representing the reading of ion variables + */ + std::vector ion_read_statements_optimized(BlockType type) const; + + + /** + * For a given output block type, return statements for writing back ion variables + * + * \param type The type of code block being generated + * \return A \c vector of strings representing the write-back of ion variables + */ + std::vector ion_write_statements(BlockType type); + + + /** + * Process a token in a verbatim block for possible variable renaming + * \param token The verbatim token to be processed + * \return The code after variable renaming + */ + std::string process_verbatim_token(const std::string& token); + + + /** + * Check if a structure for ion variables is required + * \return \c true if a structure fot ion variables must be generated + */ + bool ion_variable_struct_required() const; + + + /** + * Check if variable is qualified as constant + * \param name The name of variable + * \return \c true if it is constant + */ + virtual bool is_constant_variable(const std::string& name) const; + + + /****************************************************************************************/ + /* Backend specific routines */ + /****************************************************************************************/ + + + /** + * Generate the string representing the procedure parameter declaration + * + * The procedure parameters are stored in a vector of 4-tuples each representing a parameter. + * + * \param params The parameters that should be concatenated into the function parameter + * declaration + * \return The string representing the declaration of function parameters + */ + static std::string get_parameter_str(const ParamVector& params); + + + /** + * Print the code to copy derivative advance flag to device + */ + virtual void print_deriv_advance_flag_transfer_to_device() const; + + + /** + * Print pragma annotation for increase and capture of variable in automatic way + */ + virtual void print_device_atomic_capture_annotation() const; + + + /** + * Print the code to update NetSendBuffer_t count from device to host + */ + virtual void print_net_send_buf_count_update_to_host() const; + + + /** + * Print the code to update NetSendBuffer_t from device to host + */ + virtual void print_net_send_buf_update_to_host() const; + + + /** + * Print the code to update NetSendBuffer_t count from host to device + */ + virtual void print_net_send_buf_count_update_to_device() const; + + /** + * Print the code to update dt from host to device + */ + virtual void print_dt_update_to_device() const; + + /** + * Print the code to synchronise/wait on stream specific to NrnThread + */ + virtual void print_device_stream_wait() const; + + + /** + * Print accelerator annotations indicating data presence on device + */ + virtual void print_kernel_data_present_annotation_block_begin(); + + + /** + * Print matching block end of accelerator annotations for data presence on device + */ + virtual void print_kernel_data_present_annotation_block_end(); + + + /** + * Print accelerator kernels begin annotation for net_init kernel + */ + virtual void print_net_init_acc_serial_annotation_block_begin(); + + + /** + * Print accelerator kernels end annotation for net_init kernel + */ + virtual void print_net_init_acc_serial_annotation_block_end(); + + + /** + * Print pragma annotations for channel iterations + * + * This can be overriden by backends to provide additonal annotations or pragmas to enable + * for example SIMD code generation (e.g. through \c ivdep) + * The default implementation prints + * + * \code + * #pragma ivdep + * \endcode + * + * \param type The block type + */ + virtual void print_channel_iteration_block_parallel_hint(BlockType type, + const ast::Block* block); + + + /** + * Check if reduction block in \c nrn\_cur required + */ + virtual bool nrn_cur_reduction_loop_required(); + + + /** + * Print the setup method for setting matrix shadow vectors + * + */ + virtual void print_rhs_d_shadow_variables(); + + + /** + * Print the update to matrix elements with/without shadow vectors + * + */ + virtual void print_nrn_cur_matrix_shadow_update(); + + + /** + * Print the reduction to matrix elements from shadow vectors + * + */ + virtual void print_nrn_cur_matrix_shadow_reduction(); + + + /** + * Print atomic update pragma for reduction statements + */ + virtual void print_atomic_reduction_pragma() override; + + + /** + * Print the backend specific device method annotation + * + * \note This is not used for the C++ backend + */ + virtual void print_device_method_annotation(); + + + /** + * Print backend specific global method annotation + * + * \note This is not used for the C++ backend + */ + virtual void print_global_method_annotation(); + + + /** + * Prints the start of namespace for the backend-specific code + * + * For the C++ backend no additional namespace is required + */ + virtual void print_backend_namespace_start(); + + + /** + * Prints the end of namespace for the backend-specific code + * + * For the C++ backend no additional namespace is required + */ + virtual void print_backend_namespace_stop(); + + + /** + * Print backend specific includes (none needed for C++ backend) + */ + virtual void print_backend_includes(); + + + /** + * Check if ion variable copies should be avoided + */ + bool optimize_ion_variable_copies() const; + + + /** + * Print memory allocation routine + */ + virtual void print_memory_allocation_routine() const; + + + /** + * Print backend specific abort routine + */ + virtual void print_abort_routine() const; + + + /** + * Return the name of main compute kernels + * \param type A block type + */ + virtual std::string compute_method_name(BlockType type) const; + + + /** + * Instantiate global var instance + * + * For C++ code generation this is empty + * \return "" + */ + virtual void print_global_var_struct_decl(); + + + /** + * Print declarations of the functions used by \ref + * print_instance_struct_copy_to_device and \ref + * print_instance_struct_delete_from_device. + */ + virtual void print_instance_struct_transfer_routine_declarations() {} + + /** + * Print the definitions of the functions used by \ref + * print_instance_struct_copy_to_device and \ref + * print_instance_struct_delete_from_device. Declarations of these functions + * are printed by \ref print_instance_struct_transfer_routine_declarations. + * + * This updates the (pointer) member variables in the device copy of the + * instance struct to contain device pointers, which is why you must pass a + * list of names of those member variables. + * + * \param ptr_members List of instance struct member names. + */ + virtual void print_instance_struct_transfer_routines( + std::vector const& /* ptr_members */) {} + + + /** + * Transfer the instance struct to the device. This calls a function + * declared by \ref print_instance_struct_transfer_routine_declarations. + */ + virtual void print_instance_struct_copy_to_device() {} + + + /** + * Delete the instance struct from the device. This calls a function + * declared by \ref print_instance_struct_transfer_routine_declarations. + */ + virtual void print_instance_struct_delete_from_device() {} + + + /****************************************************************************************/ + /* Printing routines for code generation */ + /****************************************************************************************/ + + + /** + * Print call to internal or external function + * \param node The AST node representing a function call + */ + void print_function_call(const ast::FunctionCall& node) override; + + + /** + * Print top level (global scope) verbatim blocks + */ + void print_top_verbatim_blocks(); + + + /** + * Print function and procedures prototype declaration + */ + void print_function_prototypes() override; + + + /** + * Check if the given name exist in the symbol + * \return \c return a tuple if variable + * is an array otherwise + */ + std::tuple check_if_var_is_array(const std::string& name); + + + /** + * Print \c check\_function() for functions or procedure using table + * \param node The AST node representing a function or procedure block + */ + void print_table_check_function(const ast::Block& node); + + + /** + * Print replacement function for function or procedure using table + * \param node The AST node representing a function or procedure block + */ + void print_table_replacement_function(const ast::Block& node); + + + /** + * Print check_table functions + */ + void print_check_table_thread_function(); + + + /** + * Print nmodl function or procedure (common code) + * \param node the AST node representing the function or procedure in NMODL + * \param name the name of the function or procedure + */ + void print_function_or_procedure(const ast::Block& node, const std::string& name) override; + + + /** + * Common helper function to help printing function or procedure blocks + * \param node the AST node representing the function or procedure in NMODL + */ + void print_function_procedure_helper(const ast::Block& node) override; + + + /** + * Print NMODL procedure in target backend code + * \param node + */ + virtual void print_procedure(const ast::ProcedureBlock& node) override; + + + /** + * Print NMODL function in target backend code + * \param node + */ + void print_function(const ast::FunctionBlock& node) override; + + + /** + * Print NMODL function_table in target backend code + * \param node + */ + void print_function_tables(const ast::FunctionTableBlock& node); + + + bool is_functor_const(const ast::StatementBlock& variable_block, + const ast::StatementBlock& functor_block); + + + /** + * \brief Based on the \c EigenNewtonSolverBlock passed print the definition needed for its + * functor + * + * \param node \c EigenNewtonSolverBlock for which to print the functor + */ + void print_functor_definition(const ast::EigenNewtonSolverBlock& node); + + + virtual void print_eigen_linear_solver(const std::string& float_type, int N); + + + /****************************************************************************************/ + /* Code-specific helper routines */ + /****************************************************************************************/ + + + /** + * Arguments for functions that are defined and used internally. + * \return the method arguments + */ + std::string internal_method_arguments() override; + + + /** + * Parameters for internally defined functions + * \return the method parameters + */ + ParamVector internal_method_parameters() override; + + + /** + * Arguments for external functions called from generated code + * \return A string representing the arguments passed to an external function + */ + const char* external_method_arguments() noexcept override; + + + /** + * Parameters for functions in generated code that are called back from external code + * + * Functions registered in NEURON during initialization for callback must adhere to a prescribed + * calling convention. This method generates the string representing the function parameters for + * these externally called functions. + * \param table + * \return A string representing the parameters of the function + */ + const char* external_method_parameters(bool table = false) noexcept override; + + + /** + * Arguments for "_threadargs_" macro in neuron implementation + */ + std::string nrn_thread_arguments() const override; + + + /** + * Arguments for "_threadargs_" macro in neuron implementation + */ + std::string nrn_thread_internal_arguments() override; + + + /** + * Replace commonly used verbatim variables + * \param name A variable name to be checked and possibly updated + * \return The possibly replace variable name + */ + std::string replace_if_verbatim_variable(std::string name); + + + /** + * Process a verbatim block for possible variable renaming + * \param text The verbatim code to be processed + * \return The code with all variables renamed as needed + */ + std::string process_verbatim_text(std::string const& text) override; + + + /** + * Arguments for register_mech or point_register_mech function + */ + std::string register_mechanism_arguments() const override; + + + /** + * Return ion variable name and corresponding ion read variable name + * \param name The ion variable name + * \return The ion read variable name + */ + static std::pair read_ion_variable_name(const std::string& name); + + + /** + * Return ion variable name and corresponding ion write variable name + * \param name The ion variable name + * \return The ion write variable name + */ + static std::pair write_ion_variable_name(const std::string& name); + + + /** + * Generate Function call statement for nrn_wrote_conc + * \param ion_name The name of the ion variable + * \param concentration The name of the concentration variable + * \param index + * \return The string representing the function call + */ + std::string conc_write_statement(const std::string& ion_name, + const std::string& concentration, + int index); + + /** + * Process shadow update statement + * + * If the statement requires reduction then add it to vector of reduction statement and return + * statement using shadow update + * + * \param statement The statement that might require shadow updates + * \param type The target backend code block type + * \return The generated target backend code + */ + std::string process_shadow_update_statement(const ShadowUseStatement& statement, + BlockType type); + + + /****************************************************************************************/ + /* Code-specific printing routines for code generations */ + /****************************************************************************************/ + + + /** + * Print the getter method for index position of first pointer variable + * + */ + void print_first_pointer_var_index_getter(); + + + /** + * Print the getter methods for float and integer variables count + * + */ + void print_num_variable_getter(); + + + /** + * Print the getter method for getting number of arguments for net_receive + * + */ + void print_net_receive_arg_size_getter(); + + + /** + * Print the getter method for returning mechtype + * + */ + void print_mech_type_getter(); + + + /** + * Print the getter method for returning membrane list from NrnThread + * + */ + void print_memb_list_getter(); + + + /** + * Prints the start of the \c coreneuron namespace + */ + void print_namespace_start() override; + + + /** + * Prints the end of the \c coreneuron namespace + */ + void print_namespace_stop() override; + + + /** + * Print the getter method for thread variables and ids + * + */ + void print_thread_getters(); + + + /****************************************************************************************/ + /* Routines for returning variable name */ + /****************************************************************************************/ + + + /** + * Determine the updated name if the ion variable has been optimized + * \param name The ion variable name + * \return The updated name of the variable has been optimized (e.g. \c ena --> \c ion_ena) + */ + std::string update_if_ion_variable_name(const std::string& name) const; + + + /** + * Determine the name of a \c float variable given its symbol + * + * This function typically returns the accessor expression in backend code for the given symbol. + * Since the model variables are stored in data arrays and accessed by offset, this function + * will return the C++ string representing the array access at the correct offset + * + * \param symbol The symbol of a variable for which we want to obtain its name + * \param use_instance Should the variable be accessed via instance or data array + * \return The backend code string representing the access to the given variable + * symbol + */ + std::string float_variable_name(const SymbolType& symbol, bool use_instance) const override; + + + /** + * Determine the name of an \c int variable given its symbol + * + * This function typically returns the accessor expression in backend code for the given symbol. + * Since the model variables are stored in data arrays and accessed by offset, this function + * will return the C++ string representing the array access at the correct offset + * + * \param symbol The symbol of a variable for which we want to obtain its name + * \param name The name of the index variable + * \param use_instance Should the variable be accessed via instance or data array + * \return The backend code string representing the access to the given variable + * symbol + */ + std::string int_variable_name(const IndexVariableInfo& symbol, + const std::string& name, + bool use_instance) const override; + + + /** + * Determine the variable name for a global variable given its symbol + * \param symbol The symbol of a variable for which we want to obtain its name + * \param use_instance Should the variable be accessed via the (host-only) + * global variable or the instance-specific copy (also available on GPU). + * \return The C++ string representing the access to the global variable + */ + std::string global_variable_name(const SymbolType& symbol, + bool use_instance = true) const override; + + + /** + * Determine variable name in the structure of mechanism properties + * + * \param name Variable name that is being printed + * \param use_instance Should the variable be accessed via instance or data array + * \return The C++ string representing the access to the variable in the neuron + * thread structure + */ + std::string get_variable_name(const std::string& name, bool use_instance = true) const override; + + + /****************************************************************************************/ + /* Main printing routines for code generation */ + /****************************************************************************************/ + + + /** + * Print top file header printed in generated code + */ + void print_backend_info() override; + + + /** + * Print standard C/C++ includes + */ + void print_standard_includes() override; + + + /** + * Print includes from coreneuron + */ + void print_coreneuron_includes(); + + + void print_sdlists_init(bool print_initializers) override; + + + /** + * Print the structure that wraps all global variables used in the NMODL + * + * \param print_initializers Whether to include default values in the struct + * definition (true: int foo{42}; false: int foo;) + */ + void print_mechanism_global_var_structure(bool print_initializers) override; + + + /** + * Print static assertions about the global variable struct. + */ + virtual void print_global_var_struct_assertions() const; + + + /** + * Print byte arrays that register scalar and vector variables for hoc interface + * + */ + void print_global_variables_for_hoc() override; + + + /** + * Print the mechanism registration function + * + */ + void print_mechanism_register() override; + + + /** + * Print thread related memory allocation and deallocation callbacks + */ + void print_thread_memory_callbacks(); + + + /** + * Print structure of ion variables used for local copies + */ + void print_ion_var_structure(); + + + /** + * Print constructor of ion variables + * \param members The ion variable names + */ + virtual void print_ion_var_constructor(const std::vector& members); + + + /** + * Print the ion variable struct + */ + virtual void print_ion_variable(); + + + /** + * Print the pragma annotation to update global variables from host to the device + * + * \note This is not used for the C++ backend + */ + virtual void print_global_variable_device_update_annotation(); + + + /** + * Print the function that initialize range variable with different data type + */ + void print_setup_range_variable(); + + + /** + * Returns floating point type for given range variable symbol + * \param symbol A range variable symbol + */ + std::string get_range_var_float_type(const SymbolType& symbol); + + + /** + * Print initial block statements + * + * Generate the target backend code corresponding to the NMODL initial block statements + * + * \param node The AST Node representing a NMODL initial block + */ + void print_initial_block(const ast::InitialBlock* node); + + + /** + * Print common code for global functions like nrn_init, nrn_cur and nrn_state + * \param type The target backend code block type + */ + virtual void print_global_function_common_code(BlockType type, + const std::string& function_name = "") override; + + + /** + * Print the \c nrn\_init function definition + * \param skip_init_check \c true to generate code executing the initialization conditionally + */ + void print_nrn_init(bool skip_init_check = true); + + + /** + * Print NMODL before / after block in target backend code + * \param node AST node of type before/after type being printed + * \param block_id Index of the before/after block + */ + virtual void print_before_after_block(const ast::Block* node, size_t block_id); + + + /** + * Print nrn_constructor function definition + * + */ + void print_nrn_constructor() override; + + + /** + * Print nrn_destructor function definition + * + */ + void print_nrn_destructor() override; + + + /** + * Go through the map of \c EigenNewtonSolverBlock s and their corresponding functor names + * and print the functor definitions before the definitions of the functions of the generated + * file + * + */ + void print_functors_definitions(); + + + /** + * Print nrn_alloc function definition + * + */ + void print_nrn_alloc() override; + + + /** + * Print watch activate function + * + */ + void print_watch_activate(); + + + /** + * Print watch activate function + */ + void print_watch_check(); + + + /** + * Print the common code section for net receive related methods + * + * \param node The AST node representing the corresponding NMODL block + * \param need_mech_inst \c true if a local \c inst variable needs to be defined in generated + * code + */ + void print_net_receive_common_code(const ast::Block& node, bool need_mech_inst = true); + + + /** + * Print call to \c net\_send + * \param node The AST node representing the function call + */ + void print_net_send_call(const ast::FunctionCall& node); + + + /** + * Print call to net\_move + * \param node The AST node representing the function call + */ + void print_net_move_call(const ast::FunctionCall& node); + + + /** + * Print call to net\_event + * \param node The AST node representing the function call + */ + void print_net_event_call(const ast::FunctionCall& node); + + + /** + * Print initial block in the net receive block + */ + void print_net_init(); + + + /** + * Print send event move block used in net receive as well as watch + */ + void print_send_event_move(); + + + /** + * Generate the target backend code for the \c net\_receive\_buffering function delcaration + * \return The target code string + */ + virtual std::string net_receive_buffering_declaration(); + + + /** + * Print the target backend code for defining and checking a local \c Memb\_list variable + */ + virtual void print_get_memb_list(); + + + /** + * Print the code for the main \c net\_receive loop + */ + virtual void print_net_receive_loop_begin(); + + + /** + * Print the code for closing the main \c net\_receive loop + */ + virtual void print_net_receive_loop_end(); + + + /** + * Print kernel for buffering net_receive events + * + * This kernel is only needed for accelerator backends where \c net\_receive needs to be + * executed in two stages as the actual communication must be done in the host code. \param + * need_mech_inst \c true if the generated code needs a local inst variable to be defined + */ + void print_net_receive_buffering(bool need_mech_inst = true); + + + /** + * Print the code related to the update of NetSendBuffer_t cnt. For GPU this needs to be done + * with atomic operation, on CPU it's not needed. + * + */ + virtual void print_net_send_buffering_cnt_update() const; + + + /** + * Print statement that grows NetSendBuffering_t structure if needed. + * This function should be overridden for backends that cannot dynamically reallocate the buffer + */ + virtual void print_net_send_buffering_grow(); + + + /** + * Print kernel for buffering net_send events + * + * This kernel is only needed for accelerator backends where \c net\_send needs to be executed + * in two stages as the actual communication must be done in the host code. + */ + void print_net_send_buffering(); + + + /** + * Print \c net\_receive kernel function definition + */ + void print_net_receive_kernel(); + + + /** + * Print \c net\_receive function definition + */ + void print_net_receive(); + + + /** + * Print derivative kernel when \c derivimplicit method is used + * + * \param block The corresponding AST node representing an NMODL \c derivimplicit block + */ + void print_derivimplicit_kernel(const ast::Block& block); + + + /** + * Print code block to transfer newtonspace structure to device + */ + virtual void print_newtonspace_transfer_to_device() const; + + + /****************************************************************************************/ + /* Print nrn_state routine */ + /****************************************************************************************/ + + + /** + * Print nrn_state / state update function definition + */ + void print_nrn_state() override; + + + /****************************************************************************************/ + /* Print nrn_cur related routines */ + /****************************************************************************************/ + + + /** + * Print the \c nrn_current kernel + * + * \note nrn_cur_kernel will have two calls to nrn_current if no conductance keywords specified + * \param node the AST node representing the NMODL breakpoint block + */ + void print_nrn_current(const ast::BreakpointBlock& node) override; + + + /** + * Print the \c nrn\_cur kernel with NMODL \c conductance keyword provisions + * + * If the NMODL \c conductance keyword is used in the \c breakpoint block, then + * CodegenCoreneuronCppVisitor::print_nrn_cur_kernel will use this printer + * + * \param node the AST node representing the NMODL breakpoint block + */ + void print_nrn_cur_conductance_kernel(const ast::BreakpointBlock& node) override; + + + /** + * Print the \c nrn\_cur kernel without NMODL \c conductance keyword provisions + * + * If the NMODL \c conductance keyword is \b not used in the \c breakpoint block, then + * CodegenCoreneuronCppVisitor::print_nrn_cur_kernel will use this printer + */ + void print_nrn_cur_non_conductance_kernel() override; + + + /** + * Print main body of nrn_cur function + * \param node the AST node representing the NMODL breakpoint block + */ + void print_nrn_cur_kernel(const ast::BreakpointBlock& node) override; + + + /** + * Print fast membrane current calculation code + */ + virtual void print_fast_imem_calculation() override; + + + /** + * Print nrn_cur / current update function definition + */ + void print_nrn_cur() override; + + + /****************************************************************************************/ + /* Main code printing entry points */ + /****************************************************************************************/ + + + /** + * Print all includes + * + */ + void print_headers_include() override; + + + /** + * Print start of namespaces + * + */ + void print_namespace_begin() override; + + + /** + * Print end of namespaces + * + */ + void print_namespace_end() override; + + + /** + * Print common getters + * + */ + void print_common_getters(); + + + /** + * Print all classes + * \param print_initializers Whether to include default values. + */ + void print_data_structures(bool print_initializers) override; + + + /** + * Set v_unused (voltage) for NRN_PRCELLSTATE feature + */ + void print_v_unused() const override; + + + /** + * Set g_unused (conductance) for NRN_PRCELLSTATE feature + */ + void print_g_unused() const override; + + + /** + * Print all compute functions for every backend + * + */ + virtual void print_compute_functions() override; + + + /** + * Print entry point to code generation + * + */ + virtual void print_codegen_routines() override; + + + /****************************************************************************************/ + /* Overloaded visitor routines */ + /****************************************************************************************/ + + + void visit_derivimplicit_callback(const ast::DerivimplicitCallback& node) override; + void visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock& node) override; + void visit_eigen_linear_solver_block(const ast::EigenLinearSolverBlock& node) override; + void visit_for_netcon(const ast::ForNetcon& node) override; + virtual void visit_solution_expression(const ast::SolutionExpression& node) override; + virtual void visit_watch_statement(const ast::WatchStatement& node) override; + + + /** + * Print prototype declarations of functions or procedures + * \tparam T The AST node type of the node (must be of nmodl::ast::Ast or subclass) + * \param node The AST node representing the function or procedure block + * \param name A user defined name for the function + */ + template + void print_function_declaration(const T& node, const std::string& name); + + + public: + /** + * \brief Constructs the C++ code generator visitor + * + * This constructor instantiates an NMODL C++ code generator and allows writing generated code + * directly to a file in \c [output_dir]/[mod_filename].cpp. + * + * \note No code generation is performed at this stage. Since the code + * generator classes are all based on \c AstVisitor the AST must be visited using e.g. \c + * visit_program in order to generate the C++ code corresponding to the AST. + * + * \param mod_filename The name of the model for which code should be generated. + * It is used for constructing an output filename. + * \param output_dir The directory where target C++ file should be generated. + * \param float_type The float type to use in the generated code. The string will be used + * as-is in the target code. This defaults to \c double. + */ + CodegenCoreneuronCppVisitor(std::string mod_filename, + const std::string& output_dir, + std::string float_type, + const bool optimize_ionvar_copies) + : CodegenCppVisitor(mod_filename, output_dir, float_type, optimize_ionvar_copies) {} + + /** + * \copybrief nmodl::codegen::CodegenCoreneuronCppVisitor + * + * This constructor instantiates an NMODL C++ code generator and allows writing generated code + * into an output stream. + * + * \note No code generation is performed at this stage. Since the code + * generator classes are all based on \c AstVisitor the AST must be visited using e.g. \c + * visit_program in order to generate the C++ code corresponding to the AST. + * + * \param mod_filename The name of the model for which code should be generated. + * It is used for constructing an output filename. + * \param stream The output stream onto which to write the generated code + * \param float_type The float type to use in the generated code. The string will be used + * as-is in the target code. This defaults to \c double. + */ + CodegenCoreneuronCppVisitor(std::string mod_filename, + std::ostream& stream, + std::string float_type, + const bool optimize_ionvar_copies) + : CodegenCppVisitor(mod_filename, stream, float_type, optimize_ionvar_copies) {} + + + /****************************************************************************************/ + /* Public printing routines for code generation for use in unit tests */ + /****************************************************************************************/ + + + /** + * Print the function that initialize instance structure + */ + void print_instance_variable_setup(); + + + /** + * Print the structure that wraps all range and int variables required for the NMODL + * + * \param print_initializers Whether or not default values for variables + * be included in the struct declaration. + */ + void print_mechanism_range_var_structure(bool print_initializers) override; +}; + + +/** + * \details If there is an argument with name (say alpha) same as range variable (say alpha), + * we want to avoid it being printed as instance->alpha. And hence we disable variable + * name lookup during prototype declaration. Note that the name of procedure can be + * different in case of table statement. + */ +template +void CodegenCoreneuronCppVisitor::print_function_declaration(const T& node, + const std::string& name) { + enable_variable_name_lookup = false; + auto type = default_float_data_type(); + + // internal and user provided arguments + auto internal_params = internal_method_parameters(); + const auto& params = node.get_parameters(); + for (const auto& param: params) { + internal_params.emplace_back("", type, "", param.get()->get_node_name()); + } + + // procedures have "int" return type by default + const char* return_type = "int"; + if (node.is_function_block()) { + return_type = default_float_data_type(); + } + + print_device_method_annotation(); + printer->add_indent(); + printer->fmt_text("inline {} {}({})", + return_type, + method_name(name), + get_parameter_str(internal_params)); + + enable_variable_name_lookup = true; +} + +/** \} */ // end of codegen_backends + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index a3e9111acb..21909a8b90 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -4,723 +4,628 @@ * * SPDX-License-Identifier: Apache-2.0 */ - #include "codegen/codegen_cpp_visitor.hpp" -#include -#include -#include -#include -#include - #include "ast/all.hpp" #include "codegen/codegen_helper_visitor.hpp" -#include "codegen/codegen_naming.hpp" #include "codegen/codegen_utils.hpp" -#include "config/config.h" -#include "lexer/token_mapping.hpp" -#include "parser/c11_driver.hpp" -#include "utils/logger.hpp" -#include "utils/string_utils.hpp" -#include "visitors/defuse_analyze_visitor.hpp" #include "visitors/rename_visitor.hpp" -#include "visitors/symtab_visitor.hpp" -#include "visitors/var_usage_visitor.hpp" -#include "visitors/visitor_utils.hpp" namespace nmodl { namespace codegen { using namespace ast; -using visitor::DefUseAnalyzeVisitor; -using visitor::DUState; using visitor::RenameVisitor; -using visitor::SymtabVisitor; -using visitor::VarUsageVisitor; using symtab::syminfo::NmodlType; + /****************************************************************************************/ -/* Overloaded visitor routines */ +/* Common helper routines accross codegen functions */ /****************************************************************************************/ -static const std::regex regex_special_chars{R"([-[\]{}()*+?.,\^$|#\s])"}; -void CodegenCppVisitor::visit_string(const String& node) { - if (!codegen) { - return; - } - std::string name = node.eval(); - if (enable_variable_name_lookup) { - name = get_variable_name(name); - } - printer->add_text(name); +template +bool CodegenCppVisitor::has_parameter_of_name(const T& node, const std::string& name) { + auto parameters = node->get_parameters(); + return std::any_of(parameters.begin(), + parameters.end(), + [&name](const decltype(*parameters.begin()) arg) { + return arg->get_node_name() == name; + }); } -void CodegenCppVisitor::visit_integer(const Integer& node) { - if (!codegen) { - return; +/** + * \details Certain statements like unit, comment, solve can/need to be skipped + * during code generation. Note that solve block is wrapped in expression + * statement and hence we have to check inner expression. It's also true + * for the initial block defined inside net receive block. + */ +bool CodegenCppVisitor::statement_to_skip(const Statement& node) { + // clang-format off + if (node.is_unit_state() + || node.is_line_comment() + || node.is_block_comment() + || node.is_solve_block() + || node.is_conductance_hint() + || node.is_table_statement()) { + return true; } - const auto& value = node.get_value(); - printer->add_text(std::to_string(value)); + // clang-format on + if (node.is_expression_statement()) { + auto expression = dynamic_cast(&node)->get_expression(); + if (expression->is_solve_block()) { + return true; + } + if (expression->is_initial_block()) { + return true; + } + } + return false; } -void CodegenCppVisitor::visit_float(const Float& node) { - if (!codegen) { - return; +bool CodegenCppVisitor::net_send_buffer_required() const noexcept { + if (net_receive_required() && !info.artificial_cell) { + if (info.net_event_used || info.net_send_used || info.is_watch_used()) { + return true; + } } - printer->add_text(format_float_string(node.get_value())); + return false; } -void CodegenCppVisitor::visit_double(const Double& node) { - if (!codegen) { - return; - } - printer->add_text(format_double_string(node.get_value())); +bool CodegenCppVisitor::net_receive_buffering_required() const noexcept { + return info.point_process && !info.artificial_cell && info.net_receive_node != nullptr; } -void CodegenCppVisitor::visit_boolean(const Boolean& node) { - if (!codegen) { - return; +bool CodegenCppVisitor::nrn_state_required() const noexcept { + if (info.artificial_cell) { + return false; } - printer->add_text(std::to_string(static_cast(node.eval()))); + return info.nrn_state_block != nullptr || breakpoint_exist(); } -void CodegenCppVisitor::visit_name(const Name& node) { - if (!codegen) { - return; - } - node.visit_children(*this); +bool CodegenCppVisitor::nrn_cur_required() const noexcept { + return info.breakpoint_node != nullptr && !info.currents.empty(); } -void CodegenCppVisitor::visit_unit(const ast::Unit& node) { - // do not print units +bool CodegenCppVisitor::net_receive_exist() const noexcept { + return info.net_receive_node != nullptr; } -void CodegenCppVisitor::visit_prime_name(const PrimeName& /* node */) { - throw std::runtime_error("PRIME encountered during code generation, ODEs not solved?"); +bool CodegenCppVisitor::breakpoint_exist() const noexcept { + return info.breakpoint_node != nullptr; } -/** - * \todo : Validate how @ is being handled in neuron implementation - */ -void CodegenCppVisitor::visit_var_name(const VarName& node) { - if (!codegen) { - return; - } - const auto& name = node.get_name(); - const auto& at_index = node.get_at(); - const auto& index = node.get_index(); - name->accept(*this); - if (at_index) { - printer->add_text("@"); - at_index->accept(*this); - } - if (index) { - printer->add_text("["); - printer->add_text("static_cast("); - index->accept(*this); - printer->add_text(")"); - printer->add_text("]"); - } +bool CodegenCppVisitor::net_receive_required() const noexcept { + return net_receive_exist(); } -void CodegenCppVisitor::visit_indexed_name(const IndexedName& node) { - if (!codegen) { - return; - } - node.get_name()->accept(*this); - printer->add_text("["); - printer->add_text("static_cast("); - node.get_length()->accept(*this); - printer->add_text(")"); - printer->add_text("]"); +/** + * \details When floating point data type is not default (i.e. double) then we + * have to copy old array to new type (for range variables). + */ +bool CodegenCppVisitor::range_variable_setup_required() const noexcept { + return codegen::naming::DEFAULT_FLOAT_TYPE != float_data_type(); } -void CodegenCppVisitor::visit_local_list_statement(const LocalListStatement& node) { - if (!codegen) { - return; - } - printer->add_text(local_var_type(), ' '); - print_vector_elements(node.get_variables(), ", "); +// check if there is a function or procedure defined with given name +bool CodegenCppVisitor::defined_method(const std::string& name) const { + const auto& function = program_symtab->lookup(name); + auto properties = NmodlType::function_block | NmodlType::procedure_block; + return function && function->has_any_property(properties); } -void CodegenCppVisitor::visit_if_statement(const IfStatement& node) { - if (!codegen) { - return; - } - printer->add_text("if ("); - node.get_condition()->accept(*this); - printer->add_text(") "); - node.get_statement_block()->accept(*this); - print_vector_elements(node.get_elseifs(), ""); - const auto& elses = node.get_elses(); - if (elses) { - elses->accept(*this); - } +int CodegenCppVisitor::float_variables_size() const { + return codegen_float_variables.size(); } -void CodegenCppVisitor::visit_else_if_statement(const ElseIfStatement& node) { - if (!codegen) { - return; - } - printer->add_text(" else if ("); - node.get_condition()->accept(*this); - printer->add_text(") "); - node.get_statement_block()->accept(*this); +int CodegenCppVisitor::int_variables_size() const { + const auto count_semantics = [](int sum, const IndexSemantics& sem) { return sum += sem.size; }; + return std::accumulate(info.semantics.begin(), info.semantics.end(), 0, count_semantics); } -void CodegenCppVisitor::visit_else_statement(const ElseStatement& node) { - if (!codegen) { - return; - } - printer->add_text(" else "); - node.visit_children(*this); +/** + * \details We can directly print value but if user specify value as integer then + * then it gets printed as an integer. To avoid this, we use below wrapper. + * If user has provided integer then it gets printed as 1.0 (similar to mod2c + * and neuron where ".0" is appended). Otherwise we print double variables as + * they are represented in the mod file by user. If the value is in scientific + * representation (1e+20, 1E-15) then keep it as it is. + */ +std::string CodegenCppVisitor::format_double_string(const std::string& s_value) { + return utils::format_double_string(s_value); } -void CodegenCppVisitor::visit_while_statement(const WhileStatement& node) { - printer->add_text("while ("); - node.get_condition()->accept(*this); - printer->add_text(") "); - node.get_statement_block()->accept(*this); +std::string CodegenCppVisitor::format_float_string(const std::string& s_value) { + return utils::format_float_string(s_value); } -void CodegenCppVisitor::visit_from_statement(const ast::FromStatement& node) { - if (!codegen) { - return; +/** + * \details Statements like if, else etc. don't need semicolon at the end. + * (Note that it's valid to have "extraneous" semicolon). Also, statement + * block can appear as statement using expression statement which need to + * be inspected. + */ +bool CodegenCppVisitor::need_semicolon(const Statement& node) { + // clang-format off + if (node.is_if_statement() + || node.is_else_if_statement() + || node.is_else_statement() + || node.is_from_statement() + || node.is_verbatim() + || node.is_conductance_hint() + || node.is_while_statement() + || node.is_protect_statement() + || node.is_mutex_lock() + || node.is_mutex_unlock()) { + return false; } - auto name = node.get_node_name(); - const auto& from = node.get_from(); - const auto& to = node.get_to(); - const auto& inc = node.get_increment(); - const auto& block = node.get_statement_block(); - printer->fmt_text("for (int {} = ", name); - from->accept(*this); - printer->fmt_text("; {} <= ", name); - to->accept(*this); - if (inc) { - printer->fmt_text("; {} += ", name); - inc->accept(*this); - } else { - printer->fmt_text("; {}++", name); + if (node.is_expression_statement()) { + auto expression = dynamic_cast(node).get_expression(); + if (expression->is_statement_block() + || expression->is_eigen_newton_solver_block() + || expression->is_eigen_linear_solver_block() + || expression->is_solution_expression() + || expression->is_for_netcon()) { + return false; + } } - printer->add_text(") "); - block->accept(*this); + // clang-format on + return true; } -void CodegenCppVisitor::visit_paren_expression(const ParenExpression& node) { - if (!codegen) { - return; - } - printer->add_text("("); - node.get_expression()->accept(*this); - printer->add_text(")"); +/****************************************************************************************/ +/* Main printing routines for code generation */ +/****************************************************************************************/ + + +void CodegenCppVisitor::print_prcellstate_macros() const { + printer->add_line("#ifndef NRN_PRCELLSTATE"); + printer->add_line("#define NRN_PRCELLSTATE 0"); + printer->add_line("#endif"); } -void CodegenCppVisitor::visit_binary_expression(const BinaryExpression& node) { - if (!codegen) { - return; - } - auto op = node.get_op().eval(); - const auto& lhs = node.get_lhs(); - const auto& rhs = node.get_rhs(); - if (op == "^") { - printer->add_text("pow("); - lhs->accept(*this); - printer->add_text(", "); - rhs->accept(*this); - printer->add_text(")"); - } else { - lhs->accept(*this); - printer->add_text(" " + op + " "); - rhs->accept(*this); - } -} - - -void CodegenCppVisitor::visit_binary_operator(const BinaryOperator& node) { - if (!codegen) { - return; - } - printer->add_text(node.eval()); -} - - -void CodegenCppVisitor::visit_unary_operator(const UnaryOperator& node) { - if (!codegen) { - return; - } - printer->add_text(" " + node.eval()); -} - +void CodegenCppVisitor::print_mechanism_info() { + auto variable_printer = [&](const std::vector& variables) { + for (const auto& v: variables) { + auto name = v->get_name(); + if (!info.point_process) { + name += "_" + info.mod_suffix; + } + if (v->is_array()) { + name += fmt::format("[{}]", v->get_length()); + } + printer->add_line(add_escape_quote(name), ","); + } + }; -/** - * \details Statement block is top level construct (for every nmodl block). - * Sometime we want to analyse ast nodes even if code generation is - * false. Hence we visit children even if code generation is false. - */ -void CodegenCppVisitor::visit_statement_block(const StatementBlock& node) { - if (!codegen) { - node.visit_children(*this); - return; - } - print_statement_block(node); + printer->add_newline(2); + printer->add_line("/** channel information */"); + printer->add_line("static const char *mechanism[] = {"); + printer->increase_indent(); + printer->add_line(add_escape_quote(nmodl_version()), ","); + printer->add_line(add_escape_quote(info.mod_suffix), ","); + variable_printer(info.range_parameter_vars); + printer->add_line("0,"); + variable_printer(info.range_assigned_vars); + printer->add_line("0,"); + variable_printer(info.range_state_vars); + printer->add_line("0,"); + variable_printer(info.pointer_variables); + printer->add_line("0"); + printer->decrease_indent(); + printer->add_line("};"); } -void CodegenCppVisitor::visit_function_call(const FunctionCall& node) { - if (!codegen) { - return; - } - print_function_call(node); -} +/****************************************************************************************/ +/* Printing routines for code generation */ +/****************************************************************************************/ -void CodegenCppVisitor::visit_verbatim(const Verbatim& node) { - if (!codegen) { - return; +void CodegenCppVisitor::print_statement_block(const ast::StatementBlock& node, + bool open_brace, + bool close_brace) { + if (open_brace) { + printer->push_block(); } - const auto& text = node.get_statement()->eval(); - const auto& result = process_verbatim_text(text); - const auto& statements = stringutils::split_string(result, '\n'); + const auto& statements = node.get_statements(); for (const auto& statement: statements) { - const auto& trimed_stmt = stringutils::trim_newline(statement); - if (trimed_stmt.find_first_not_of(' ') != std::string::npos) { - printer->add_line(trimed_stmt); + if (statement_to_skip(*statement)) { + continue; + } + /// not necessary to add indent for verbatim block (pretty-printing) + if (!statement->is_verbatim() && !statement->is_mutex_lock() && + !statement->is_mutex_unlock() && !statement->is_protect_statement()) { + printer->add_indent(); + } + statement->accept(*this); + if (need_semicolon(*statement)) { + printer->add_text(';'); + } + if (!statement->is_mutex_lock() && !statement->is_mutex_unlock()) { + printer->add_newline(); } } -} -void CodegenCppVisitor::visit_update_dt(const ast::UpdateDt& node) { - // dt change statement should be pulled outside already + if (close_brace) { + printer->pop_block_nl(0); + } } -void CodegenCppVisitor::visit_protect_statement(const ast::ProtectStatement& node) { - print_atomic_reduction_pragma(); - printer->add_indent(); - node.get_expression()->accept(*this); - printer->add_text(";"); -} -void CodegenCppVisitor::visit_mutex_lock(const ast::MutexLock& node) { - printer->fmt_line("#pragma omp critical ({})", info.mod_suffix); - printer->add_indent(); - printer->push_block(); +/** + * \todo Issue with verbatim renaming. e.g. pattern.mod has info struct with + * index variable. If we use "index" instead of "indexes" as default argument + * then during verbatim replacement we don't know the index is which one. This + * is because verbatim renaming pass has already stripped out prefixes from + * the text. + */ +void CodegenCppVisitor::rename_function_arguments() { + const auto& default_arguments = stringutils::split_string(nrn_thread_arguments(), ','); + for (const auto& dirty_arg: default_arguments) { + const auto& arg = stringutils::trim(dirty_arg); + RenameVisitor v(arg, "arg_" + arg); + for (const auto& function: info.functions) { + if (has_parameter_of_name(function, arg)) { + function->accept(v); + } + } + for (const auto& function: info.procedures) { + if (has_parameter_of_name(function, arg)) { + function->accept(v); + } + } + } } -void CodegenCppVisitor::visit_mutex_unlock(const ast::MutexUnlock& node) { - printer->pop_block(); -} /****************************************************************************************/ -/* Common helper routines */ +/* Main code printing entry points */ /****************************************************************************************/ /** - * \details Certain statements like unit, comment, solve can/need to be skipped - * during code generation. Note that solve block is wrapped in expression - * statement and hence we have to check inner expression. It's also true - * for the initial block defined inside net receive block. + * NMODL constants from unit database + * */ -bool CodegenCppVisitor::statement_to_skip(const Statement& node) { - // clang-format off - if (node.is_unit_state() - || node.is_line_comment() - || node.is_block_comment() - || node.is_solve_block() - || node.is_conductance_hint() - || node.is_table_statement()) { - return true; - } - // clang-format on - if (node.is_expression_statement()) { - auto expression = dynamic_cast(&node)->get_expression(); - if (expression->is_solve_block()) { - return true; - } - if (expression->is_initial_block()) { - return true; +void CodegenCppVisitor::print_nmodl_constants() { + if (!info.factor_definitions.empty()) { + printer->add_newline(2); + printer->add_line("/** constants used in nmodl from UNITS */"); + for (const auto& it: info.factor_definitions) { + const std::string format_string = "static const double {} = {};"; + printer->fmt_line(format_string, it->get_node_name(), it->get_value()->get_value()); } } - return false; } -bool CodegenCppVisitor::net_send_buffer_required() const noexcept { - if (net_receive_required() && !info.artificial_cell) { - if (info.net_event_used || info.net_send_used || info.is_watch_used()) { - return true; - } - } - return false; -} +/****************************************************************************************/ +/* Overloaded visitor routines */ +/****************************************************************************************/ -bool CodegenCppVisitor::net_receive_buffering_required() const noexcept { - return info.point_process && !info.artificial_cell && info.net_receive_node != nullptr; -} +extern const std::regex regex_special_chars{R"([-[\]{}()*+?.,\^$|#\s])"}; -bool CodegenCppVisitor::nrn_state_required() const noexcept { - if (info.artificial_cell) { - return false; +void CodegenCppVisitor::visit_string(const String& node) { + if (!codegen) { + return; } - return info.nrn_state_block != nullptr || breakpoint_exist(); + std::string name = node.eval(); + if (enable_variable_name_lookup) { + name = get_variable_name(name); + } + printer->add_text(name); } -bool CodegenCppVisitor::nrn_cur_required() const noexcept { - return info.breakpoint_node != nullptr && !info.currents.empty(); +void CodegenCppVisitor::visit_integer(const Integer& node) { + if (!codegen) { + return; + } + const auto& value = node.get_value(); + printer->add_text(std::to_string(value)); } -bool CodegenCppVisitor::net_receive_exist() const noexcept { - return info.net_receive_node != nullptr; +void CodegenCppVisitor::visit_float(const Float& node) { + if (!codegen) { + return; + } + printer->add_text(format_float_string(node.get_value())); } -bool CodegenCppVisitor::breakpoint_exist() const noexcept { - return info.breakpoint_node != nullptr; +void CodegenCppVisitor::visit_double(const Double& node) { + if (!codegen) { + return; + } + printer->add_text(format_double_string(node.get_value())); } -bool CodegenCppVisitor::net_receive_required() const noexcept { - return net_receive_exist(); +void CodegenCppVisitor::visit_boolean(const Boolean& node) { + if (!codegen) { + return; + } + printer->add_text(std::to_string(static_cast(node.eval()))); } -/** - * \details When floating point data type is not default (i.e. double) then we - * have to copy old array to new type (for range variables). - */ -bool CodegenCppVisitor::range_variable_setup_required() const noexcept { - return codegen::naming::DEFAULT_FLOAT_TYPE != float_data_type(); +void CodegenCppVisitor::visit_name(const Name& node) { + if (!codegen) { + return; + } + node.visit_children(*this); } -int CodegenCppVisitor::position_of_float_var(const std::string& name) const { - int index = 0; - for (const auto& var: codegen_float_variables) { - if (var->get_name() == name) { - return index; - } - index += var->get_length(); - } - throw std::logic_error(name + " variable not found"); +void CodegenCppVisitor::visit_unit(const ast::Unit& node) { + // do not print units } -int CodegenCppVisitor::position_of_int_var(const std::string& name) const { - int index = 0; - for (const auto& var: codegen_int_variables) { - if (var.symbol->get_name() == name) { - return index; - } - index += var.symbol->get_length(); - } - throw std::logic_error(name + " variable not found"); +void CodegenCppVisitor::visit_prime_name(const PrimeName& /* node */) { + throw std::runtime_error("PRIME encountered during code generation, ODEs not solved?"); } /** - * \details We can directly print value but if user specify value as integer then - * then it gets printed as an integer. To avoid this, we use below wrapper. - * If user has provided integer then it gets printed as 1.0 (similar to mod2c - * and neuron where ".0" is appended). Otherwise we print double variables as - * they are represented in the mod file by user. If the value is in scientific - * representation (1e+20, 1E-15) then keep it as it is. + * \todo : Validate how @ is being handled in neuron implementation */ -std::string CodegenCppVisitor::format_double_string(const std::string& s_value) { - return utils::format_double_string(s_value); +void CodegenCppVisitor::visit_var_name(const VarName& node) { + if (!codegen) { + return; + } + const auto& name = node.get_name(); + const auto& at_index = node.get_at(); + const auto& index = node.get_index(); + name->accept(*this); + if (at_index) { + printer->add_text("@"); + at_index->accept(*this); + } + if (index) { + printer->add_text("["); + printer->add_text("static_cast("); + index->accept(*this); + printer->add_text(")"); + printer->add_text("]"); + } } -std::string CodegenCppVisitor::format_float_string(const std::string& s_value) { - return utils::format_float_string(s_value); +void CodegenCppVisitor::visit_indexed_name(const IndexedName& node) { + if (!codegen) { + return; + } + node.get_name()->accept(*this); + printer->add_text("["); + printer->add_text("static_cast("); + node.get_length()->accept(*this); + printer->add_text(")"); + printer->add_text("]"); } -/** - * \details Statements like if, else etc. don't need semicolon at the end. - * (Note that it's valid to have "extraneous" semicolon). Also, statement - * block can appear as statement using expression statement which need to - * be inspected. - */ -bool CodegenCppVisitor::need_semicolon(const Statement& node) { - // clang-format off - if (node.is_if_statement() - || node.is_else_if_statement() - || node.is_else_statement() - || node.is_from_statement() - || node.is_verbatim() - || node.is_conductance_hint() - || node.is_while_statement() - || node.is_protect_statement() - || node.is_mutex_lock() - || node.is_mutex_unlock()) { - return false; - } - if (node.is_expression_statement()) { - auto expression = dynamic_cast(node).get_expression(); - if (expression->is_statement_block() - || expression->is_eigen_newton_solver_block() - || expression->is_eigen_linear_solver_block() - || expression->is_solution_expression() - || expression->is_for_netcon()) { - return false; - } +void CodegenCppVisitor::visit_local_list_statement(const LocalListStatement& node) { + if (!codegen) { + return; } - // clang-format on - return true; + printer->add_text(local_var_type(), ' '); + print_vector_elements(node.get_variables(), ", "); } -// check if there is a function or procedure defined with given name -bool CodegenCppVisitor::defined_method(const std::string& name) const { - const auto& function = program_symtab->lookup(name); - auto properties = NmodlType::function_block | NmodlType::procedure_block; - return function && function->has_any_property(properties); +void CodegenCppVisitor::visit_if_statement(const IfStatement& node) { + if (!codegen) { + return; + } + printer->add_text("if ("); + node.get_condition()->accept(*this); + printer->add_text(") "); + node.get_statement_block()->accept(*this); + print_vector_elements(node.get_elseifs(), ""); + const auto& elses = node.get_elses(); + if (elses) { + elses->accept(*this); + } } -/** - * \details Current variable used in breakpoint block could be local variable. - * In this case, neuron has already renamed the variable name by prepending - * "_l". In our implementation, the variable could have been renamed by - * one of the pass. And hence, we search all local variables and check if - * the variable is renamed. Note that we have to look into the symbol table - * of statement block and not breakpoint. - */ -std::string CodegenCppVisitor::breakpoint_current(std::string current) const { - auto breakpoint = info.breakpoint_node; - if (breakpoint == nullptr) { - return current; +void CodegenCppVisitor::visit_else_if_statement(const ElseIfStatement& node) { + if (!codegen) { + return; } - auto symtab = breakpoint->get_statement_block()->get_symbol_table(); - auto variables = symtab->get_variables_with_properties(NmodlType::local_var); - for (const auto& var: variables) { - auto renamed_name = var->get_name(); - auto original_name = var->get_original_name(); - if (current == original_name) { - current = renamed_name; - break; - } + printer->add_text(" else if ("); + node.get_condition()->accept(*this); + printer->add_text(") "); + node.get_statement_block()->accept(*this); +} + + +void CodegenCppVisitor::visit_else_statement(const ElseStatement& node) { + if (!codegen) { + return; } - return current; + printer->add_text(" else "); + node.visit_children(*this); } -int CodegenCppVisitor::float_variables_size() const { - return codegen_float_variables.size(); +void CodegenCppVisitor::visit_while_statement(const WhileStatement& node) { + printer->add_text("while ("); + node.get_condition()->accept(*this); + printer->add_text(") "); + node.get_statement_block()->accept(*this); } -int CodegenCppVisitor::int_variables_size() const { - const auto count_semantics = [](int sum, const IndexSemantics& sem) { return sum += sem.size; }; - return std::accumulate(info.semantics.begin(), info.semantics.end(), 0, count_semantics); +void CodegenCppVisitor::visit_from_statement(const ast::FromStatement& node) { + if (!codegen) { + return; + } + auto name = node.get_node_name(); + const auto& from = node.get_from(); + const auto& to = node.get_to(); + const auto& inc = node.get_increment(); + const auto& block = node.get_statement_block(); + printer->fmt_text("for (int {} = ", name); + from->accept(*this); + printer->fmt_text("; {} <= ", name); + to->accept(*this); + if (inc) { + printer->fmt_text("; {} += ", name); + inc->accept(*this); + } else { + printer->fmt_text("; {}++", name); + } + printer->add_text(") "); + block->accept(*this); } -/** - * \details Depending upon the block type, we have to print read/write ion variables - * during code generation. Depending on block/procedure being printed, this - * method return statements as vector. As different code backends could have - * different variable names, we rely on backend-specific read_ion_variable_name - * and write_ion_variable_name method which will be overloaded. - */ -std::vector CodegenCppVisitor::ion_read_statements(BlockType type) const { - if (optimize_ion_variable_copies()) { - return ion_read_statements_optimized(type); +void CodegenCppVisitor::visit_paren_expression(const ParenExpression& node) { + if (!codegen) { + return; } - std::vector statements; - for (const auto& ion: info.ions) { - auto name = ion.name; - for (const auto& var: ion.reads) { - auto const iter = std::find(ion.implicit_reads.begin(), ion.implicit_reads.end(), var); - if (iter != ion.implicit_reads.end()) { - continue; - } - auto variable_names = read_ion_variable_name(var); - auto first = get_variable_name(variable_names.first); - auto second = get_variable_name(variable_names.second); - statements.push_back(fmt::format("{} = {};", first, second)); - } - for (const auto& var: ion.writes) { - if (ion.is_ionic_conc(var)) { - auto variables = read_ion_variable_name(var); - auto first = get_variable_name(variables.first); - auto second = get_variable_name(variables.second); - statements.push_back(fmt::format("{} = {};", first, second)); - } - } + printer->add_text("("); + node.get_expression()->accept(*this); + printer->add_text(")"); +} + + +void CodegenCppVisitor::visit_binary_expression(const BinaryExpression& node) { + if (!codegen) { + return; + } + auto op = node.get_op().eval(); + const auto& lhs = node.get_lhs(); + const auto& rhs = node.get_rhs(); + if (op == "^") { + printer->add_text("pow("); + lhs->accept(*this); + printer->add_text(", "); + rhs->accept(*this); + printer->add_text(")"); + } else { + lhs->accept(*this); + printer->add_text(" " + op + " "); + rhs->accept(*this); } - return statements; } -std::vector CodegenCppVisitor::ion_read_statements_optimized(BlockType type) const { - std::vector statements; - for (const auto& ion: info.ions) { - for (const auto& var: ion.writes) { - if (ion.is_ionic_conc(var)) { - auto variables = read_ion_variable_name(var); - auto first = "ionvar." + variables.first; - const auto& second = get_variable_name(variables.second); - statements.push_back(fmt::format("{} = {};", first, second)); - } - } +void CodegenCppVisitor::visit_binary_operator(const BinaryOperator& node) { + if (!codegen) { + return; } - return statements; + printer->add_text(node.eval()); } -// NOLINTNEXTLINE(readability-function-cognitive-complexity) -std::vector CodegenCppVisitor::ion_write_statements(BlockType type) { - std::vector statements; - for (const auto& ion: info.ions) { - std::string concentration; - auto name = ion.name; - for (const auto& var: ion.writes) { - auto variable_names = write_ion_variable_name(var); - if (ion.is_ionic_current(var)) { - if (type == BlockType::Equation) { - auto current = breakpoint_current(var); - auto lhs = variable_names.first; - auto op = "+="; - auto rhs = get_variable_name(current); - if (info.point_process) { - auto area = get_variable_name(naming::NODE_AREA_VARIABLE); - rhs += fmt::format("*(1.e2/{})", area); - } - statements.push_back(ShadowUseStatement{lhs, op, rhs}); - } - } else { - if (!ion.is_rev_potential(var)) { - concentration = var; - } - auto lhs = variable_names.first; - auto op = "="; - auto rhs = get_variable_name(variable_names.second); - statements.push_back(ShadowUseStatement{lhs, op, rhs}); - } - } - if (type == BlockType::Initial && !concentration.empty()) { - int index = 0; - if (ion.is_intra_cell_conc(concentration)) { - index = 1; - } else if (ion.is_extra_cell_conc(concentration)) { - index = 2; - } else { - /// \todo Unhandled case in neuron implementation - throw std::logic_error(fmt::format("codegen error for {} ion", ion.name)); - } - auto ion_type_name = fmt::format("{}_type", ion.name); - auto lhs = fmt::format("int {}", ion_type_name); - auto op = "="; - auto rhs = get_variable_name(ion_type_name); - statements.push_back(ShadowUseStatement{lhs, op, rhs}); - auto statement = conc_write_statement(ion.name, concentration, index); - statements.push_back(ShadowUseStatement{statement, "", ""}); - } +void CodegenCppVisitor::visit_unary_operator(const UnaryOperator& node) { + if (!codegen) { + return; } - return statements; + printer->add_text(" " + node.eval()); } /** - * \details Often top level verbatim blocks use variables with old names. - * Here we process if we are processing verbatim block at global scope. + * \details Statement block is top level construct (for every nmodl block). + * Sometime we want to analyse ast nodes even if code generation is + * false. Hence we visit children even if code generation is false. */ -std::string CodegenCppVisitor::process_verbatim_token(const std::string& token) { - const std::string& name = token; +void CodegenCppVisitor::visit_statement_block(const StatementBlock& node) { + if (!codegen) { + node.visit_children(*this); + return; + } + print_statement_block(node); +} - /* - * If given token is procedure name and if it's defined - * in the current mod file then it must be replaced - */ - if (program_symtab->is_method_defined(token)) { - return method_name(token); + +void CodegenCppVisitor::visit_function_call(const FunctionCall& node) { + if (!codegen) { + return; } + print_function_call(node); +} - /* - * Check if token is commongly used variable name in - * verbatim block like nt, \c \_threadargs etc. If so, replace - * it and return. - */ - auto new_name = replace_if_verbatim_variable(name); - if (new_name != name) { - return get_variable_name(new_name, false); + +void CodegenCppVisitor::visit_verbatim(const Verbatim& node) { + if (!codegen) { + return; } + const auto& text = node.get_statement()->eval(); + const auto& result = process_verbatim_text(text); - /* - * For top level verbatim blocks we shouldn't replace variable - * names with Instance because arguments are provided from coreneuron - * and they are missing inst. - */ - auto use_instance = !printing_top_verbatim_blocks; - return get_variable_name(token, use_instance); + const auto& statements = stringutils::split_string(result, '\n'); + for (const auto& statement: statements) { + const auto& trimed_stmt = stringutils::trim_newline(statement); + if (trimed_stmt.find_first_not_of(' ') != std::string::npos) { + printer->add_line(trimed_stmt); + } + } } -bool CodegenCppVisitor::ion_variable_struct_required() const { - return optimize_ion_variable_copies() && info.ion_has_write_variable(); +void CodegenCppVisitor::visit_update_dt(const ast::UpdateDt& node) { + // dt change statement should be pulled outside already } -/** - * \details This can be override in the backend. For example, parameters can be constant - * except in INITIAL block where they are set to 0. As initial block is/can be - * executed on c++/cpu backend, gpu backend can mark the parameter as constant. - */ -bool CodegenCppVisitor::is_constant_variable(const std::string& name) const { - auto symbol = program_symtab->lookup_in_scope(name); - bool is_constant = false; - if (symbol != nullptr) { - // per mechanism ion variables needs to be updated from neuron/coreneuron values - if (info.is_ion_variable(name)) { - is_constant = false; - } - // for parameter variable to be const, make sure it's write count is 0 - // and it's not used in the verbatim block - else if (symbol->has_any_property(NmodlType::param_assign) && - info.variables_in_verbatim.find(name) == info.variables_in_verbatim.end() && - symbol->get_write_count() == 0) { - is_constant = true; - } - } - return is_constant; +void CodegenCppVisitor::visit_protect_statement(const ast::ProtectStatement& node) { + print_atomic_reduction_pragma(); + printer->add_indent(); + node.get_expression()->accept(*this); + printer->add_text(";"); } -/** - * \details Once variables are populated, update index semantics to register with coreneuron - */ -// NOLINTNEXTLINE(readability-function-cognitive-complexity) -void CodegenCppVisitor::update_index_semantics() { - int index = 0; - info.semantics.clear(); - - if (info.point_process) { +void CodegenCppVisitor::visit_mutex_lock(const ast::MutexLock& node) { + printer->fmt_line("#pragma omp critical ({})", info.mod_suffix); + printer->add_indent(); + printer->push_block(); +} + + +void CodegenCppVisitor::visit_mutex_unlock(const ast::MutexUnlock& node) { + printer->pop_block(); +} + + +/** + * \details Once variables are populated, update index semantics to register with coreneuron + */ +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void CodegenCppVisitor::update_index_semantics() { + int index = 0; + info.semantics.clear(); + + if (info.point_process) { info.semantics.emplace_back(index++, naming::AREA_SEMANTIC, 1); info.semantics.emplace_back(index++, naming::POINT_PROCESS_SEMANTIC, 1); } @@ -968,3715 +873,6 @@ std::vector CodegenCppVisitor::get_int_variables() { } -/****************************************************************************************/ -/* Routines must be overloaded in backend */ -/****************************************************************************************/ - -std::string CodegenCppVisitor::get_parameter_str(const ParamVector& params) { - std::string str; - bool is_first = true; - for (const auto& param: params) { - if (is_first) { - is_first = false; - } else { - str += ", "; - } - str += fmt::format("{}{} {}{}", - std::get<0>(param), - std::get<1>(param), - std::get<2>(param), - std::get<3>(param)); - } - return str; -} - - -void CodegenCppVisitor::print_deriv_advance_flag_transfer_to_device() const { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_device_atomic_capture_annotation() const { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_net_send_buf_count_update_to_host() const { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_net_send_buf_update_to_host() const { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_net_send_buf_count_update_to_device() const { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_dt_update_to_device() const { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_device_stream_wait() const { - // backend specific, do nothing -} - -/** - * \details Each kernel such as \c nrn\_init, \c nrn\_state and \c nrn\_cur could be offloaded - * to accelerator. In this case, at very top level, we print pragma - * for data present. For example: - * - * \code{.cpp} - * void nrn_state(...) { - * #pragma acc data present (nt, ml...) - * { - * - * } - * } - * \endcode - */ -void CodegenCppVisitor::print_kernel_data_present_annotation_block_begin() { - // backend specific, do nothing -} - - -void CodegenCppVisitor::print_kernel_data_present_annotation_block_end() { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_net_init_acc_serial_annotation_block_begin() { - // backend specific, do nothing -} - -void CodegenCppVisitor::print_net_init_acc_serial_annotation_block_end() { - // backend specific, do nothing -} - -/** - * \details Depending programming model and compiler, we print compiler hint - * for parallelization. For example: - * - * \code - * #pragma omp simd - * for(int id = 0; id < nodecount; id++) { - * - * #pragma acc parallel loop - * for(int id = 0; id < nodecount; id++) { - * \endcode - */ -void CodegenCppVisitor::print_channel_iteration_block_parallel_hint(BlockType /* type */, - const ast::Block* block) { - // ivdep allows SIMD parallelisation of a block/loop but doesn't provide - // a standard mechanism for atomics. Also, even with openmp 5.0, openmp - // atomics do not enable vectorisation under "omp simd" (gives compiler - // error with gcc < 9 if atomic and simd pragmas are nested). So, emit - // ivdep/simd pragma when no MUTEXLOCK/MUTEXUNLOCK/PROTECT statements - // are used in the given block. - std::vector> nodes; - if (block) { - nodes = collect_nodes(*block, - {ast::AstNodeType::PROTECT_STATEMENT, - ast::AstNodeType::MUTEX_LOCK, - ast::AstNodeType::MUTEX_UNLOCK}); - } - if (nodes.empty()) { - printer->add_line("#pragma omp simd"); - printer->add_line("#pragma ivdep"); - } -} - - -bool CodegenCppVisitor::nrn_cur_reduction_loop_required() { - return info.point_process; -} - - -void CodegenCppVisitor::print_rhs_d_shadow_variables() { - if (info.point_process) { - printer->fmt_line("double* shadow_rhs = nt->{};", naming::NTHREAD_RHS_SHADOW); - printer->fmt_line("double* shadow_d = nt->{};", naming::NTHREAD_D_SHADOW); - } -} - - -void CodegenCppVisitor::print_nrn_cur_matrix_shadow_update() { - if (info.point_process) { - printer->add_line("shadow_rhs[id] = rhs;"); - printer->add_line("shadow_d[id] = g;"); - } else { - auto rhs_op = operator_for_rhs(); - auto d_op = operator_for_d(); - printer->fmt_line("vec_rhs[node_id] {} rhs;", rhs_op); - printer->fmt_line("vec_d[node_id] {} g;", d_op); - } -} - - -void CodegenCppVisitor::print_nrn_cur_matrix_shadow_reduction() { - auto rhs_op = operator_for_rhs(); - auto d_op = operator_for_d(); - if (info.point_process) { - printer->add_line("int node_id = node_index[id];"); - printer->fmt_line("vec_rhs[node_id] {} shadow_rhs[id];", rhs_op); - printer->fmt_line("vec_d[node_id] {} shadow_d[id];", d_op); - } -} - - -/** - * In the current implementation of CPU/CPP backend we need to emit atomic pragma - * only with PROTECT construct (atomic rduction requirement for other cases on CPU - * is handled via separate shadow vectors). - */ -void CodegenCppVisitor::print_atomic_reduction_pragma() { - printer->add_line("#pragma omp atomic update"); -} - - -void CodegenCppVisitor::print_device_method_annotation() { - // backend specific, nothing for cpu -} - - -void CodegenCppVisitor::print_global_method_annotation() { - // backend specific, nothing for cpu -} - - -void CodegenCppVisitor::print_backend_namespace_start() { - // no separate namespace for C++ (cpu) backend -} - - -void CodegenCppVisitor::print_backend_namespace_stop() { - // no separate namespace for C++ (cpu) backend -} - - -void CodegenCppVisitor::print_backend_includes() { - // backend specific, nothing for cpu -} - - -std::string CodegenCppVisitor::backend_name() const { - return "C++ (api-compatibility)"; -} - - -bool CodegenCppVisitor::optimize_ion_variable_copies() const { - return optimize_ionvar_copies; -} - - -void CodegenCppVisitor::print_memory_allocation_routine() const { - printer->add_newline(2); - auto args = "size_t num, size_t size, size_t alignment = 16"; - printer->fmt_push_block("static inline void* mem_alloc({})", args); - printer->add_line("void* ptr;"); - printer->add_line("posix_memalign(&ptr, alignment, num*size);"); - printer->add_line("memset(ptr, 0, size);"); - printer->add_line("return ptr;"); - printer->pop_block(); - - printer->add_newline(2); - printer->push_block("static inline void mem_free(void* ptr)"); - printer->add_line("free(ptr);"); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_abort_routine() const { - printer->add_newline(2); - printer->push_block("static inline void coreneuron_abort()"); - printer->add_line("abort();"); - printer->pop_block(); -} - - -std::string CodegenCppVisitor::compute_method_name(BlockType type) const { - if (type == BlockType::Initial) { - return method_name(naming::NRN_INIT_METHOD); - } - if (type == BlockType::Constructor) { - return method_name(naming::NRN_CONSTRUCTOR_METHOD); - } - if (type == BlockType::Destructor) { - return method_name(naming::NRN_DESTRUCTOR_METHOD); - } - if (type == BlockType::State) { - return method_name(naming::NRN_STATE_METHOD); - } - if (type == BlockType::Equation) { - return method_name(naming::NRN_CUR_METHOD); - } - if (type == BlockType::Watch) { - return method_name(naming::NRN_WATCH_CHECK_METHOD); - } - throw std::logic_error("compute_method_name not implemented"); -} - - -std::string CodegenCppVisitor::global_var_struct_type_qualifier() { - return ""; -} - -void CodegenCppVisitor::print_global_var_struct_decl() { - printer->add_line(global_struct(), ' ', global_struct_instance(), ';'); -} - -/****************************************************************************************/ -/* printing routines for code generation */ -/****************************************************************************************/ - - -void CodegenCppVisitor::visit_watch_statement(const ast::WatchStatement& /* node */) { - printer->add_text(fmt::format("nrn_watch_activate(inst, id, pnodecount, {}, v, watch_remove)", - current_watch_statement++)); -} - - -void CodegenCppVisitor::print_statement_block(const ast::StatementBlock& node, - bool open_brace, - bool close_brace) { - if (open_brace) { - printer->push_block(); - } - - const auto& statements = node.get_statements(); - for (const auto& statement: statements) { - if (statement_to_skip(*statement)) { - continue; - } - /// not necessary to add indent for verbatim block (pretty-printing) - if (!statement->is_verbatim() && !statement->is_mutex_lock() && - !statement->is_mutex_unlock() && !statement->is_protect_statement()) { - printer->add_indent(); - } - statement->accept(*this); - if (need_semicolon(*statement)) { - printer->add_text(';'); - } - if (!statement->is_mutex_lock() && !statement->is_mutex_unlock()) { - printer->add_newline(); - } - } - - if (close_brace) { - printer->pop_block_nl(0); - } -} - - -void CodegenCppVisitor::print_function_call(const FunctionCall& node) { - const auto& name = node.get_node_name(); - auto function_name = name; - if (defined_method(name)) { - function_name = method_name(name); - } - - if (is_net_send(name)) { - print_net_send_call(node); - return; - } - - if (is_net_move(name)) { - print_net_move_call(node); - return; - } - - if (is_net_event(name)) { - print_net_event_call(node); - return; - } - - const auto& arguments = node.get_arguments(); - printer->add_text(function_name, '('); - - if (defined_method(name)) { - printer->add_text(internal_method_arguments()); - if (!arguments.empty()) { - printer->add_text(", "); - } - } - - print_vector_elements(arguments, ", "); - printer->add_text(')'); -} - - -void CodegenCppVisitor::print_top_verbatim_blocks() { - if (info.top_verbatim_blocks.empty()) { - return; - } - print_namespace_stop(); - - printer->add_newline(2); - printer->add_line("using namespace coreneuron;"); - codegen = true; - printing_top_verbatim_blocks = true; - - for (const auto& block: info.top_blocks) { - if (block->is_verbatim()) { - printer->add_newline(2); - block->accept(*this); - } - } - - printing_top_verbatim_blocks = false; - codegen = false; - print_namespace_start(); -} - - -/** - * \todo Issue with verbatim renaming. e.g. pattern.mod has info struct with - * index variable. If we use "index" instead of "indexes" as default argument - * then during verbatim replacement we don't know the index is which one. This - * is because verbatim renaming pass has already stripped out prefixes from - * the text. - */ -void CodegenCppVisitor::rename_function_arguments() { - const auto& default_arguments = stringutils::split_string(nrn_thread_arguments(), ','); - for (const auto& dirty_arg: default_arguments) { - const auto& arg = stringutils::trim(dirty_arg); - RenameVisitor v(arg, "arg_" + arg); - for (const auto& function: info.functions) { - if (has_parameter_of_name(function, arg)) { - function->accept(v); - } - } - for (const auto& function: info.procedures) { - if (has_parameter_of_name(function, arg)) { - function->accept(v); - } - } - } -} - - -void CodegenCppVisitor::print_function_prototypes() { - if (info.functions.empty() && info.procedures.empty()) { - return; - } - codegen = true; - printer->add_newline(2); - for (const auto& node: info.functions) { - print_function_declaration(*node, node->get_node_name()); - printer->add_text(';'); - printer->add_newline(); - } - for (const auto& node: info.procedures) { - print_function_declaration(*node, node->get_node_name()); - printer->add_text(';'); - printer->add_newline(); - } - codegen = false; -} - - -static const TableStatement* get_table_statement(const ast::Block& node) { - // TableStatementVisitor v; - - const auto& table_statements = collect_nodes(node, {AstNodeType::TABLE_STATEMENT}); - - if (table_statements.size() != 1) { - auto message = fmt::format("One table statement expected in {} found {}", - node.get_node_name(), - table_statements.size()); - throw std::runtime_error(message); - } - return dynamic_cast(table_statements.front().get()); -} - - -std::tuple CodegenCppVisitor::check_if_var_is_array(const std::string& name) { - auto symbol = program_symtab->lookup_in_scope(name); - if (!symbol) { - throw std::runtime_error( - fmt::format("CodegenCppVisitor:: {} not found in symbol table!", name)); - } - if (symbol->is_array()) { - return {true, symbol->get_length()}; - } else { - return {false, 0}; - } -} - - -void CodegenCppVisitor::print_table_check_function(const Block& node) { - auto statement = get_table_statement(node); - auto table_variables = statement->get_table_vars(); - auto depend_variables = statement->get_depend_vars(); - const auto& from = statement->get_from(); - const auto& to = statement->get_to(); - auto name = node.get_node_name(); - auto internal_params = internal_method_parameters(); - auto with = statement->get_with()->eval(); - auto use_table_var = get_variable_name(naming::USE_TABLE_VARIABLE); - auto tmin_name = get_variable_name("tmin_" + name); - auto mfac_name = get_variable_name("mfac_" + name); - auto float_type = default_float_data_type(); - - printer->add_newline(2); - print_device_method_annotation(); - printer->fmt_push_block("void check_{}({})", - method_name(name), - get_parameter_str(internal_params)); - { - printer->fmt_push_block("if ({} == 0)", use_table_var); - printer->add_line("return;"); - printer->pop_block(); - - printer->add_line("static bool make_table = true;"); - for (const auto& variable: depend_variables) { - printer->fmt_line("static {} save_{};", float_type, variable->get_node_name()); - } - - for (const auto& variable: depend_variables) { - const auto& var_name = variable->get_node_name(); - const auto& instance_name = get_variable_name(var_name); - printer->fmt_push_block("if (save_{} != {})", var_name, instance_name); - printer->add_line("make_table = true;"); - printer->pop_block(); - } - - printer->push_block("if (make_table)"); - { - printer->add_line("make_table = false;"); - - printer->add_indent(); - printer->add_text(tmin_name, " = "); - from->accept(*this); - printer->add_text(';'); - printer->add_newline(); - - printer->add_indent(); - printer->add_text("double tmax = "); - to->accept(*this); - printer->add_text(';'); - printer->add_newline(); - - - printer->fmt_line("double dx = (tmax-{}) / {}.;", tmin_name, with); - printer->fmt_line("{} = 1./dx;", mfac_name); - - printer->fmt_line("double x = {};", tmin_name); - printer->fmt_push_block("for (std::size_t i = 0; i < {}; x += dx, i++)", with + 1); - auto function = method_name("f_" + name); - if (node.is_procedure_block()) { - printer->fmt_line("{}({}, x);", function, internal_method_arguments()); - for (const auto& variable: table_variables) { - auto var_name = variable->get_node_name(); - auto instance_name = get_variable_name(var_name); - auto table_name = get_variable_name("t_" + var_name); - auto [is_array, array_length] = check_if_var_is_array(var_name); - if (is_array) { - for (int j = 0; j < array_length; j++) { - printer->fmt_line( - "{}[{}][i] = {}[{}];", table_name, j, instance_name, j); - } - } else { - printer->fmt_line("{}[i] = {};", table_name, instance_name); - } - } - } else { - auto table_name = get_variable_name("t_" + name); - printer->fmt_line("{}[i] = {}({}, x);", - table_name, - function, - internal_method_arguments()); - } - printer->pop_block(); - - for (const auto& variable: depend_variables) { - auto var_name = variable->get_node_name(); - auto instance_name = get_variable_name(var_name); - printer->fmt_line("save_{} = {};", var_name, instance_name); - } - } - printer->pop_block(); - } - printer->pop_block(); -} - - -void CodegenCppVisitor::print_table_replacement_function(const ast::Block& node) { - auto name = node.get_node_name(); - auto statement = get_table_statement(node); - auto table_variables = statement->get_table_vars(); - auto with = statement->get_with()->eval(); - auto use_table_var = get_variable_name(naming::USE_TABLE_VARIABLE); - auto tmin_name = get_variable_name("tmin_" + name); - auto mfac_name = get_variable_name("mfac_" + name); - auto function_name = method_name("f_" + name); - - printer->add_newline(2); - print_function_declaration(node, name); - printer->push_block(); - { - const auto& params = node.get_parameters(); - printer->fmt_push_block("if ({} == 0)", use_table_var); - if (node.is_procedure_block()) { - printer->fmt_line("{}({}, {});", - function_name, - internal_method_arguments(), - params[0].get()->get_node_name()); - printer->add_line("return 0;"); - } else { - printer->fmt_line("return {}({}, {});", - function_name, - internal_method_arguments(), - params[0].get()->get_node_name()); - } - printer->pop_block(); - - printer->fmt_line("double xi = {} * ({} - {});", - mfac_name, - params[0].get()->get_node_name(), - tmin_name); - printer->push_block("if (isnan(xi))"); - if (node.is_procedure_block()) { - for (const auto& var: table_variables) { - auto var_name = get_variable_name(var->get_node_name()); - auto [is_array, array_length] = check_if_var_is_array(var->get_node_name()); - if (is_array) { - for (int j = 0; j < array_length; j++) { - printer->fmt_line("{}[{}] = xi;", var_name, j); - } - } else { - printer->fmt_line("{} = xi;", var_name); - } - } - printer->add_line("return 0;"); - } else { - printer->add_line("return xi;"); - } - printer->pop_block(); - - printer->fmt_push_block("if (xi <= 0. || xi >= {}.)", with); - printer->fmt_line("int index = (xi <= 0.) ? 0 : {};", with); - if (node.is_procedure_block()) { - for (const auto& variable: table_variables) { - auto var_name = variable->get_node_name(); - auto instance_name = get_variable_name(var_name); - auto table_name = get_variable_name("t_" + var_name); - auto [is_array, array_length] = check_if_var_is_array(var_name); - if (is_array) { - for (int j = 0; j < array_length; j++) { - printer->fmt_line( - "{}[{}] = {}[{}][index];", instance_name, j, table_name, j); - } - } else { - printer->fmt_line("{} = {}[index];", instance_name, table_name); - } - } - printer->add_line("return 0;"); - } else { - auto table_name = get_variable_name("t_" + name); - printer->fmt_line("return {}[index];", table_name); - } - printer->pop_block(); - - printer->add_line("int i = int(xi);"); - printer->add_line("double theta = xi - double(i);"); - if (node.is_procedure_block()) { - for (const auto& var: table_variables) { - auto var_name = var->get_node_name(); - auto instance_name = get_variable_name(var_name); - auto table_name = get_variable_name("t_" + var_name); - auto [is_array, array_length] = check_if_var_is_array(var->get_node_name()); - if (is_array) { - for (size_t j = 0; j < array_length; j++) { - printer->fmt_line( - "{0}[{1}] = {2}[{1}][i] + theta*({2}[{1}][i+1]-{2}[{1}][i]);", - instance_name, - j, - table_name); - } - } else { - printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);", - instance_name, - table_name); - } - } - printer->add_line("return 0;"); - } else { - auto table_name = get_variable_name("t_" + name); - printer->fmt_line("return {0}[i] + theta * ({0}[i+1] - {0}[i]);", table_name); - } - } - printer->pop_block(); -} - - -void CodegenCppVisitor::print_check_table_thread_function() { - if (info.table_count == 0) { - return; - } - - printer->add_newline(2); - auto name = method_name("check_table_thread"); - auto parameters = external_method_parameters(true); - - printer->fmt_push_block("static void {} ({})", name, parameters); - printer->add_line("setup_instance(nt, ml);"); - printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); - printer->add_line("double v = 0;"); - - for (const auto& function: info.functions_with_table) { - auto method_name_str = method_name("check_" + function->get_node_name()); - auto arguments = internal_method_arguments(); - printer->fmt_line("{}({});", method_name_str, arguments); - } - - printer->pop_block(); -} - - -void CodegenCppVisitor::print_function_or_procedure(const ast::Block& node, - const std::string& name) { - printer->add_newline(2); - print_function_declaration(node, name); - printer->add_text(" "); - printer->push_block(); - - // function requires return variable declaration - if (node.is_function_block()) { - auto type = default_float_data_type(); - printer->fmt_line("{} ret_{} = 0.0;", type, name); - } else { - printer->fmt_line("int ret_{} = 0;", name); - } - - print_statement_block(*node.get_statement_block(), false, false); - printer->fmt_line("return ret_{};", name); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_function_procedure_helper(const ast::Block& node) { - codegen = true; - auto name = node.get_node_name(); - - if (info.function_uses_table(name)) { - auto new_name = "f_" + name; - print_function_or_procedure(node, new_name); - print_table_check_function(node); - print_table_replacement_function(node); - } else { - print_function_or_procedure(node, name); - } - - codegen = false; -} - - -void CodegenCppVisitor::print_procedure(const ast::ProcedureBlock& node) { - print_function_procedure_helper(node); -} - - -void CodegenCppVisitor::print_function(const ast::FunctionBlock& node) { - auto name = node.get_node_name(); - - // name of return variable - std::string return_var; - if (info.function_uses_table(name)) { - return_var = "ret_f_" + name; - } else { - return_var = "ret_" + name; - } - - // first rename return variable name - auto block = node.get_statement_block().get(); - RenameVisitor v(name, return_var); - block->accept(v); - - print_function_procedure_helper(node); -} - - -void CodegenCppVisitor::print_function_tables(const ast::FunctionTableBlock& node) { - auto name = node.get_node_name(); - const auto& p = node.get_parameters(); - auto params = internal_method_parameters(); - for (const auto& i: p) { - params.emplace_back("", "double", "", i->get_node_name()); - } - printer->fmt_line("double {}({})", method_name(name), get_parameter_str(params)); - printer->push_block(); - printer->fmt_line("double _arg[{}];", p.size()); - for (size_t i = 0; i < p.size(); ++i) { - printer->fmt_line("_arg[{}] = {};", i, p[i]->get_node_name()); - } - printer->fmt_line("return hoc_func_table({}, {}, _arg);", - get_variable_name(std::string("_ptable_" + name), true), - p.size()); - printer->pop_block(); - - printer->fmt_push_block("double table_{}()", method_name(name)); - printer->fmt_line("hoc_spec_table(&{}, {});", - get_variable_name(std::string("_ptable_" + name)), - p.size()); - printer->add_line("return 0.;"); - printer->pop_block(); -} - -/** - * @brief Checks whether the functor_block generated by sympy solver modifies any variable outside - * its scope. If it does then return false, so that the operator() of the struct functor of the - * Eigen Newton solver doesn't have const qualifier. - * - * @param variable_block Statement Block of the variables declarations used in the functor struct of - * the solver - * @param functor_block Actual code being printed in the operator() of the functor struct of the - * solver - * @return True if operator() is const else False - */ -bool is_functor_const(const ast::StatementBlock& variable_block, - const ast::StatementBlock& functor_block) { - // Create complete_block with both variable declarations (done in variable_block) and solver - // part (done in functor_block) to be able to run the SymtabVisitor and DefUseAnalyzeVisitor - // then and get the proper DUChains for the variables defined in the variable_block - ast::StatementBlock complete_block(functor_block); - // Typically variable_block has only one statement, a statement containing the declaration - // of the local variables - for (const auto& statement: variable_block.get_statements()) { - complete_block.insert_statement(complete_block.get_statements().begin(), statement); - } - - // Create Symbol Table for complete_block - auto model_symbol_table = std::make_shared(); - SymtabVisitor(model_symbol_table.get()).visit_statement_block(complete_block); - // Initialize DefUseAnalyzeVisitor to generate the DUChains for the variables defined in the - // variable_block - DefUseAnalyzeVisitor v(*complete_block.get_symbol_table()); - - // Check the DUChains for all the variables in the variable_block - // If variable is defined in complete_block don't add const quilifier in operator() - auto is_functor_const = true; - const auto& variables = collect_nodes(variable_block, {ast::AstNodeType::LOCAL_VAR}); - for (const auto& variable: variables) { - const auto& chain = v.analyze(complete_block, variable->get_node_name()); - is_functor_const = !(chain.eval() == DUState::D || chain.eval() == DUState::LD || - chain.eval() == DUState::CD); - if (!is_functor_const) { - break; - } - } - - return is_functor_const; -} - -void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlock& node) { - // functor that evaluates F(X) and J(X) for - // Newton solver - auto float_type = default_float_data_type(); - int N = node.get_n_state_vars()->get_value(); - - const auto functor_name = info.functor_names[&node]; - printer->fmt_push_block("struct {0}", functor_name); - printer->add_line("NrnThread* nt;"); - printer->add_line(instance_struct(), "* inst;"); - printer->add_line("int id, pnodecount;"); - printer->add_line("double v;"); - printer->add_line("const Datum* indexes;"); - printer->add_line("double* data;"); - printer->add_line("ThreadDatum* thread;"); - - if (ion_variable_struct_required()) { - print_ion_variable(); - } - - print_statement_block(*node.get_variable_block(), false, false); - printer->add_newline(); - - printer->push_block("void initialize()"); - print_statement_block(*node.get_initialize_block(), false, false); - printer->pop_block(); - printer->add_newline(); - - printer->fmt_line( - "{0}(NrnThread* nt, {1}* inst, int id, int pnodecount, double v, const Datum* indexes, " - "double* data, ThreadDatum* thread) : " - "nt{{nt}}, inst{{inst}}, id{{id}}, pnodecount{{pnodecount}}, v{{v}}, indexes{{indexes}}, " - "data{{data}}, thread{{thread}} " - "{{}}", - functor_name, - instance_struct()); - - printer->add_indent(); - - const auto& variable_block = *node.get_variable_block(); - const auto& functor_block = *node.get_functor_block(); - - printer->fmt_text( - "void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, " - "1>& nmodl_eigen_fm, " - "Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}", - float_type, - N, - is_functor_const(variable_block, functor_block) ? "const " : ""); - printer->push_block(); - printer->fmt_line("const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); - printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type); - printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type); - print_statement_block(functor_block, false, false); - printer->pop_block(); - printer->add_newline(); - - // assign newton solver results in matrix X to state vars - printer->push_block("void finalize()"); - print_statement_block(*node.get_finalize_block(), false, false); - printer->pop_block(); - - printer->pop_block(";"); -} - -void CodegenCppVisitor::visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock& node) { - // solution vector to store copy of state vars for Newton solver - printer->add_newline(); - - auto float_type = default_float_data_type(); - int N = node.get_n_state_vars()->get_value(); - printer->fmt_line("Eigen::Matrix<{}, {}, 1> nmodl_eigen_xm;", float_type, N); - printer->fmt_line("{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); - - print_statement_block(*node.get_setup_x_block(), false, false); - - // call newton solver with functor and X matrix that contains state vars - printer->add_line("// call newton solver"); - printer->fmt_line("{} newton_functor(nt, inst, id, pnodecount, v, indexes, data, thread);", - info.functor_names[&node]); - printer->add_line("newton_functor.initialize();"); - printer->add_line( - "int newton_iterations = nmodl::newton::newton_solver(nmodl_eigen_xm, newton_functor);"); - printer->add_line( - "if (newton_iterations < 0) assert(false && \"Newton solver did not converge!\");"); - - // assign newton solver results in matrix X to state vars - print_statement_block(*node.get_update_states_block(), false, false); - printer->add_line("newton_functor.finalize();"); -} - -void CodegenCppVisitor::visit_eigen_linear_solver_block(const ast::EigenLinearSolverBlock& node) { - printer->add_newline(); - - const std::string float_type = default_float_data_type(); - int N = node.get_n_state_vars()->get_value(); - printer->fmt_line("Eigen::Matrix<{0}, {1}, 1> nmodl_eigen_xm, nmodl_eigen_fm;", float_type, N); - printer->fmt_line("Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm;", float_type, N); - if (N <= 4) - printer->fmt_line("Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm_inv;", float_type, N); - printer->fmt_line("{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); - printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type); - printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type); - print_statement_block(*node.get_variable_block(), false, false); - print_statement_block(*node.get_initialize_block(), false, false); - print_statement_block(*node.get_setup_x_block(), false, false); - - printer->add_newline(); - print_eigen_linear_solver(float_type, N); - printer->add_newline(); - - print_statement_block(*node.get_update_states_block(), false, false); - print_statement_block(*node.get_finalize_block(), false, false); -} - -void CodegenCppVisitor::print_eigen_linear_solver(const std::string& float_type, int N) { - if (N <= 4) { - // Faster compared to LU, given the template specialization in Eigen. - printer->add_multi_line(R"CODE( - bool invertible; - nmodl_eigen_jm.computeInverseWithCheck(nmodl_eigen_jm_inv,invertible); - nmodl_eigen_xm = nmodl_eigen_jm_inv*nmodl_eigen_fm; - if (!invertible) assert(false && "Singular or ill-conditioned matrix (Eigen::inverse)!"); - )CODE"); - } else { - // In Eigen the default storage order is ColMajor. - // Crout's implementation requires matrices stored in RowMajor order (C++-style arrays). - // Therefore, the transposeInPlace is critical such that the data() method to give the rows - // instead of the columns. - printer->add_line("if (!nmodl_eigen_jm.IsRowMajor) nmodl_eigen_jm.transposeInPlace();"); - - // pivot vector - printer->fmt_line("Eigen::Matrix pivot;", N); - printer->fmt_line("Eigen::Matrix<{0}, {1}, 1> rowmax;", float_type, N); - - // In-place LU-Decomposition (Crout Algo) : Jm is replaced by its LU-decomposition - printer->fmt_line( - "if (nmodl::crout::Crout<{0}>({1}, nmodl_eigen_jm.data(), pivot.data(), rowmax.data()) " - "< 0) assert(false && \"Singular or ill-conditioned matrix (nmodl::crout)!\");", - float_type, - N); - - // Solve the linear system : Forward/Backward substitution part - printer->fmt_line( - "nmodl::crout::solveCrout<{0}>({1}, nmodl_eigen_jm.data(), nmodl_eigen_fm.data(), " - "nmodl_eigen_xm.data(), pivot.data());", - float_type, - N); - } -} - -/****************************************************************************************/ -/* Code-specific helper routines */ -/****************************************************************************************/ - - -std::string CodegenCppVisitor::internal_method_arguments() { - if (ion_variable_struct_required()) { - return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v"; - } - return "id, pnodecount, inst, data, indexes, thread, nt, v"; -} - - -/** - * @todo: figure out how to correctly handle qualifiers - */ -CodegenCppVisitor::ParamVector CodegenCppVisitor::internal_method_parameters() { - ParamVector params; - params.emplace_back("", "int", "", "id"); - params.emplace_back("", "int", "", "pnodecount"); - params.emplace_back("", fmt::format("{}*", instance_struct()), "", "inst"); - if (ion_variable_struct_required()) { - params.emplace_back("", "IonCurVar&", "", "ionvar"); - } - params.emplace_back("", "double*", "", "data"); - params.emplace_back("const ", "Datum*", "", "indexes"); - params.emplace_back("", "ThreadDatum*", "", "thread"); - params.emplace_back("", "NrnThread*", "", "nt"); - params.emplace_back("", "double", "", "v"); - return params; -} - - -const char* CodegenCppVisitor::external_method_arguments() noexcept { - return "id, pnodecount, data, indexes, thread, nt, ml, v"; -} - - -const char* CodegenCppVisitor::external_method_parameters(bool table) noexcept { - if (table) { - return "int id, int pnodecount, double* data, Datum* indexes, " - "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, int tml_id"; - } - return "int id, int pnodecount, double* data, Datum* indexes, " - "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, double v"; -} - - -std::string CodegenCppVisitor::nrn_thread_arguments() const { - if (ion_variable_struct_required()) { - return "id, pnodecount, ionvar, data, indexes, thread, nt, ml, v"; - } - return "id, pnodecount, data, indexes, thread, nt, ml, v"; -} - - -/** - * Function call arguments when function or procedure is defined in the - * same mod file itself - */ -std::string CodegenCppVisitor::nrn_thread_internal_arguments() { - if (ion_variable_struct_required()) { - return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v"; - } - return "id, pnodecount, inst, data, indexes, thread, nt, v"; -} - - -/** - * Replace commonly used variables in the verbatim blocks into their corresponding - * variable name in the new code generation backend. - */ -std::string CodegenCppVisitor::replace_if_verbatim_variable(std::string name) { - if (naming::VERBATIM_VARIABLES_MAPPING.find(name) != naming::VERBATIM_VARIABLES_MAPPING.end()) { - name = naming::VERBATIM_VARIABLES_MAPPING.at(name); - } - - /** - * if function is defined the same mod file then the arguments must - * contain mechanism instance as well. - */ - if (name == naming::THREAD_ARGS) { - if (internal_method_call_encountered) { - name = nrn_thread_internal_arguments(); - internal_method_call_encountered = false; - } else { - name = nrn_thread_arguments(); - } - } - if (name == naming::THREAD_ARGS_PROTO) { - name = external_method_parameters(); - } - return name; -} - - -/** - * Processing commonly used constructs in the verbatim blocks. - * @todo : this is still ad-hoc and requires re-implementation to - * handle it more elegantly. - */ -std::string CodegenCppVisitor::process_verbatim_text(std::string const& text) { - parser::CDriver driver; - driver.scan_string(text); - auto tokens = driver.all_tokens(); - std::string result; - for (size_t i = 0; i < tokens.size(); i++) { - auto token = tokens[i]; - - // check if we have function call in the verbatim block where - // function is defined in the same mod file - if (program_symtab->is_method_defined(token) && tokens[i + 1] == "(") { - internal_method_call_encountered = true; - } - auto name = process_verbatim_token(token); - - if (token == (std::string("_") + naming::TQITEM_VARIABLE)) { - name.insert(0, 1, '&'); - } - if (token == "_STRIDE") { - name = "pnodecount+id"; - } - result += name; - } - return result; -} - - -std::string CodegenCppVisitor::register_mechanism_arguments() const { - auto nrn_cur = nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "nullptr"; - auto nrn_state = nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "nullptr"; - auto nrn_alloc = method_name(naming::NRN_ALLOC_METHOD); - auto nrn_init = method_name(naming::NRN_INIT_METHOD); - auto const nrn_private_constructor = method_name(naming::NRN_PRIVATE_CONSTRUCTOR_METHOD); - auto const nrn_private_destructor = method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD); - return fmt::format("mechanism, {}, {}, nullptr, {}, {}, {}, {}, first_pointer_var_index()", - nrn_alloc, - nrn_cur, - nrn_state, - nrn_init, - nrn_private_constructor, - nrn_private_destructor); -} - - -std::pair CodegenCppVisitor::read_ion_variable_name( - const std::string& name) { - return {name, naming::ION_VARNAME_PREFIX + name}; -} - - -std::pair CodegenCppVisitor::write_ion_variable_name( - const std::string& name) { - return {naming::ION_VARNAME_PREFIX + name, name}; -} - - -std::string CodegenCppVisitor::conc_write_statement(const std::string& ion_name, - const std::string& concentration, - int index) { - auto conc_var_name = get_variable_name(naming::ION_VARNAME_PREFIX + concentration); - auto style_var_name = get_variable_name("style_" + ion_name); - return fmt::format( - "nrn_wrote_conc({}_type," - " &({})," - " {}," - " {}," - " nrn_ion_global_map," - " {}," - " nt->_ml_list[{}_type]->_nodecount_padded)", - ion_name, - conc_var_name, - index, - style_var_name, - get_variable_name(naming::CELSIUS_VARIABLE), - ion_name); -} - - -/** - * If mechanisms dependency level execution is enabled then certain updates - * like ionic current contributions needs to be atomically updated. In this - * case we first update current mechanism's shadow vector and then add statement - * to queue that will be used in reduction queue. - */ -std::string CodegenCppVisitor::process_shadow_update_statement(const ShadowUseStatement& statement, - BlockType /* type */) { - // when there is no operator or rhs then that statement doesn't need shadow update - if (statement.op.empty() && statement.rhs.empty()) { - auto text = statement.lhs + ";"; - return text; - } - - // return regular statement - auto lhs = get_variable_name(statement.lhs); - auto text = fmt::format("{} {} {};", lhs, statement.op, statement.rhs); - return text; -} - - -/****************************************************************************************/ -/* Code-specific printing routines for code generation */ -/****************************************************************************************/ - - -/** - * NMODL constants from unit database - * - */ -void CodegenCppVisitor::print_nmodl_constants() { - if (!info.factor_definitions.empty()) { - printer->add_newline(2); - printer->add_line("/** constants used in nmodl from UNITS */"); - for (const auto& it: info.factor_definitions) { - const std::string format_string = "static const double {} = {};"; - printer->fmt_line(format_string, it->get_node_name(), it->get_value()->get_value()); - } - } -} - - -void CodegenCppVisitor::print_first_pointer_var_index_getter() { - printer->add_newline(2); - print_device_method_annotation(); - printer->push_block("static inline int first_pointer_var_index()"); - printer->fmt_line("return {};", info.first_pointer_var_index); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_num_variable_getter() { - printer->add_newline(2); - print_device_method_annotation(); - printer->push_block("static inline int float_variables_size()"); - printer->fmt_line("return {};", float_variables_size()); - printer->pop_block(); - - printer->add_newline(2); - print_device_method_annotation(); - printer->push_block("static inline int int_variables_size()"); - printer->fmt_line("return {};", int_variables_size()); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_net_receive_arg_size_getter() { - if (!net_receive_exist()) { - return; - } - printer->add_newline(2); - print_device_method_annotation(); - printer->push_block("static inline int num_net_receive_args()"); - printer->fmt_line("return {};", info.num_net_receive_parameters); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_mech_type_getter() { - printer->add_newline(2); - print_device_method_annotation(); - printer->push_block("static inline int get_mech_type()"); - // false => get it from the host-only global struct, not the instance structure - printer->fmt_line("return {};", get_variable_name("mech_type", false)); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_memb_list_getter() { - printer->add_newline(2); - print_device_method_annotation(); - printer->push_block("static inline Memb_list* get_memb_list(NrnThread* nt)"); - printer->push_block("if (!nt->_ml_list)"); - printer->add_line("return nullptr;"); - printer->pop_block(); - printer->add_line("return nt->_ml_list[get_mech_type()];"); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_namespace_start() { - printer->add_newline(2); - printer->push_block("namespace coreneuron"); -} - - -void CodegenCppVisitor::print_namespace_stop() { - printer->pop_block(); -} - - -/** - * \details There are three types of thread variables currently considered: - * - top local thread variables - * - thread variables in the mod file - * - thread variables for solver - * - * These variables are allocated into different thread structures and have - * corresponding thread ids. Thread id start from 0. In mod2c implementation, - * thread_data_index is increased at various places and it is used to - * decide the index of thread. - */ - -void CodegenCppVisitor::print_thread_getters() { - if (info.vectorize && info.derivimplicit_used()) { - int tid = info.derivimplicit_var_thread_id; - int list = info.derivimplicit_list_num; - - // clang-format off - printer->add_newline(2); - printer->add_line("/** thread specific helper routines for derivimplicit */"); - - printer->add_newline(1); - printer->fmt_push_block("static inline int* deriv{}_advance(ThreadDatum* thread)", list); - printer->fmt_line("return &(thread[{}].i);", tid); - printer->pop_block(); - printer->add_newline(); - - printer->fmt_push_block("static inline int dith{}()", list); - printer->fmt_line("return {};", tid+1); - printer->pop_block(); - printer->add_newline(); - - printer->fmt_push_block("static inline void** newtonspace{}(ThreadDatum* thread)", list); - printer->fmt_line("return &(thread[{}]._pvoid);", tid+2); - printer->pop_block(); - } - - if (info.vectorize && !info.thread_variables.empty()) { - printer->add_newline(2); - printer->add_line("/** tid for thread variables */"); - printer->push_block("static inline int thread_var_tid()"); - printer->fmt_line("return {};", info.thread_var_thread_id); - printer->pop_block(); - } - - if (info.vectorize && !info.top_local_variables.empty()) { - printer->add_newline(2); - printer->add_line("/** tid for top local tread variables */"); - printer->push_block("static inline int top_local_var_tid()"); - printer->fmt_line("return {};", info.top_local_thread_id); - printer->pop_block(); - } - // clang-format on -} - - -/****************************************************************************************/ -/* Routines for returning variable name */ -/****************************************************************************************/ - - -std::string CodegenCppVisitor::float_variable_name(const SymbolType& symbol, - bool use_instance) const { - auto name = symbol->get_name(); - auto dimension = symbol->get_length(); - auto position = position_of_float_var(name); - // clang-format off - if (symbol->is_array()) { - if (use_instance) { - return fmt::format("(inst->{}+id*{})", name, dimension); - } - return fmt::format("(data + {}*pnodecount + id*{})", position, dimension); - } - if (use_instance) { - return fmt::format("inst->{}[id]", name); - } - return fmt::format("data[{}*pnodecount + id]", position); - // clang-format on -} - - -std::string CodegenCppVisitor::int_variable_name(const IndexVariableInfo& symbol, - const std::string& name, - bool use_instance) const { - auto position = position_of_int_var(name); - // clang-format off - if (symbol.is_index) { - if (use_instance) { - return fmt::format("inst->{}[{}]", name, position); - } - return fmt::format("indexes[{}]", position); - } - if (symbol.is_integer) { - if (use_instance) { - return fmt::format("inst->{}[{}*pnodecount+id]", name, position); - } - return fmt::format("indexes[{}*pnodecount+id]", position); - } - if (use_instance) { - return fmt::format("inst->{}[indexes[{}*pnodecount + id]]", name, position); - } - auto data = symbol.is_vdata ? "_vdata" : "_data"; - return fmt::format("nt->{}[indexes[{}*pnodecount + id]]", data, position); - // clang-format on -} - - -std::string CodegenCppVisitor::global_variable_name(const SymbolType& symbol, - bool use_instance) const { - if (use_instance) { - return fmt::format("inst->{}->{}", naming::INST_GLOBAL_MEMBER, symbol->get_name()); - } else { - return fmt::format("{}.{}", global_struct_instance(), symbol->get_name()); - } -} - - -std::string CodegenCppVisitor::update_if_ion_variable_name(const std::string& name) const { - std::string result(name); - if (ion_variable_struct_required()) { - if (info.is_ion_read_variable(name)) { - result = naming::ION_VARNAME_PREFIX + name; - } - if (info.is_ion_write_variable(name)) { - result = "ionvar." + name; - } - if (info.is_current(name)) { - result = "ionvar." + name; - } - } - return result; -} - - -std::string CodegenCppVisitor::get_variable_name(const std::string& name, bool use_instance) const { - const std::string& varname = update_if_ion_variable_name(name); - - // clang-format off - auto symbol_comparator = [&varname](const SymbolType& sym) { - return varname == sym->get_name(); - }; - - auto index_comparator = [&varname](const IndexVariableInfo& var) { - return varname == var.symbol->get_name(); - }; - // clang-format on - - // float variable - auto f = std::find_if(codegen_float_variables.begin(), - codegen_float_variables.end(), - symbol_comparator); - if (f != codegen_float_variables.end()) { - return float_variable_name(*f, use_instance); - } - - // integer variable - auto i = - std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), index_comparator); - if (i != codegen_int_variables.end()) { - return int_variable_name(*i, varname, use_instance); - } - - // global variable - auto g = std::find_if(codegen_global_variables.begin(), - codegen_global_variables.end(), - symbol_comparator); - if (g != codegen_global_variables.end()) { - return global_variable_name(*g, use_instance); - } - - if (varname == naming::NTHREAD_DT_VARIABLE) { - return std::string("nt->_") + naming::NTHREAD_DT_VARIABLE; - } - - // t in net_receive method is an argument to function and hence it should - // ne used instead of nt->_t which is current time of thread - if (varname == naming::NTHREAD_T_VARIABLE && !printing_net_receive) { - return std::string("nt->_") + naming::NTHREAD_T_VARIABLE; - } - - auto const iter = - std::find_if(info.neuron_global_variables.begin(), - info.neuron_global_variables.end(), - [&varname](auto const& entry) { return entry.first->get_name() == varname; }); - if (iter != info.neuron_global_variables.end()) { - std::string ret; - if (use_instance) { - ret = "*(inst->"; - } - ret.append(varname); - if (use_instance) { - ret.append(")"); - } - return ret; - } - - // otherwise return original name - return varname; -} - - -/****************************************************************************************/ -/* Main printing routines for code generation */ -/****************************************************************************************/ - - -void CodegenCppVisitor::print_backend_info() { - time_t current_time{}; - time(¤t_time); - std::string data_time_str{std::ctime(¤t_time)}; - auto version = nmodl::Version::NMODL_VERSION + " [" + nmodl::Version::GIT_REVISION + "]"; - - printer->add_line("/*********************************************************"); - printer->add_line("Model Name : ", info.mod_suffix); - printer->add_line("Filename : ", info.mod_file, ".mod"); - printer->add_line("NMODL Version : ", nmodl_version()); - printer->fmt_line("Vectorized : {}", info.vectorize); - printer->fmt_line("Threadsafe : {}", info.thread_safe); - printer->add_line("Created : ", stringutils::trim(data_time_str)); - printer->add_line("Backend : ", backend_name()); - printer->add_line("NMODL Compiler : ", version); - printer->add_line("*********************************************************/"); -} - - -void CodegenCppVisitor::print_standard_includes() { - printer->add_newline(); - printer->add_multi_line(R"CODE( - #include - #include - #include - #include - )CODE"); -} - - -void CodegenCppVisitor::print_coreneuron_includes() { - printer->add_newline(); - printer->add_multi_line(R"CODE( - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - )CODE"); - if (info.eigen_newton_solver_exist) { - printer->add_line("#include "); - } - if (info.eigen_linear_solver_exist) { - if (std::accumulate(info.state_vars.begin(), - info.state_vars.end(), - 0, - [](int l, const SymbolType& variable) { - return l += variable->get_length(); - }) > 4) { - printer->add_line("#include "); - } else { - printer->add_line("#include "); - printer->add_line("#include "); - } - } -} - - -/** - * \details Variables required for type of ion, type of point process etc. are - * of static int type. For the C++ backend type, it's ok to have - * these variables as file scoped static variables. - * - * Initial values of state variables (h0) are also defined as static - * variables. Note that the state could be ion variable and it could - * be also range variable. Hence lookup into symbol table before. - * - * When model is not vectorized (shouldn't be the case in coreneuron) - * the top local variables become static variables. - * - * Note that static variables are already initialized to 0. We do the - * same for some variables to keep same code as neuron. - */ -// NOLINTNEXTLINE(readability-function-cognitive-complexity) -void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initializers) { - const auto value_initialize = print_initializers ? "{}" : ""; - const auto qualifier = global_var_struct_type_qualifier(); - - auto float_type = default_float_data_type(); - printer->add_newline(2); - printer->add_line("/** all global variables */"); - printer->fmt_push_block("struct {}", global_struct()); - - for (const auto& ion: info.ions) { - auto name = fmt::format("{}_type", ion.name); - printer->fmt_line("{}int {}{};", qualifier, name, value_initialize); - codegen_global_variables.push_back(make_symbol(name)); - } - - if (info.point_process) { - printer->fmt_line("{}int point_type{};", qualifier, value_initialize); - codegen_global_variables.push_back(make_symbol("point_type")); - } - - for (const auto& var: info.state_vars) { - auto name = var->get_name() + "0"; - auto symbol = program_symtab->lookup(name); - if (symbol == nullptr) { - printer->fmt_line("{}{} {}{};", qualifier, float_type, name, value_initialize); - codegen_global_variables.push_back(make_symbol(name)); - } - } - - // Neuron and Coreneuron adds "v" to global variables when vectorize - // is false. But as v is always local variable and passed as argument, - // we don't need to use global variable v - - auto& top_locals = info.top_local_variables; - if (!info.vectorize && !top_locals.empty()) { - for (const auto& var: top_locals) { - auto name = var->get_name(); - auto length = var->get_length(); - if (var->is_array()) { - printer->fmt_line("{}{} {}[{}] /* TODO init top-local-array */;", - qualifier, - float_type, - name, - length); - } else { - printer->fmt_line("{}{} {} /* TODO init top-local */;", - qualifier, - float_type, - name); - } - codegen_global_variables.push_back(var); - } - } - - if (!info.thread_variables.empty()) { - printer->fmt_line("{}int thread_data_in_use{};", qualifier, value_initialize); - printer->fmt_line("{}{} thread_data[{}] /* TODO init thread_data */;", - qualifier, - float_type, - info.thread_var_data_size); - codegen_global_variables.push_back(make_symbol("thread_data_in_use")); - auto symbol = make_symbol("thread_data"); - symbol->set_as_array(info.thread_var_data_size); - codegen_global_variables.push_back(symbol); - } - - // TODO: remove this entirely? - printer->fmt_line("{}int reset{};", qualifier, value_initialize); - codegen_global_variables.push_back(make_symbol("reset")); - - printer->fmt_line("{}int mech_type{};", qualifier, value_initialize); - codegen_global_variables.push_back(make_symbol("mech_type")); - - for (const auto& var: info.global_variables) { - auto name = var->get_name(); - auto length = var->get_length(); - if (var->is_array()) { - printer->fmt_line( - "{}{} {}[{}] /* TODO init const-array */;", qualifier, float_type, name, length); - } else { - double value{}; - if (auto const& value_ptr = var->get_value()) { - value = *value_ptr; - } - printer->fmt_line("{}{} {}{};", - qualifier, - float_type, - name, - print_initializers ? fmt::format("{{{:g}}}", value) : std::string{}); - } - codegen_global_variables.push_back(var); - } - - for (const auto& var: info.constant_variables) { - auto const name = var->get_name(); - auto* const value_ptr = var->get_value().get(); - double const value{value_ptr ? *value_ptr : 0}; - printer->fmt_line("{}{} {}{};", - qualifier, - float_type, - name, - print_initializers ? fmt::format("{{{:g}}}", value) : std::string{}); - codegen_global_variables.push_back(var); - } - - if (info.primes_size != 0) { - const auto count_prime_variables = [](auto size, const SymbolType& symbol) { - return size += symbol->get_length(); - }; - const auto prime_variables_by_order_size = - std::accumulate(info.prime_variables_by_order.begin(), - info.prime_variables_by_order.end(), - 0, - count_prime_variables); - if (info.primes_size != prime_variables_by_order_size) { - throw std::runtime_error{ - fmt::format("primes_size = {} differs from prime_variables_by_order.size() = {}, " - "this should not happen.", - info.primes_size, - info.prime_variables_by_order.size())}; - } - auto const initializer_list = [&](auto const& primes, const char* prefix) -> std::string { - if (!print_initializers) { - return {}; - } - std::string list{"{"}; - for (auto iter = primes.begin(); iter != primes.end(); ++iter) { - auto const& prime = *iter; - list.append(std::to_string(position_of_float_var(prefix + prime->get_name()))); - if (std::next(iter) != primes.end()) { - list.append(", "); - } - } - list.append("}"); - return list; - }; - printer->fmt_line("{}int slist1[{}]{};", - qualifier, - info.primes_size, - initializer_list(info.prime_variables_by_order, "")); - printer->fmt_line("{}int dlist1[{}]{};", - qualifier, - info.primes_size, - initializer_list(info.prime_variables_by_order, "D")); - codegen_global_variables.push_back(make_symbol("slist1")); - codegen_global_variables.push_back(make_symbol("dlist1")); - // additional list for derivimplicit method - if (info.derivimplicit_used()) { - auto primes = program_symtab->get_variables_with_properties(NmodlType::prime_name); - printer->fmt_line("{}int slist2[{}]{};", - qualifier, - info.primes_size, - initializer_list(primes, "")); - codegen_global_variables.push_back(make_symbol("slist2")); - } - } - - if (info.table_count > 0) { - printer->fmt_line("{}double usetable{};", qualifier, print_initializers ? "{1}" : ""); - codegen_global_variables.push_back(make_symbol(naming::USE_TABLE_VARIABLE)); - - for (const auto& block: info.functions_with_table) { - const auto& name = block->get_node_name(); - printer->fmt_line("{}{} tmin_{}{};", qualifier, float_type, name, value_initialize); - printer->fmt_line("{}{} mfac_{}{};", qualifier, float_type, name, value_initialize); - codegen_global_variables.push_back(make_symbol("tmin_" + name)); - codegen_global_variables.push_back(make_symbol("mfac_" + name)); - } - - for (const auto& variable: info.table_statement_variables) { - auto const name = "t_" + variable->get_name(); - auto const num_values = variable->get_num_values(); - if (variable->is_array()) { - int array_len = variable->get_length(); - printer->fmt_line("{}{} {}[{}][{}]{};", - qualifier, - float_type, - name, - array_len, - num_values, - value_initialize); - } else { - printer->fmt_line( - "{}{} {}[{}]{};", qualifier, float_type, name, num_values, value_initialize); - } - codegen_global_variables.push_back(make_symbol(name)); - } - } - - for (const auto& f: info.function_tables) { - printer->fmt_line("void* _ptable_{}{{}};", f->get_node_name()); - codegen_global_variables.push_back(make_symbol("_ptable_" + f->get_node_name())); - } - - if (info.vectorize && info.thread_data_index) { - printer->fmt_line("{}ThreadDatum ext_call_thread[{}]{};", - qualifier, - info.thread_data_index, - value_initialize); - codegen_global_variables.push_back(make_symbol("ext_call_thread")); - } - - printer->pop_block(";"); - - print_global_var_struct_assertions(); - print_global_var_struct_decl(); -} - -void CodegenCppVisitor::print_global_var_struct_assertions() const { - // Assert some things that we assume when copying instances of this struct - // to the GPU and so on. - printer->fmt_line("static_assert(std::is_trivially_copy_constructible_v<{}>);", - global_struct()); - printer->fmt_line("static_assert(std::is_trivially_move_constructible_v<{}>);", - global_struct()); - printer->fmt_line("static_assert(std::is_trivially_copy_assignable_v<{}>);", global_struct()); - printer->fmt_line("static_assert(std::is_trivially_move_assignable_v<{}>);", global_struct()); - printer->fmt_line("static_assert(std::is_trivially_destructible_v<{}>);", global_struct()); -} - - -void CodegenCppVisitor::print_prcellstate_macros() const { - printer->add_line("#ifndef NRN_PRCELLSTATE"); - printer->add_line("#define NRN_PRCELLSTATE 0"); - printer->add_line("#endif"); -} - - -void CodegenCppVisitor::print_mechanism_info() { - auto variable_printer = [&](const std::vector& variables) { - for (const auto& v: variables) { - auto name = v->get_name(); - if (!info.point_process) { - name += "_" + info.mod_suffix; - } - if (v->is_array()) { - name += fmt::format("[{}]", v->get_length()); - } - printer->add_line(add_escape_quote(name), ","); - } - }; - - printer->add_newline(2); - printer->add_line("/** channel information */"); - printer->add_line("static const char *mechanism[] = {"); - printer->increase_indent(); - printer->add_line(add_escape_quote(nmodl_version()), ","); - printer->add_line(add_escape_quote(info.mod_suffix), ","); - variable_printer(info.range_parameter_vars); - printer->add_line("0,"); - variable_printer(info.range_assigned_vars); - printer->add_line("0,"); - variable_printer(info.range_state_vars); - printer->add_line("0,"); - variable_printer(info.pointer_variables); - printer->add_line("0"); - printer->decrease_indent(); - printer->add_line("};"); -} - - -/** - * Print structs that encapsulate information about scalar and - * vector elements of type global and thread variables. - */ -void CodegenCppVisitor::print_global_variables_for_hoc() { - auto variable_printer = - [&](const std::vector& variables, bool if_array, bool if_vector) { - for (const auto& variable: variables) { - if (variable->is_array() == if_array) { - // false => do not use the instance struct, which is not - // defined in the global declaration that we are printing - auto name = get_variable_name(variable->get_name(), false); - auto ename = add_escape_quote(variable->get_name() + "_" + info.mod_suffix); - auto length = variable->get_length(); - if (if_vector) { - printer->fmt_line("{{{}, {}, {}}},", ename, name, length); - } else { - printer->fmt_line("{{{}, &{}}},", ename, name); - } - } - } - }; - - auto globals = info.global_variables; - auto thread_vars = info.thread_variables; - - if (info.table_count > 0) { - globals.push_back(make_symbol(naming::USE_TABLE_VARIABLE)); - } - - printer->add_newline(2); - printer->add_line("/** connect global (scalar) variables to hoc -- */"); - printer->add_line("static DoubScal hoc_scalar_double[] = {"); - printer->increase_indent(); - variable_printer(globals, false, false); - variable_printer(thread_vars, false, false); - printer->add_line("{nullptr, nullptr}"); - printer->decrease_indent(); - printer->add_line("};"); - - printer->add_newline(2); - printer->add_line("/** connect global (array) variables to hoc -- */"); - printer->add_line("static DoubVec hoc_vector_double[] = {"); - printer->increase_indent(); - variable_printer(globals, true, true); - variable_printer(thread_vars, true, true); - printer->add_line("{nullptr, nullptr, 0}"); - printer->decrease_indent(); - printer->add_line("};"); -} - -/** - * Return registration type for a given BEFORE/AFTER block - * /param block A BEFORE/AFTER block being registered - * - * Depending on a block type i.e. BEFORE or AFTER and also type - * of it's associated block i.e. BREAKPOINT, INITIAL, SOLVE and - * STEP, the registration type (as an integer) is calculated. - * These values are then interpreted by CoreNEURON internally. - */ -static std::string get_register_type_for_ba_block(const ast::Block* block) { - std::string register_type{}; - BAType ba_type{}; - /// before block have value 10 and after block 20 - if (block->is_before_block()) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) - register_type = "BAType::Before"; - ba_type = - dynamic_cast(block)->get_bablock()->get_type()->get_value(); - } else { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) - register_type = "BAType::After"; - ba_type = - dynamic_cast(block)->get_bablock()->get_type()->get_value(); - } - - /// associated blocks have different values (1 to 4) based on type. - /// These values are based on neuron/coreneuron implementation details. - if (ba_type == BATYPE_BREAKPOINT) { - register_type += " + BAType::Breakpoint"; - } else if (ba_type == BATYPE_SOLVE) { - register_type += " + BAType::Solve"; - } else if (ba_type == BATYPE_INITIAL) { - register_type += " + BAType::Initial"; - } else if (ba_type == BATYPE_STEP) { - register_type += " + BAType::Step"; - } else { - throw std::runtime_error("Unhandled Before/After type encountered during code generation"); - } - return register_type; -} - - -/** - * \details Every mod file has register function to connect with the simulator. - * Various information about mechanism and callbacks get registered with - * the simulator using suffix_reg() function. - * - * Here are details: - * - We should exclude that callback based on the solver, watch statements. - * - If nrn_get_mechtype is < -1 means that mechanism is not used in the - * context of neuron execution and hence could be ignored in coreneuron - * execution. - * - Ions are internally defined and their types can be queried similar to - * other mechanisms. - * - hoc_register_var may not be needed in the context of coreneuron - * - We assume net receive buffer is on. This is because generated code is - * compatible for cpu as well as gpu target. - */ -// NOLINTNEXTLINE(readability-function-cognitive-complexity) -void CodegenCppVisitor::print_mechanism_register() { - printer->add_newline(2); - printer->add_line("/** register channel with the simulator */"); - printer->fmt_push_block("void _{}_reg()", info.mod_file); - - // type related information - auto suffix = add_escape_quote(info.mod_suffix); - printer->add_newline(); - printer->fmt_line("int mech_type = nrn_get_mechtype({});", suffix); - printer->fmt_line("{} = mech_type;", get_variable_name("mech_type", false)); - printer->push_block("if (mech_type == -1)"); - printer->add_line("return;"); - printer->pop_block(); - - printer->add_newline(); - printer->add_line("_nrn_layout_reg(mech_type, 0);"); // 0 for SoA - - // register mechanism - const auto mech_arguments = register_mechanism_arguments(); - const auto number_of_thread_objects = num_thread_objects(); - if (info.point_process) { - printer->fmt_line("point_register_mech({}, {}, {}, {});", - mech_arguments, - info.constructor_node ? method_name(naming::NRN_CONSTRUCTOR_METHOD) - : "nullptr", - info.destructor_node ? method_name(naming::NRN_DESTRUCTOR_METHOD) - : "nullptr", - number_of_thread_objects); - } else { - printer->fmt_line("register_mech({}, {});", mech_arguments, number_of_thread_objects); - if (info.constructor_node) { - printer->fmt_line("register_constructor({});", - method_name(naming::NRN_CONSTRUCTOR_METHOD)); - } - } - - // types for ion - for (const auto& ion: info.ions) { - printer->fmt_line("{} = nrn_get_mechtype({});", - get_variable_name(ion.name + "_type", false), - add_escape_quote(ion.name + "_ion")); - } - printer->add_newline(); - - /* - * Register callbacks for thread allocation and cleanup. Note that thread_data_index - * represent total number of thread used minus 1 (i.e. index of last thread). - */ - if (info.vectorize && (info.thread_data_index != 0)) { - // false to avoid getting the copy from the instance structure - printer->fmt_line("thread_mem_init({});", get_variable_name("ext_call_thread", false)); - } - - if (!info.thread_variables.empty()) { - printer->fmt_line("{} = 0;", get_variable_name("thread_data_in_use")); - } - - if (info.thread_callback_register) { - printer->add_line("_nrn_thread_reg0(mech_type, thread_mem_cleanup);"); - printer->add_line("_nrn_thread_reg1(mech_type, thread_mem_init);"); - } - - if (info.emit_table_thread()) { - auto name = method_name("check_table_thread"); - printer->fmt_line("_nrn_thread_table_reg(mech_type, {});", name); - } - - // register read/write callbacks for pointers - if (info.bbcore_pointer_used) { - printer->add_line("hoc_reg_bbcore_read(mech_type, bbcore_read);"); - printer->add_line("hoc_reg_bbcore_write(mech_type, bbcore_write);"); - } - - // register size of double and int elements - // clang-format off - printer->add_line("hoc_register_prop_size(mech_type, float_variables_size(), int_variables_size());"); - // clang-format on - - // register semantics for index variables - for (auto& semantic: info.semantics) { - auto args = - fmt::format("mech_type, {}, {}", semantic.index, add_escape_quote(semantic.name)); - printer->fmt_line("hoc_register_dparam_semantics({});", args); - } - - if (info.is_watch_used()) { - auto watch_fun = compute_method_name(BlockType::Watch); - printer->fmt_line("hoc_register_watch_check({}, mech_type);", watch_fun); - } - - if (info.write_concentration) { - printer->add_line("nrn_writes_conc(mech_type, 0);"); - } - - // register various information for point process type - if (info.net_event_used) { - printer->add_line("add_nrn_has_net_event(mech_type);"); - } - if (info.artificial_cell) { - printer->fmt_line("add_nrn_artcell(mech_type, {});", info.tqitem_index); - } - if (net_receive_buffering_required()) { - printer->fmt_line("hoc_register_net_receive_buffering({}, mech_type);", - method_name("net_buf_receive")); - } - if (info.num_net_receive_parameters != 0) { - auto net_recv_init_arg = "nullptr"; - if (info.net_receive_initial_node != nullptr) { - net_recv_init_arg = "net_init"; - } - printer->fmt_line("set_pnt_receive(mech_type, {}, {}, num_net_receive_args());", - method_name("net_receive"), - net_recv_init_arg); - } - if (info.for_netcon_used) { - // index where information about FOR_NETCON is stored in the integer array - const auto index = - std::find_if(info.semantics.begin(), info.semantics.end(), [](const IndexSemantics& a) { - return a.name == naming::FOR_NETCON_SEMANTIC; - })->index; - printer->fmt_line("add_nrn_fornetcons(mech_type, {});", index); - } - - if (info.net_event_used || info.net_send_used) { - printer->add_line("hoc_register_net_send_buffering(mech_type);"); - } - - /// register all before/after blocks - for (size_t i = 0; i < info.before_after_blocks.size(); i++) { - // register type and associated function name for the block - const auto& block = info.before_after_blocks[i]; - std::string register_type = get_register_type_for_ba_block(block); - std::string function_name = method_name(fmt::format("nrn_before_after_{}", i)); - printer->fmt_line("hoc_reg_ba(mech_type, {}, {});", function_name, register_type); - } - - // register variables for hoc - printer->add_line("hoc_register_var(hoc_scalar_double, hoc_vector_double, NULL);"); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_thread_memory_callbacks() { - if (!info.thread_callback_register) { - return; - } - - // thread_mem_init callback - printer->add_newline(2); - printer->add_line("/** thread memory allocation callback */"); - printer->push_block("static void thread_mem_init(ThreadDatum* thread) "); - - if (info.vectorize && info.derivimplicit_used()) { - printer->fmt_line("thread[dith{}()].pval = nullptr;", info.derivimplicit_list_num); - } - if (info.vectorize && (info.top_local_thread_size != 0)) { - auto length = info.top_local_thread_size; - auto allocation = fmt::format("(double*)mem_alloc({}, sizeof(double))", length); - printer->fmt_line("thread[top_local_var_tid()].pval = {};", allocation); - } - if (info.thread_var_data_size != 0) { - auto length = info.thread_var_data_size; - auto thread_data = get_variable_name("thread_data"); - auto thread_data_in_use = get_variable_name("thread_data_in_use"); - auto allocation = fmt::format("(double*)mem_alloc({}, sizeof(double))", length); - printer->fmt_push_block("if ({})", thread_data_in_use); - printer->fmt_line("thread[thread_var_tid()].pval = {};", allocation); - printer->chain_block("else"); - printer->fmt_line("thread[thread_var_tid()].pval = {};", thread_data); - printer->fmt_line("{} = 1;", thread_data_in_use); - printer->pop_block(); - } - printer->pop_block(); - printer->add_newline(2); - - - // thread_mem_cleanup callback - printer->add_line("/** thread memory cleanup callback */"); - printer->push_block("static void thread_mem_cleanup(ThreadDatum* thread) "); - - // clang-format off - if (info.vectorize && info.derivimplicit_used()) { - int n = info.derivimplicit_list_num; - printer->fmt_line("free(thread[dith{}()].pval);", n); - printer->fmt_line("nrn_destroy_newtonspace(static_cast(*newtonspace{}(thread)));", n); - } - // clang-format on - - if (info.top_local_thread_size != 0) { - auto line = "free(thread[top_local_var_tid()].pval);"; - printer->add_line(line); - } - if (info.thread_var_data_size != 0) { - auto thread_data = get_variable_name("thread_data"); - auto thread_data_in_use = get_variable_name("thread_data_in_use"); - printer->fmt_push_block("if (thread[thread_var_tid()].pval == {})", thread_data); - printer->fmt_line("{} = 0;", thread_data_in_use); - printer->chain_block("else"); - printer->add_line("free(thread[thread_var_tid()].pval);"); - printer->pop_block(); - } - printer->pop_block(); -} - - -void CodegenCppVisitor::print_mechanism_range_var_structure(bool print_initializers) { - auto const value_initialize = print_initializers ? "{}" : ""; - auto int_type = default_int_data_type(); - printer->add_newline(2); - printer->add_line("/** all mechanism instance variables and global variables */"); - printer->fmt_push_block("struct {} ", instance_struct()); - - for (auto const& [var, type]: info.neuron_global_variables) { - auto const name = var->get_name(); - printer->fmt_line("{}* {}{};", - type, - name, - print_initializers ? fmt::format("{{&coreneuron::{}}}", name) - : std::string{}); - } - for (auto& var: codegen_float_variables) { - const auto& name = var->get_name(); - auto type = get_range_var_float_type(var); - auto qualifier = is_constant_variable(name) ? "const " : ""; - printer->fmt_line("{}{}* {}{};", qualifier, type, name, value_initialize); - } - for (auto& var: codegen_int_variables) { - const auto& name = var.symbol->get_name(); - if (var.is_index || var.is_integer) { - auto qualifier = var.is_constant ? "const " : ""; - printer->fmt_line("{}{}* {}{};", qualifier, int_type, name, value_initialize); - } else { - auto qualifier = var.is_constant ? "const " : ""; - auto type = var.is_vdata ? "void*" : default_float_data_type(); - printer->fmt_line("{}{}* {}{};", qualifier, type, name, value_initialize); - } - } - - printer->fmt_line("{}* {}{};", - global_struct(), - naming::INST_GLOBAL_MEMBER, - print_initializers ? fmt::format("{{&{}}}", global_struct_instance()) - : std::string{}); - printer->pop_block(";"); -} - - -void CodegenCppVisitor::print_ion_var_structure() { - if (!ion_variable_struct_required()) { - return; - } - printer->add_newline(2); - printer->add_line("/** ion write variables */"); - printer->push_block("struct IonCurVar"); - - std::string float_type = default_float_data_type(); - std::vector members; - - for (auto& ion: info.ions) { - for (auto& var: ion.writes) { - printer->fmt_line("{} {};", float_type, var); - members.push_back(var); - } - } - for (auto& var: info.currents) { - if (!info.is_ion_variable(var)) { - printer->fmt_line("{} {};", float_type, var); - members.push_back(var); - } - } - - print_ion_var_constructor(members); - - printer->pop_block(";"); -} - - -void CodegenCppVisitor::print_ion_var_constructor(const std::vector& members) { - // constructor - printer->add_newline(); - printer->add_indent(); - printer->add_text("IonCurVar() : "); - for (int i = 0; i < members.size(); i++) { - printer->fmt_text("{}(0)", members[i]); - if (i + 1 < members.size()) { - printer->add_text(", "); - } - } - printer->add_text(" {}"); - printer->add_newline(); -} - - -void CodegenCppVisitor::print_ion_variable() { - printer->add_line("IonCurVar ionvar;"); -} - - -void CodegenCppVisitor::print_global_variable_device_update_annotation() { - // nothing for cpu -} - - -void CodegenCppVisitor::print_setup_range_variable() { - auto type = float_data_type(); - printer->add_newline(2); - printer->add_line("/** allocate and setup array for range variable */"); - printer->fmt_push_block("static inline {}* setup_range_variable(double* variable, int n)", - type); - printer->fmt_line("{0}* data = ({0}*) mem_alloc(n, sizeof({0}));", type); - printer->push_block("for(size_t i = 0; i < n; i++)"); - printer->add_line("data[i] = variable[i];"); - printer->pop_block(); - printer->add_line("return data;"); - printer->pop_block(); -} - - -/** - * \details If floating point type like "float" is specified on command line then - * we can't turn all variables to new type. This is because certain variables - * are pointers to internal variables (e.g. ions). Hence, we check if given - * variable can be safely converted to new type. If so, return new type. - */ -std::string CodegenCppVisitor::get_range_var_float_type(const SymbolType& symbol) { - // clang-format off - auto with = NmodlType::read_ion_var - | NmodlType::write_ion_var - | NmodlType::pointer_var - | NmodlType::bbcore_pointer_var - | NmodlType::extern_neuron_variable; - // clang-format on - bool need_default_type = symbol->has_any_property(with); - if (need_default_type) { - return default_float_data_type(); - } - return float_data_type(); -} - - -void CodegenCppVisitor::print_instance_variable_setup() { - if (range_variable_setup_required()) { - print_setup_range_variable(); - } - - printer->add_newline(); - printer->add_line("// Allocate instance structure"); - printer->fmt_push_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", - method_name(naming::NRN_PRIVATE_CONSTRUCTOR_METHOD)); - printer->add_line("assert(!ml->instance);"); - printer->add_line("assert(!ml->global_variables);"); - printer->add_line("assert(ml->global_variables_size == 0);"); - printer->fmt_line("auto* const inst = new {}{{}};", instance_struct()); - printer->fmt_line("assert(inst->{} == &{});", - naming::INST_GLOBAL_MEMBER, - global_struct_instance()); - printer->add_line("ml->instance = inst;"); - printer->fmt_line("ml->global_variables = inst->{};", naming::INST_GLOBAL_MEMBER); - printer->fmt_line("ml->global_variables_size = sizeof({});", global_struct()); - printer->pop_block(); - printer->add_newline(); - - auto const cast_inst_and_assert_validity = [&]() { - printer->fmt_line("auto* const inst = static_cast<{}*>(ml->instance);", instance_struct()); - printer->add_line("assert(inst);"); - printer->fmt_line("assert(inst->{});", naming::INST_GLOBAL_MEMBER); - printer->fmt_line("assert(inst->{} == &{});", - naming::INST_GLOBAL_MEMBER, - global_struct_instance()); - printer->fmt_line("assert(inst->{} == ml->global_variables);", naming::INST_GLOBAL_MEMBER); - printer->fmt_line("assert(ml->global_variables_size == sizeof({}));", global_struct()); - }; - - // Must come before print_instance_struct_copy_to_device and - // print_instance_struct_delete_from_device - print_instance_struct_transfer_routine_declarations(); - - printer->add_line("// Deallocate the instance structure"); - printer->fmt_push_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", - method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD)); - cast_inst_and_assert_validity(); - print_instance_struct_delete_from_device(); - printer->add_multi_line(R"CODE( - delete inst; - ml->instance = nullptr; - ml->global_variables = nullptr; - ml->global_variables_size = 0; - )CODE"); - printer->pop_block(); - printer->add_newline(); - - - printer->add_line("/** initialize mechanism instance variables */"); - printer->push_block("static inline void setup_instance(NrnThread* nt, Memb_list* ml)"); - cast_inst_and_assert_validity(); - - std::string stride; - printer->add_line("int pnodecount = ml->_nodecount_padded;"); - stride = "*pnodecount"; - - printer->add_line("Datum* indexes = ml->pdata;"); - - auto const float_type = default_float_data_type(); - - int id = 0; - std::vector ptr_members{naming::INST_GLOBAL_MEMBER}; - for (auto const& [var, type]: info.neuron_global_variables) { - ptr_members.push_back(var->get_name()); - } - ptr_members.reserve(ptr_members.size() + codegen_float_variables.size() + - codegen_int_variables.size()); - for (auto& var: codegen_float_variables) { - auto name = var->get_name(); - auto range_var_type = get_range_var_float_type(var); - if (float_type == range_var_type) { - auto const variable = fmt::format("ml->data+{}{}", id, stride); - printer->fmt_line("inst->{} = {};", name, variable); - } else { - // TODO what MOD file exercises this? - printer->fmt_line("inst->{} = setup_range_variable(ml->data+{}{}, pnodecount);", - name, - id, - stride); - } - ptr_members.push_back(std::move(name)); - id += var->get_length(); - } - - for (auto& var: codegen_int_variables) { - auto name = var.symbol->get_name(); - auto const variable = [&var]() { - if (var.is_index || var.is_integer) { - return "ml->pdata"; - } else if (var.is_vdata) { - return "nt->_vdata"; - } else { - return "nt->_data"; - } - }(); - printer->fmt_line("inst->{} = {};", name, variable); - ptr_members.push_back(std::move(name)); - } - print_instance_struct_copy_to_device(); - printer->pop_block(); // setup_instance - printer->add_newline(); - - print_instance_struct_transfer_routines(ptr_members); -} - - -void CodegenCppVisitor::print_initial_block(const InitialBlock* node) { - if (info.artificial_cell) { - printer->add_line("double v = 0.0;"); - } else { - printer->add_line("int node_id = node_index[id];"); - printer->add_line("double v = voltage[node_id];"); - print_v_unused(); - } - - if (ion_variable_struct_required()) { - printer->add_line("IonCurVar ionvar;"); - } - - // read ion statements - auto read_statements = ion_read_statements(BlockType::Initial); - for (auto& statement: read_statements) { - printer->add_line(statement); - } - - // initialize state variables (excluding ion state) - for (auto& var: info.state_vars) { - auto name = var->get_name(); - if (!info.is_ionic_conc(name)) { - auto lhs = get_variable_name(name); - auto rhs = get_variable_name(name + "0"); - if (var->is_array()) { - for (int i = 0; i < var->get_length(); ++i) { - printer->fmt_line("{}[{}] = {};", lhs, i, rhs); - } - } else { - printer->fmt_line("{} = {};", lhs, rhs); - } - } - } - - // initial block - if (node != nullptr) { - const auto& block = node->get_statement_block(); - print_statement_block(*block, false, false); - } - - // write ion statements - auto write_statements = ion_write_statements(BlockType::Initial); - for (auto& statement: write_statements) { - auto text = process_shadow_update_statement(statement, BlockType::Initial); - printer->add_line(text); - } -} - - -void CodegenCppVisitor::print_global_function_common_code(BlockType type, - const std::string& function_name) { - std::string method; - if (function_name.empty()) { - method = compute_method_name(type); - } else { - method = function_name; - } - auto args = "NrnThread* nt, Memb_list* ml, int type"; - - // watch statement function doesn't have type argument - if (type == BlockType::Watch) { - args = "NrnThread* nt, Memb_list* ml"; - } - - print_global_method_annotation(); - printer->fmt_push_block("void {}({})", method, args); - if (type != BlockType::Destructor && type != BlockType::Constructor) { - // We do not (currently) support DESTRUCTOR and CONSTRUCTOR blocks - // running anything on the GPU. - print_kernel_data_present_annotation_block_begin(); - } else { - /// TODO: Remove this when the code generation is propery done - /// Related to https://github.com/BlueBrain/nmodl/issues/692 - printer->add_line("#ifndef CORENEURON_BUILD"); - } - printer->add_multi_line(R"CODE( - int nodecount = ml->nodecount; - int pnodecount = ml->_nodecount_padded; - const int* node_index = ml->nodeindices; - double* data = ml->data; - const double* voltage = nt->_actual_v; - )CODE"); - - if (type == BlockType::Equation) { - printer->add_line("double* vec_rhs = nt->_actual_rhs;"); - printer->add_line("double* vec_d = nt->_actual_d;"); - print_rhs_d_shadow_variables(); - } - printer->add_line("Datum* indexes = ml->pdata;"); - printer->add_line("ThreadDatum* thread = ml->_thread;"); - - if (type == BlockType::Initial) { - printer->add_newline(); - printer->add_line("setup_instance(nt, ml);"); - } - printer->fmt_line("auto* const inst = static_cast<{}*>(ml->instance);", instance_struct()); - printer->add_newline(1); -} - -void CodegenCppVisitor::print_nrn_init(bool skip_init_check) { - codegen = true; - printer->add_newline(2); - printer->add_line("/** initialize channel */"); - - print_global_function_common_code(BlockType::Initial); - if (info.derivimplicit_used()) { - printer->add_newline(); - int nequation = info.num_equations; - int list_num = info.derivimplicit_list_num; - // clang-format off - printer->fmt_line("int& deriv_advance_flag = *deriv{}_advance(thread);", list_num); - printer->add_line("deriv_advance_flag = 0;"); - print_deriv_advance_flag_transfer_to_device(); - printer->fmt_line("auto ns = newtonspace{}(thread);", list_num); - printer->fmt_line("auto& th = thread[dith{}()];", list_num); - printer->push_block("if (*ns == nullptr)"); - printer->fmt_line("int vec_size = 2*{}*pnodecount*sizeof(double);", nequation); - printer->fmt_line("double* vec = makevector(vec_size);", nequation); - printer->fmt_line("th.pval = vec;", list_num); - printer->fmt_line("*ns = nrn_cons_newtonspace({}, pnodecount);", nequation); - print_newtonspace_transfer_to_device(); - printer->pop_block(); - // clang-format on - } - - // update global variable as those might be updated via python/hoc API - // NOTE: CoreNEURON has enough information to do this on its own, which - // would be neater. - print_global_variable_device_update_annotation(); - - if (skip_init_check) { - printer->push_block("if (_nrn_skip_initmodel == 0)"); - } - - if (!info.changed_dt.empty()) { - printer->fmt_line("double _save_prev_dt = {};", - get_variable_name(naming::NTHREAD_DT_VARIABLE)); - printer->fmt_line("{} = {};", - get_variable_name(naming::NTHREAD_DT_VARIABLE), - info.changed_dt); - print_dt_update_to_device(); - } - - print_channel_iteration_block_parallel_hint(BlockType::Initial, info.initial_node); - printer->push_block("for (int id = 0; id < nodecount; id++)"); - - if (info.net_receive_node != nullptr) { - printer->fmt_line("{} = -1e20;", get_variable_name("tsave")); - } - - print_initial_block(info.initial_node); - printer->pop_block(); - - if (!info.changed_dt.empty()) { - printer->fmt_line("{} = _save_prev_dt;", get_variable_name(naming::NTHREAD_DT_VARIABLE)); - print_dt_update_to_device(); - } - - printer->pop_block(); - - if (info.derivimplicit_used()) { - printer->add_line("deriv_advance_flag = 1;"); - print_deriv_advance_flag_transfer_to_device(); - } - - if (info.net_send_used && !info.artificial_cell) { - print_send_event_move(); - } - - print_kernel_data_present_annotation_block_end(); - if (skip_init_check) { - printer->pop_block(); - } - codegen = false; -} - -void CodegenCppVisitor::print_before_after_block(const ast::Block* node, size_t block_id) { - codegen = true; - - std::string ba_type; - std::shared_ptr ba_block; - - if (node->is_before_block()) { - ba_block = dynamic_cast(node)->get_bablock(); - ba_type = "BEFORE"; - } else { - ba_block = dynamic_cast(node)->get_bablock(); - ba_type = "AFTER"; - } - - std::string ba_block_type = ba_block->get_type()->eval(); - - /// name of the before/after function - std::string function_name = method_name(fmt::format("nrn_before_after_{}", block_id)); - - /// print common function code like init/state/current - printer->add_newline(2); - printer->fmt_line("/** {} of block type {} # {} */", ba_type, ba_block_type, block_id); - print_global_function_common_code(BlockType::BeforeAfter, function_name); - - print_channel_iteration_block_parallel_hint(BlockType::BeforeAfter, node); - printer->push_block("for (int id = 0; id < nodecount; id++)"); - - printer->add_line("int node_id = node_index[id];"); - printer->add_line("double v = voltage[node_id];"); - print_v_unused(); - - // read ion statements - const auto& read_statements = ion_read_statements(BlockType::Equation); - for (auto& statement: read_statements) { - printer->add_line(statement); - } - - /// print main body - printer->add_indent(); - print_statement_block(*ba_block->get_statement_block()); - printer->add_newline(); - - // write ion statements - const auto& write_statements = ion_write_statements(BlockType::Equation); - for (auto& statement: write_statements) { - auto text = process_shadow_update_statement(statement, BlockType::Equation); - printer->add_line(text); - } - - /// loop end including data annotation block - printer->pop_block(); - printer->pop_block(); - print_kernel_data_present_annotation_block_end(); - - codegen = false; -} - -void CodegenCppVisitor::print_nrn_constructor() { - printer->add_newline(2); - print_global_function_common_code(BlockType::Constructor); - if (info.constructor_node != nullptr) { - const auto& block = info.constructor_node->get_statement_block(); - print_statement_block(*block, false, false); - } - printer->add_line("#endif"); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_nrn_destructor() { - printer->add_newline(2); - print_global_function_common_code(BlockType::Destructor); - if (info.destructor_node != nullptr) { - const auto& block = info.destructor_node->get_statement_block(); - print_statement_block(*block, false, false); - } - printer->add_line("#endif"); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_functors_definitions() { - codegen = true; - for (const auto& functor_name: info.functor_names) { - printer->add_newline(2); - print_functor_definition(*functor_name.first); - } - codegen = false; -} - - -void CodegenCppVisitor::print_nrn_alloc() { - printer->add_newline(2); - auto method = method_name(naming::NRN_ALLOC_METHOD); - printer->fmt_push_block("static void {}(double* data, Datum* indexes, int type)", method); - printer->add_line("// do nothing"); - printer->pop_block(); -} - -/** - * \todo Number of watch could be more than number of statements - * according to grammar. Check if this is correctly handled in neuron - * and coreneuron. - */ -void CodegenCppVisitor::print_watch_activate() { - if (info.watch_statements.empty()) { - return; - } - codegen = true; - printer->add_newline(2); - auto inst = fmt::format("{}* inst", instance_struct()); - - printer->fmt_push_block( - "static void nrn_watch_activate({}, int id, int pnodecount, int watch_id, " - "double v, bool &watch_remove)", - inst); - - // initialize all variables only during first watch statement - printer->push_block("if (watch_remove == false)"); - for (int i = 0; i < info.watch_count; i++) { - auto name = get_variable_name(fmt::format("watch{}", i + 1)); - printer->fmt_line("{} = 0;", name); - } - printer->add_line("watch_remove = true;"); - printer->pop_block(); - - /** - * \todo Similar to neuron/coreneuron we are using - * first watch and ignoring rest. - */ - for (int i = 0; i < info.watch_statements.size(); i++) { - auto statement = info.watch_statements[i]; - printer->fmt_push_block("if (watch_id == {})", i); - - auto varname = get_variable_name(fmt::format("watch{}", i + 1)); - printer->add_indent(); - printer->fmt_text("{} = 2 + (", varname); - auto watch = statement->get_statements().front(); - watch->get_expression()->visit_children(*this); - printer->add_text(");"); - printer->add_newline(); - - printer->pop_block(); - } - printer->pop_block(); - codegen = false; -} - - -/** - * \todo Similar to print_watch_activate, we are using only - * first watch. need to verify with neuron/coreneuron about rest. - */ -void CodegenCppVisitor::print_watch_check() { - if (info.watch_statements.empty()) { - return; - } - codegen = true; - printer->add_newline(2); - printer->add_line("/** routine to check watch activation */"); - print_global_function_common_code(BlockType::Watch); - - // WATCH statements appears in NET_RECEIVE block and while printing - // net_receive function we already check if it contains any MUTEX/PROTECT - // constructs. As WATCH is not a top level block but list of statements, - // we don't need to have ivdep pragma related check - print_channel_iteration_block_parallel_hint(BlockType::Watch, nullptr); - - printer->push_block("for (int id = 0; id < nodecount; id++)"); - - if (info.is_voltage_used_by_watch_statements()) { - printer->add_line("int node_id = node_index[id];"); - printer->add_line("double v = voltage[node_id];"); - print_v_unused(); - } - - // flat to make sure only one WATCH statement can be triggered at a time - printer->add_line("bool watch_untriggered = true;"); - - for (int i = 0; i < info.watch_statements.size(); i++) { - auto statement = info.watch_statements[i]; - const auto& watch = statement->get_statements().front(); - const auto& varname = get_variable_name(fmt::format("watch{}", i + 1)); - - // start block 1 - printer->fmt_push_block("if ({}&2 && watch_untriggered)", varname); - - // start block 2 - printer->add_indent(); - printer->add_text("if ("); - watch->get_expression()->accept(*this); - printer->add_text(") {"); - printer->add_newline(); - printer->increase_indent(); - - // start block 3 - printer->fmt_push_block("if (({}&1) == 0)", varname); - - printer->add_line("watch_untriggered = false;"); - - const auto& tqitem = get_variable_name("tqitem"); - const auto& point_process = get_variable_name("point_process"); - printer->add_indent(); - printer->add_text("net_send_buffering("); - const auto& t = get_variable_name("t"); - printer->fmt_text("nt, ml->_net_send_buffer, 0, {}, -1, {}, {}+0.0, ", - tqitem, - point_process, - t); - watch->get_value()->accept(*this); - printer->add_text(");"); - printer->add_newline(); - printer->pop_block(); - - printer->add_line(varname, " = 3;"); - // end block 3 - - // start block 3 - printer->decrease_indent(); - printer->push_block("} else"); - printer->add_line(varname, " = 2;"); - printer->pop_block(); - // end block 3 - - printer->pop_block(); - // end block 1 - } - - printer->pop_block(); - print_send_event_move(); - print_kernel_data_present_annotation_block_end(); - printer->pop_block(); - codegen = false; -} - - -void CodegenCppVisitor::print_net_receive_common_code(const Block& node, bool need_mech_inst) { - printer->add_multi_line(R"CODE( - int tid = pnt->_tid; - int id = pnt->_i_instance; - double v = 0; - )CODE"); - - if (info.artificial_cell || node.is_initial_block()) { - printer->add_line("NrnThread* nt = nrn_threads + tid;"); - printer->add_line("Memb_list* ml = nt->_ml_list[pnt->_type];"); - } - if (node.is_initial_block()) { - print_kernel_data_present_annotation_block_begin(); - } - - printer->add_multi_line(R"CODE( - int nodecount = ml->nodecount; - int pnodecount = ml->_nodecount_padded; - double* data = ml->data; - double* weights = nt->weights; - Datum* indexes = ml->pdata; - ThreadDatum* thread = ml->_thread; - )CODE"); - if (need_mech_inst) { - printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); - } - - if (node.is_initial_block()) { - print_net_init_acc_serial_annotation_block_begin(); - } - - // rename variables but need to see if they are actually used - auto parameters = info.net_receive_node->get_parameters(); - if (!parameters.empty()) { - int i = 0; - printer->add_newline(); - for (auto& parameter: parameters) { - auto name = parameter->get_node_name(); - bool var_used = VarUsageVisitor().variable_used(node, "(*" + name + ")"); - if (var_used) { - printer->fmt_line("double* {} = weights + weight_index + {};", name, i); - RenameVisitor vr(name, "*" + name); - node.visit_children(vr); - } - i++; - } - } -} - - -void CodegenCppVisitor::print_net_send_call(const FunctionCall& node) { - auto const& arguments = node.get_arguments(); - const auto& tqitem = get_variable_name("tqitem"); - std::string weight_index = "weight_index"; - std::string pnt = "pnt"; - - // for functions not generated from NET_RECEIVE blocks (i.e. top level INITIAL block) - // the weight_index argument is 0. - if (!printing_net_receive && !printing_net_init) { - weight_index = "0"; - auto var = get_variable_name("point_process"); - if (info.artificial_cell) { - pnt = "(Point_process*)" + var; - } - } - - // artificial cells don't use spike buffering - // clang-format off - if (info.artificial_cell) { - printer->fmt_text("artcell_net_send(&{}, {}, {}, nt->_t+", tqitem, weight_index, pnt); - } else { - const auto& point_process = get_variable_name("point_process"); - const auto& t = get_variable_name("t"); - printer->add_text("net_send_buffering("); - printer->fmt_text("nt, ml->_net_send_buffer, 0, {}, {}, {}, {}+", tqitem, weight_index, point_process, t); - } - // clang-format off - print_vector_elements(arguments, ", "); - printer->add_text(')'); -} - - -void CodegenCppVisitor::print_net_move_call(const FunctionCall& node) { - if (!printing_net_receive && !printing_net_init) { - throw std::runtime_error("Error : net_move only allowed in NET_RECEIVE block"); - } - - auto const& arguments = node.get_arguments(); - const auto& tqitem = get_variable_name("tqitem"); - std::string weight_index = "-1"; - std::string pnt = "pnt"; - - // artificial cells don't use spike buffering - // clang-format off - if (info.artificial_cell) { - printer->fmt_text("artcell_net_move(&{}, {}, ", tqitem, pnt); - print_vector_elements(arguments, ", "); - printer->add_text(")"); - } else { - const auto& point_process = get_variable_name("point_process"); - printer->add_text("net_send_buffering("); - printer->fmt_text("nt, ml->_net_send_buffer, 2, {}, {}, {}, ", tqitem, weight_index, point_process); - print_vector_elements(arguments, ", "); - printer->add_text(", 0.0"); - printer->add_text(")"); - } -} - - -void CodegenCppVisitor::print_net_event_call(const FunctionCall& node) { - const auto& arguments = node.get_arguments(); - if (info.artificial_cell) { - printer->add_text("net_event(pnt, "); - print_vector_elements(arguments, ", "); - } else { - const auto& point_process = get_variable_name("point_process"); - printer->add_text("net_send_buffering("); - printer->fmt_text("nt, ml->_net_send_buffer, 1, -1, -1, {}, ", point_process); - print_vector_elements(arguments, ", "); - printer->add_text(", 0.0"); - } - printer->add_text(")"); -} - -/** - * Rename arguments to NET_RECEIVE block with corresponding pointer variable - * - * Arguments to NET_RECEIVE block are packed and passed via weight vector. These - * variables need to be replaced with corresponding pointer variable. For example, - * if mod file is like - * - * \code{.mod} - * NET_RECEIVE (weight, R){ - * INITIAL { - * R=1 - * } - * } - * \endcode - * - * then generated code for initial block should be: - * - * \code{.cpp} - * double* R = weights + weight_index + 0; - * (*R) = 1.0; - * \endcode - * - * So, the `R` in AST needs to be renamed with `(*R)`. - */ -static void rename_net_receive_arguments(const ast::NetReceiveBlock& net_receive_node, const ast::Node& node) { - const auto& parameters = net_receive_node.get_parameters(); - for (auto& parameter: parameters) { - const auto& name = parameter->get_node_name(); - auto var_used = VarUsageVisitor().variable_used(node, name); - if (var_used) { - RenameVisitor vr(name, "(*" + name + ")"); - node.get_statement_block()->visit_children(vr); - } - } -} - - -void CodegenCppVisitor::print_net_init() { - const auto node = info.net_receive_initial_node; - if (node == nullptr) { - return; - } - - // rename net_receive arguments used in the initial block of net_receive - rename_net_receive_arguments(*info.net_receive_node, *node); - - codegen = true; - printing_net_init = true; - auto args = "Point_process* pnt, int weight_index, double flag"; - printer->add_newline(2); - printer->add_line("/** initialize block for net receive */"); - printer->fmt_push_block("static void net_init({})", args); - auto block = node->get_statement_block().get(); - if (block->get_statements().empty()) { - printer->add_line("// do nothing"); - } else { - print_net_receive_common_code(*node); - print_statement_block(*block, false, false); - if (node->is_initial_block()) { - print_net_init_acc_serial_annotation_block_end(); - print_kernel_data_present_annotation_block_end(); - printer->add_line("auto& nsb = ml->_net_send_buffer;"); - print_net_send_buf_update_to_host(); - } - } - printer->pop_block(); - codegen = false; - printing_net_init = false; -} - - -void CodegenCppVisitor::print_send_event_move() { - printer->add_newline(); - printer->add_line("NetSendBuffer_t* nsb = ml->_net_send_buffer;"); - print_net_send_buf_update_to_host(); - printer->push_block("for (int i=0; i < nsb->_cnt; i++)"); - printer->add_multi_line(R"CODE( - int type = nsb->_sendtype[i]; - int tid = nt->id; - double t = nsb->_nsb_t[i]; - double flag = nsb->_nsb_flag[i]; - int vdata_index = nsb->_vdata_index[i]; - int weight_index = nsb->_weight_index[i]; - int point_index = nsb->_pnt_index[i]; - net_sem_from_gpu(type, vdata_index, weight_index, tid, point_index, t, flag); - )CODE"); - printer->pop_block(); - printer->add_line("nsb->_cnt = 0;"); - print_net_send_buf_count_update_to_device(); -} - - -std::string CodegenCppVisitor::net_receive_buffering_declaration() { - return fmt::format("void {}(NrnThread* nt)", method_name("net_buf_receive")); -} - - -void CodegenCppVisitor::print_get_memb_list() { - printer->add_line("Memb_list* ml = get_memb_list(nt);"); - printer->push_block("if (!ml)"); - printer->add_line("return;"); - printer->pop_block(); - printer->add_newline(); -} - - -void CodegenCppVisitor::print_net_receive_loop_begin() { - printer->add_line("int count = nrb->_displ_cnt;"); - print_channel_iteration_block_parallel_hint(BlockType::NetReceive, info.net_receive_node); - printer->push_block("for (int i = 0; i < count; i++)"); -} - - -void CodegenCppVisitor::print_net_receive_loop_end() { - printer->pop_block(); -} - - -void CodegenCppVisitor::print_net_receive_buffering(bool need_mech_inst) { - if (!net_receive_required() || info.artificial_cell) { - return; - } - printer->add_newline(2); - printer->push_block(net_receive_buffering_declaration()); - - print_get_memb_list(); - - const auto& net_receive = method_name("net_receive_kernel"); - - print_kernel_data_present_annotation_block_begin(); - - printer->add_line("NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;"); - if (need_mech_inst) { - printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); - } - print_net_receive_loop_begin(); - printer->add_line("int start = nrb->_displ[i];"); - printer->add_line("int end = nrb->_displ[i+1];"); - printer->push_block("for (int j = start; j < end; j++)"); - printer->add_multi_line(R"CODE( - int index = nrb->_nrb_index[j]; - int offset = nrb->_pnt_index[index]; - double t = nrb->_nrb_t[index]; - int weight_index = nrb->_weight_index[index]; - double flag = nrb->_nrb_flag[index]; - Point_process* point_process = nt->pntprocs + offset; - )CODE"); - printer->add_line(net_receive, "(t, point_process, inst, nt, ml, weight_index, flag);"); - printer->pop_block(); - print_net_receive_loop_end(); - - print_device_stream_wait(); - printer->add_line("nrb->_displ_cnt = 0;"); - printer->add_line("nrb->_cnt = 0;"); - - if (info.net_send_used || info.net_event_used) { - print_send_event_move(); - } - - print_kernel_data_present_annotation_block_end(); - printer->pop_block(); -} - -void CodegenCppVisitor::print_net_send_buffering_cnt_update() const { - printer->add_line("i = nsb->_cnt++;"); -} - -void CodegenCppVisitor::print_net_send_buffering_grow() { - printer->push_block("if (i >= nsb->_size)"); - printer->add_line("nsb->grow();"); - printer->pop_block(); -} - -void CodegenCppVisitor::print_net_send_buffering() { - if (!net_send_buffer_required()) { - return; - } - - printer->add_newline(2); - print_device_method_annotation(); - auto args = - "const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, " - "int weight_index, int point_index, double t, double flag"; - printer->fmt_push_block("static inline void net_send_buffering({})", args); - printer->add_line("int i = 0;"); - print_net_send_buffering_cnt_update(); - print_net_send_buffering_grow(); - printer->push_block("if (i < nsb->_size)"); - printer->add_multi_line(R"CODE( - nsb->_sendtype[i] = type; - nsb->_vdata_index[i] = vdata_index; - nsb->_weight_index[i] = weight_index; - nsb->_pnt_index[i] = point_index; - nsb->_nsb_t[i] = t; - nsb->_nsb_flag[i] = flag; - )CODE"); - printer->pop_block(); - printer->pop_block(); -} - - -void CodegenCppVisitor::visit_for_netcon(const ast::ForNetcon& node) { - // For_netcon should take the same arguments as net_receive and apply the operations - // in the block to the weights of the netcons. Since all the weights are on the same vector, - // weights, we have a mask of operations that we apply iteratively, advancing the offset - // to the next netcon. - const auto& args = node.get_parameters(); - RenameVisitor v; - const auto& statement_block = node.get_statement_block(); - for (size_t i_arg = 0; i_arg < args.size(); ++i_arg) { - // sanitize node_name since we want to substitute names like (*w) as they are - auto old_name = - std::regex_replace(args[i_arg]->get_node_name(), regex_special_chars, R"(\$&)"); - const auto& new_name = fmt::format("weights[{} + nt->_fornetcon_weight_perm[i]]", i_arg); - v.set(old_name, new_name); - statement_block->accept(v); - } - - const auto index = - std::find_if(info.semantics.begin(), info.semantics.end(), [](const IndexSemantics& a) { - return a.name == naming::FOR_NETCON_SEMANTIC; - })->index; - - printer->fmt_text("const size_t offset = {}*pnodecount + id;", index); - printer->add_newline(); - printer->add_line( - "const size_t for_netcon_start = nt->_fornetcon_perm_indices[indexes[offset]];"); - printer->add_line( - "const size_t for_netcon_end = nt->_fornetcon_perm_indices[indexes[offset] + 1];"); - - printer->add_line("for (auto i = for_netcon_start; i < for_netcon_end; ++i) {"); - printer->increase_indent(); - print_statement_block(*statement_block, false, false); - printer->decrease_indent(); - - printer->add_line("}"); -} - -void CodegenCppVisitor::print_net_receive_kernel() { - if (!net_receive_required()) { - return; - } - codegen = true; - printing_net_receive = true; - const auto node = info.net_receive_node; - - // rename net_receive arguments used in the block itself - rename_net_receive_arguments(*info.net_receive_node, *node); - - std::string name; - ParamVector params; - if (!info.artificial_cell) { - name = method_name("net_receive_kernel"); - params.emplace_back("", "double", "", "t"); - params.emplace_back("", "Point_process*", "", "pnt"); - params.emplace_back("", fmt::format("{}*", instance_struct()), - "", "inst"); - params.emplace_back("", "NrnThread*", "", "nt"); - params.emplace_back("", "Memb_list*", "", "ml"); - params.emplace_back("", "int", "", "weight_index"); - params.emplace_back("", "double", "", "flag"); - } else { - name = method_name("net_receive"); - params.emplace_back("", "Point_process*", "", "pnt"); - params.emplace_back("", "int", "", "weight_index"); - params.emplace_back("", "double", "", "flag"); - } - - printer->add_newline(2); - printer->fmt_push_block("static inline void {}({})", name, get_parameter_str(params)); - print_net_receive_common_code(*node, info.artificial_cell); - if (info.artificial_cell) { - printer->add_line("double t = nt->_t;"); - } - - // set voltage variable if it is used in the block (e.g. for WATCH statement) - auto v_used = VarUsageVisitor().variable_used(*node->get_statement_block(), "v"); - if (v_used) { - printer->add_line("int node_id = ml->nodeindices[id];"); - printer->add_line("v = nt->_actual_v[node_id];"); - } - - printer->fmt_line("{} = t;", get_variable_name("tsave")); - - if (info.is_watch_used()) { - printer->add_line("bool watch_remove = false;"); - } - - printer->add_indent(); - node->get_statement_block()->accept(*this); - printer->add_newline(); - printer->pop_block(); - - printing_net_receive = false; - codegen = false; -} - - -void CodegenCppVisitor::print_net_receive() { - if (!net_receive_required()) { - return; - } - codegen = true; - printing_net_receive = true; - if (!info.artificial_cell) { - const auto& name = method_name("net_receive"); - ParamVector params; - params.emplace_back("", "Point_process*", "", "pnt"); - params.emplace_back("", "int", "", "weight_index"); - params.emplace_back("", "double", "", "flag"); - printer->add_newline(2); - printer->fmt_push_block("static void {}({})", name, get_parameter_str(params)); - printer->add_line("NrnThread* nt = nrn_threads + pnt->_tid;"); - printer->add_line("Memb_list* ml = get_memb_list(nt);"); - printer->add_line("NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;"); - printer->push_block("if (nrb->_cnt >= nrb->_size)"); - printer->add_line("realloc_net_receive_buffer(nt, ml);"); - printer->pop_block(); - printer->add_multi_line(R"CODE( - int id = nrb->_cnt; - nrb->_pnt_index[id] = pnt-nt->pntprocs; - nrb->_weight_index[id] = weight_index; - nrb->_nrb_t[id] = nt->_t; - nrb->_nrb_flag[id] = flag; - nrb->_cnt++; - )CODE"); - printer->pop_block(); - } - printing_net_receive = false; - codegen = false; -} - - -/** - * \todo Data is not derived. Need to add instance into instance struct? - * data used here is wrong in AoS because as in original implementation, - * data is not incremented every iteration for AoS. May be better to derive - * actual variable names? [resolved now?] - * slist needs to added as local variable - */ -void CodegenCppVisitor::print_derivimplicit_kernel(const Block& block) { - auto ext_args = external_method_arguments(); - auto ext_params = external_method_parameters(); - auto suffix = info.mod_suffix; - auto list_num = info.derivimplicit_list_num; - auto block_name = block.get_node_name(); - auto primes_size = info.primes_size; - auto stride = "*pnodecount+id"; - - printer->add_newline(2); - - printer->push_block("namespace"); - printer->fmt_push_block("struct _newton_{}_{}", block_name, info.mod_suffix); - printer->fmt_push_block("int operator()({}) const", external_method_parameters()); - auto const instance = fmt::format("auto* const inst = static_cast<{0}*>(ml->instance);", - instance_struct()); - auto const slist1 = fmt::format("auto const& slist{} = {};", - list_num, - get_variable_name(fmt::format("slist{}", list_num))); - auto const slist2 = fmt::format("auto& slist{} = {};", - list_num + 1, - get_variable_name(fmt::format("slist{}", list_num + 1))); - auto const dlist1 = fmt::format("auto const& dlist{} = {};", - list_num, - get_variable_name(fmt::format("dlist{}", list_num))); - auto const dlist2 = fmt::format( - "double* dlist{} = static_cast(thread[dith{}()].pval) + ({}*pnodecount);", - list_num + 1, - list_num, - info.primes_size); - printer->add_line(instance); - if (ion_variable_struct_required()) { - print_ion_variable(); - } - printer->fmt_line("double* savstate{} = static_cast(thread[dith{}()].pval);", - list_num, - list_num); - printer->add_line(slist1); - printer->add_line(dlist1); - printer->add_line(dlist2); - codegen = true; - print_statement_block(*block.get_statement_block(), false, false); - codegen = false; - printer->add_line("int counter = -1;"); - printer->fmt_push_block("for (int i=0; i<{}; i++)", info.num_primes); - printer->fmt_push_block("if (*deriv{}_advance(thread))", list_num); - printer->fmt_line( - "dlist{0}[(++counter){1}] = " - "data[dlist{2}[i]{1}]-(data[slist{2}[i]{1}]-savstate{2}[i{1}])/nt->_dt;", - list_num + 1, - stride, - list_num); - printer->chain_block("else"); - printer->fmt_line("dlist{0}[(++counter){1}] = data[slist{2}[i]{1}]-savstate{2}[i{1}];", - list_num + 1, - stride, - list_num); - printer->pop_block(); - printer->pop_block(); - printer->add_line("return 0;"); - printer->pop_block(); // operator() - printer->pop_block(";"); // struct - printer->pop_block(); // namespace - printer->add_newline(); - printer->fmt_push_block("int {}_{}({})", block_name, suffix, ext_params); - printer->add_line(instance); - printer->fmt_line("double* savstate{} = (double*) thread[dith{}()].pval;", list_num, list_num); - printer->add_line(slist1); - printer->add_line(slist2); - printer->add_line(dlist2); - printer->fmt_push_block("for (int i=0; i<{}; i++)", info.num_primes); - printer->fmt_line("savstate{}[i{}] = data[slist{}[i]{}];", list_num, stride, list_num, stride); - printer->pop_block(); - printer->fmt_line( - "int reset = nrn_newton_thread(static_cast(*newtonspace{}(thread)), {}, " - "slist{}, _newton_{}_{}{{}}, dlist{}, {});", - list_num, - primes_size, - list_num + 1, - block_name, - suffix, - list_num + 1, - ext_args); - printer->add_line("return reset;"); - printer->pop_block(); - printer->add_newline(2); -} - - -void CodegenCppVisitor::print_newtonspace_transfer_to_device() const { - // nothing to do on cpu -} - - -void CodegenCppVisitor::visit_derivimplicit_callback(const ast::DerivimplicitCallback& node) { - if (!codegen) { - return; - } - printer->fmt_line("{}_{}({});", - node.get_node_to_solve()->get_node_name(), - info.mod_suffix, - external_method_arguments()); -} - -void CodegenCppVisitor::visit_solution_expression(const SolutionExpression& node) { - auto block = node.get_node_to_solve().get(); - if (block->is_statement_block()) { - auto statement_block = dynamic_cast(block); - print_statement_block(*statement_block, false, false); - } else { - block->accept(*this); - } -} - - -/****************************************************************************************/ -/* Print nrn_state routine */ -/****************************************************************************************/ - - -void CodegenCppVisitor::print_nrn_state() { - if (!nrn_state_required()) { - return; - } - codegen = true; - - printer->add_newline(2); - printer->add_line("/** update state */"); - print_global_function_common_code(BlockType::State); - print_channel_iteration_block_parallel_hint(BlockType::State, info.nrn_state_block); - printer->push_block("for (int id = 0; id < nodecount; id++)"); - - printer->add_line("int node_id = node_index[id];"); - printer->add_line("double v = voltage[node_id];"); - print_v_unused(); - - /** - * \todo Eigen solver node also emits IonCurVar variable in the functor - * but that shouldn't update ions in derivative block - */ - if (ion_variable_struct_required()) { - print_ion_variable(); - } - - auto read_statements = ion_read_statements(BlockType::State); - for (auto& statement: read_statements) { - printer->add_line(statement); - } - - if (info.nrn_state_block) { - info.nrn_state_block->visit_children(*this); - } - - if (info.currents.empty() && info.breakpoint_node != nullptr) { - auto block = info.breakpoint_node->get_statement_block(); - print_statement_block(*block, false, false); - } - - const auto& write_statements = ion_write_statements(BlockType::State); - for (auto& statement: write_statements) { - const auto& text = process_shadow_update_statement(statement, BlockType::State); - printer->add_line(text); - } - printer->pop_block(); - - print_kernel_data_present_annotation_block_end(); - - printer->pop_block(); - codegen = false; -} - - -/****************************************************************************************/ -/* Print nrn_cur related routines */ -/****************************************************************************************/ - - -void CodegenCppVisitor::print_nrn_current(const BreakpointBlock& node) { - const auto& args = internal_method_parameters(); - const auto& block = node.get_statement_block(); - printer->add_newline(2); - print_device_method_annotation(); - printer->fmt_push_block("inline double nrn_current_{}({})", - info.mod_suffix, - get_parameter_str(args)); - printer->add_line("double current = 0.0;"); - print_statement_block(*block, false, false); - for (auto& current: info.currents) { - const auto& name = get_variable_name(current); - printer->fmt_line("current += {};", name); - } - printer->add_line("return current;"); - printer->pop_block(); -} - - -void CodegenCppVisitor::print_nrn_cur_conductance_kernel(const BreakpointBlock& node) { - const auto& block = node.get_statement_block(); - print_statement_block(*block, false, false); - if (!info.currents.empty()) { - std::string sum; - for (const auto& current: info.currents) { - auto var = breakpoint_current(current); - sum += get_variable_name(var); - if (¤t != &info.currents.back()) { - sum += "+"; - } - } - printer->fmt_line("double rhs = {};", sum); - } - - std::string sum; - for (const auto& conductance: info.conductances) { - auto var = breakpoint_current(conductance.variable); - sum += get_variable_name(var); - if (&conductance != &info.conductances.back()) { - sum += "+"; - } - } - printer->fmt_line("double g = {};", sum); - - for (const auto& conductance: info.conductances) { - if (!conductance.ion.empty()) { - const auto& lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + conductance.ion + "dv"; - const auto& rhs = get_variable_name(conductance.variable); - const ShadowUseStatement statement{lhs, "+=", rhs}; - const auto& text = process_shadow_update_statement(statement, BlockType::Equation); - printer->add_line(text); - } - } -} - - -void CodegenCppVisitor::print_nrn_cur_non_conductance_kernel() { - printer->fmt_line("double g = nrn_current_{}({}+0.001);", - info.mod_suffix, - internal_method_arguments()); - for (auto& ion: info.ions) { - for (auto& var: ion.writes) { - if (ion.is_ionic_current(var)) { - const auto& name = get_variable_name(var); - printer->fmt_line("double di{} = {};", ion.name, name); - } - } - } - printer->fmt_line("double rhs = nrn_current_{}({});", - info.mod_suffix, - internal_method_arguments()); - printer->add_line("g = (g-rhs)/0.001;"); - for (auto& ion: info.ions) { - for (auto& var: ion.writes) { - if (ion.is_ionic_current(var)) { - const auto& lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + ion.name + "dv"; - auto rhs = fmt::format("(di{}-{})/0.001", ion.name, get_variable_name(var)); - if (info.point_process) { - auto area = get_variable_name(naming::NODE_AREA_VARIABLE); - rhs += fmt::format("*1.e2/{}", area); - } - const ShadowUseStatement statement{lhs, "+=", rhs}; - const auto& text = process_shadow_update_statement(statement, BlockType::Equation); - printer->add_line(text); - } - } - } -} - - -void CodegenCppVisitor::print_nrn_cur_kernel(const BreakpointBlock& node) { - printer->add_line("int node_id = node_index[id];"); - printer->add_line("double v = voltage[node_id];"); - print_v_unused(); - if (ion_variable_struct_required()) { - print_ion_variable(); - } - - const auto& read_statements = ion_read_statements(BlockType::Equation); - for (auto& statement: read_statements) { - printer->add_line(statement); - } - - if (info.conductances.empty()) { - print_nrn_cur_non_conductance_kernel(); - } else { - print_nrn_cur_conductance_kernel(node); - } - - const auto& write_statements = ion_write_statements(BlockType::Equation); - for (auto& statement: write_statements) { - auto text = process_shadow_update_statement(statement, BlockType::Equation); - printer->add_line(text); - } - - if (info.point_process) { - const auto& area = get_variable_name(naming::NODE_AREA_VARIABLE); - printer->fmt_line("double mfactor = 1.e2/{};", area); - printer->add_line("g = g*mfactor;"); - printer->add_line("rhs = rhs*mfactor;"); - } - - print_g_unused(); -} - -void CodegenCppVisitor::print_fast_imem_calculation() { - if (!info.electrode_current) { - return; - } - std::string rhs, d; - auto rhs_op = operator_for_rhs(); - auto d_op = operator_for_d(); - if (info.point_process) { - rhs = "shadow_rhs[id]"; - d = "shadow_d[id]"; - } else { - rhs = "rhs"; - d = "g"; - } - - printer->push_block("if (nt->nrn_fast_imem)"); - if (nrn_cur_reduction_loop_required()) { - printer->push_block("for (int id = 0; id < nodecount; id++)"); - printer->add_line("int node_id = node_index[id];"); - } - printer->fmt_line("nt->nrn_fast_imem->nrn_sav_rhs[node_id] {} {};", rhs_op, rhs); - printer->fmt_line("nt->nrn_fast_imem->nrn_sav_d[node_id] {} {};", d_op, d); - if (nrn_cur_reduction_loop_required()) { - printer->pop_block(); - } - printer->pop_block(); -} - -void CodegenCppVisitor::print_nrn_cur() { - if (!nrn_cur_required()) { - return; - } - - codegen = true; - if (info.conductances.empty()) { - print_nrn_current(*info.breakpoint_node); - } - - printer->add_newline(2); - printer->add_line("/** update current */"); - print_global_function_common_code(BlockType::Equation); - print_channel_iteration_block_parallel_hint(BlockType::Equation, info.breakpoint_node); - printer->push_block("for (int id = 0; id < nodecount; id++)"); - print_nrn_cur_kernel(*info.breakpoint_node); - print_nrn_cur_matrix_shadow_update(); - if (!nrn_cur_reduction_loop_required()) { - print_fast_imem_calculation(); - } - printer->pop_block(); - - if (nrn_cur_reduction_loop_required()) { - printer->push_block("for (int id = 0; id < nodecount; id++)"); - print_nrn_cur_matrix_shadow_reduction(); - printer->pop_block(); - print_fast_imem_calculation(); - } - - print_kernel_data_present_annotation_block_end(); - printer->pop_block(); - codegen = false; -} - - -/****************************************************************************************/ -/* Main code printing entry points */ -/****************************************************************************************/ - -void CodegenCppVisitor::print_headers_include() { - print_standard_includes(); - print_backend_includes(); - print_coreneuron_includes(); -} - - -void CodegenCppVisitor::print_namespace_begin() { - print_namespace_start(); - print_backend_namespace_start(); -} - - -void CodegenCppVisitor::print_namespace_end() { - print_backend_namespace_stop(); - print_namespace_stop(); -} - - -void CodegenCppVisitor::print_common_getters() { - print_first_pointer_var_index_getter(); - print_net_receive_arg_size_getter(); - print_thread_getters(); - print_num_variable_getter(); - print_mech_type_getter(); - print_memb_list_getter(); -} - - -void CodegenCppVisitor::print_data_structures(bool print_initializers) { - print_mechanism_global_var_structure(print_initializers); - print_mechanism_range_var_structure(print_initializers); - print_ion_var_structure(); -} - -void CodegenCppVisitor::print_v_unused() const { - if (!info.vectorize) { - return; - } - printer->add_multi_line(R"CODE( - #if NRN_PRCELLSTATE - inst->v_unused[id] = v; - #endif - )CODE"); -} - -void CodegenCppVisitor::print_g_unused() const { - printer->add_multi_line(R"CODE( - #if NRN_PRCELLSTATE - inst->g_unused[id] = g; - #endif - )CODE"); -} - -void CodegenCppVisitor::print_compute_functions() { - print_top_verbatim_blocks(); - for (const auto& procedure: info.procedures) { - print_procedure(*procedure); - } - for (const auto& function: info.functions) { - print_function(*function); - } - for (const auto& function: info.function_tables) { - print_function_tables(*function); - } - for (size_t i = 0; i < info.before_after_blocks.size(); i++) { - print_before_after_block(info.before_after_blocks[i], i); - } - for (const auto& callback: info.derivimplicit_callbacks) { - const auto& block = *callback->get_node_to_solve(); - print_derivimplicit_kernel(block); - } - print_net_send_buffering(); - print_net_init(); - print_watch_activate(); - print_watch_check(); - print_net_receive_kernel(); - print_net_receive(); - print_net_receive_buffering(); - print_nrn_init(); - print_nrn_cur(); - print_nrn_state(); -} - - -void CodegenCppVisitor::print_codegen_routines() { - codegen = true; - print_backend_info(); - print_headers_include(); - print_namespace_begin(); - print_nmodl_constants(); - print_prcellstate_macros(); - print_mechanism_info(); - print_data_structures(true); - print_global_variables_for_hoc(); - print_common_getters(); - print_memory_allocation_routine(); - print_abort_routine(); - print_thread_memory_callbacks(); - print_instance_variable_setup(); - print_nrn_alloc(); - print_nrn_constructor(); - print_nrn_destructor(); - print_function_prototypes(); - print_functors_definitions(); - print_compute_functions(); - print_check_table_thread_function(); - print_mechanism_register(); - print_namespace_end(); - codegen = false; -} - - -void CodegenCppVisitor::set_codegen_global_variables(const std::vector& global_vars) { - codegen_global_variables = global_vars; -} - - void CodegenCppVisitor::setup(const Program& node) { program_symtab = node.get_symbol_table(); @@ -4685,7 +881,8 @@ void CodegenCppVisitor::setup(const Program& node) { info.mod_file = mod_filename; if (!info.vectorize) { - logger->warn("CodegenCppVisitor : MOD file uses non-thread safe constructs of NMODL"); + logger->warn( + "CodegenCoreneuronCppVisitor : MOD file uses non-thread safe constructs of NMODL"); } codegen_float_variables = get_float_variables(); @@ -4702,4 +899,4 @@ void CodegenCppVisitor::visit_program(const Program& node) { } } // namespace codegen -} // namespace nmodl +} // namespace nmodl \ No newline at end of file diff --git a/src/codegen/codegen_cpp_visitor.hpp b/src/codegen/codegen_cpp_visitor.hpp index ee0344058c..66016af917 100644 --- a/src/codegen/codegen_cpp_visitor.hpp +++ b/src/codegen/codegen_cpp_visitor.hpp @@ -31,7 +31,6 @@ #include "utils/logger.hpp" #include "visitors/ast_visitor.hpp" - /// encapsulates code generation backend implementations namespace nmodl { @@ -182,6 +181,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { protected: using SymbolType = std::shared_ptr; + /** * A vector of parameters represented by a 4-tuple of strings: * @@ -193,62 +193,84 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { */ using ParamVector = std::vector>; + + /****************************************************************************************/ + /* Member variables */ + /****************************************************************************************/ + /** * Code printer object for target (C++) */ std::unique_ptr printer; + /** * Name of mod file (without .mod suffix) */ std::string mod_filename; + /** * Data type of floating point variables */ std::string float_type = codegen::naming::DEFAULT_FLOAT_TYPE; + /** * Flag to indicate if visitor should avoid ion variable copies */ bool optimize_ionvar_copies = true; - /** - * Flag to indicate if visitor should print the visited nodes - */ - bool codegen = false; /** - * Variable name should be converted to instance name (but not for function arguments) + * All ast information for code generation */ - bool enable_variable_name_lookup = true; + codegen::CodegenInfo info; + /** * Symbol table for the program */ symtab::SymbolTable* program_symtab = nullptr; + /** * All float variables for the model */ std::vector codegen_float_variables; + /** * All int variables for the model */ std::vector codegen_int_variables; + /** * All global variables for the model * \todo: this has become different than CodegenInfo */ std::vector codegen_global_variables; + + /** + * Flag to indicate if visitor should print the visited nodes + */ + bool codegen = false; + + + /** + * Variable name should be converted to instance name (but not for function arguments) + */ + bool enable_variable_name_lookup = true; + + /** * \c true if currently net_receive block being printed */ bool printing_net_receive = false; + /** * \c true if currently initial block of net_receive being printed */ @@ -260,20 +282,22 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { */ bool printing_top_verbatim_blocks = false; + /** * \c true if internal method call was encountered while processing verbatim block */ bool internal_method_call_encountered = false; + /** * Index of watch statement being printed */ int current_watch_statement = 0; - /** - * All ast information for code generation - */ - codegen::CodegenInfo info; + + /****************************************************************************************/ + /* Generic information getters */ + /****************************************************************************************/ /** * Return Nmodl language version @@ -283,31 +307,17 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { return codegen::naming::NMODL_VERSION; } - /** - * Add quotes to string to be output - * - * \param text The string to be quoted - * \return The same string with double-quotes pre- and postfixed - */ - std::string add_escape_quote(const std::string& text) const { - return "\"" + text + "\""; - } - /** - * Operator for rhs vector update (matrix update) + * Name of the simulator the code was generated for */ - const char* operator_for_rhs() const noexcept { - return info.electrode_current ? "+=" : "-="; - } + virtual std::string simulator_name() = 0; /** - * Operator for diagonal vector update (matrix update) + * Name of the code generation backend */ - const char* operator_for_d() const noexcept { - return info.electrode_current ? "-=" : "+="; - } + virtual std::string backend_name() const = 0; /** @@ -318,6 +328,14 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { } + /** + * Check if a semicolon is required at the end of given statement + * \param node The AST Statement node to check + * \return \c true if this Statement requires a semicolon + */ + static bool need_semicolon(const ast::Statement& node); + + /** * Default data type for floating point elements */ @@ -343,75 +361,50 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** - * Checks if given function name is \c net_send - * \param name The function name to check - * \return \c true if the function is net_send - */ - bool is_net_send(const std::string& name) const noexcept { - return name == codegen::naming::NET_SEND_METHOD; - } - - /** - * Checks if given function name is \c net_move - * \param name The function name to check - * \return \c true if the function is net_move - */ - bool is_net_move(const std::string& name) const noexcept { - return name == codegen::naming::NET_MOVE_METHOD; - } - - /** - * Checks if given function name is \c net_event - * \param name The function name to check - * \return \c true if the function is net_event + * Operator for rhs vector update (matrix update) */ - bool is_net_event(const std::string& name) const noexcept { - return name == codegen::naming::NET_EVENT_METHOD; + const char* operator_for_rhs() const noexcept { + return info.electrode_current ? "+=" : "-="; } /** - * Name of structure that wraps range variables + * Operator for diagonal vector update (matrix update) */ - std::string instance_struct() const { - return fmt::format("{}_Instance", info.mod_suffix); + const char* operator_for_d() const noexcept { + return info.electrode_current ? "-=" : "+="; } - /** - * Name of structure that wraps global variables - */ - std::string global_struct() const { - return fmt::format("{}_Store", info.mod_suffix); - } + /****************************************************************************************/ + /* Common helper routines accross codegen functions */ + /****************************************************************************************/ /** - * Name of the (host-only) global instance of `global_struct` + * Check if function or procedure node has parameter with given name + * + * \tparam T Node type (either procedure or function) + * \param node AST node (either procedure or function) + * \param name Name of parameter + * \return True if argument with name exist */ - std::string global_struct_instance() const { - return info.mod_suffix + "_global"; - } + template + bool has_parameter_of_name(const T& node, const std::string& name); /** - * Constructs the name of a function or procedure - * \param name The name of the function or procedure - * \return The name of the function or procedure postfixed with the model name + * Check if given statement should be skipped during code generation + * \param node The AST Statement node to check + * \return \c true if this Statement is to be skipped */ - std::string method_name(const std::string& name) const { - return name + "_" + info.mod_suffix; - } + static bool statement_to_skip(const ast::Statement& node); /** - * Creates a temporary symbol - * \param name The name of the symbol - * \return A symbol based on the given name + * Check if net_send_buffer is required */ - SymbolType make_symbol(const std::string& name) const { - return std::make_shared(name, ModToken()); - } + bool net_send_buffer_required() const noexcept; /** @@ -438,12 +431,6 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { bool net_receive_required() const noexcept; - /** - * Check if net_send_buffer is required - */ - bool net_send_buffer_required() const noexcept; - - /** * Check if setup_range_variable function is required * \return @@ -472,39 +459,33 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** - * Check if given statement should be skipped during code generation - * \param node The AST Statement node to check - * \return \c true if this Statement is to be skipped - */ - static bool statement_to_skip(const ast::Statement& node); - - - /** - * Check if a semicolon is required at the end of given statement - * \param node The AST Statement node to check - * \return \c true if this Statement requires a semicolon - */ - static bool need_semicolon(const ast::Statement& node); - - - /** - * Determine the number of threads to allocate + * Checks if given function name is \c net_send + * \param name The function name to check + * \return \c true if the function is net_send */ - int num_thread_objects() const noexcept { - return info.vectorize ? (info.thread_data_index + 1) : 0; + bool is_net_send(const std::string& name) const noexcept { + return name == codegen::naming::NET_SEND_METHOD; } /** - * Number of float variables in the model + * Checks if given function name is \c net_move + * \param name The function name to check + * \return \c true if the function is net_move */ - int float_variables_size() const; + bool is_net_move(const std::string& name) const noexcept { + return name == codegen::naming::NET_MOVE_METHOD; + } /** - * Number of integer variables in the model + * Checks if given function name is \c net_event + * \param name The function name to check + * \return \c true if the function is net_event */ - int int_variables_size() const; + bool is_net_event(const std::string& name) const noexcept { + return name == codegen::naming::NET_EVENT_METHOD; + } /** @@ -512,7 +493,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param name The name of a float variable * \return The position index in the data array */ - int position_of_float_var(const std::string& name) const; + virtual int position_of_float_var(const std::string& name) const = 0; /** @@ -520,21 +501,19 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param name The name of an int variable * \return The position index in the data array */ - int position_of_int_var(const std::string& name) const; + virtual int position_of_int_var(const std::string& name) const = 0; /** - * Determine the updated name if the ion variable has been optimized - * \param name The ion variable name - * \return The updated name of the variable has been optimized (e.g. \c ena --> \c ion_ena) + * Number of float variables in the model */ - std::string update_if_ion_variable_name(const std::string& name) const; + int float_variables_size() const; /** - * Name of the code generation backend + * Number of integer variables in the model */ - virtual std::string backend_name() const; + int int_variables_size() const; /** @@ -542,7 +521,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param value The number to convert given as string as it is parsed by the modfile * \return Its string representation */ - virtual std::string format_double_string(const std::string& value); + std::string format_double_string(const std::string& value); /** @@ -550,70 +529,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param value The number to convert given as string as it is parsed by the modfile * \return Its string representation */ - virtual std::string format_float_string(const std::string& value); - - - /** - * Determine the name of a \c float variable given its symbol - * - * This function typically returns the accessor expression in backend code for the given symbol. - * Since the model variables are stored in data arrays and accessed by offset, this function - * will return the C++ string representing the array access at the correct offset - * - * \param symbol The symbol of a variable for which we want to obtain its name - * \param use_instance Should the variable be accessed via instance or data array - * \return The backend code string representing the access to the given variable - * symbol - */ - std::string float_variable_name(const SymbolType& symbol, bool use_instance) const; - - - /** - * Determine the name of an \c int variable given its symbol - * - * This function typically returns the accessor expression in backend code for the given symbol. - * Since the model variables are stored in data arrays and accessed by offset, this function - * will return the C++ string representing the array access at the correct offset - * - * \param symbol The symbol of a variable for which we want to obtain its name - * \param name The name of the index variable - * \param use_instance Should the variable be accessed via instance or data array - * \return The backend code string representing the access to the given variable - * symbol - */ - std::string int_variable_name(const IndexVariableInfo& symbol, - const std::string& name, - bool use_instance) const; - - - /** - * Determine the variable name for a global variable given its symbol - * \param symbol The symbol of a variable for which we want to obtain its name - * \param use_instance Should the variable be accessed via the (host-only) - * global variable or the instance-specific copy (also available on GPU). - * \return The C++ string representing the access to the global variable - */ - std::string global_variable_name(const SymbolType& symbol, bool use_instance = true) const; - - - /** - * Determine variable name in the structure of mechanism properties - * - * \param name Variable name that is being printed - * \param use_instance Should the variable be accessed via instance or data array - * \return The C++ string representing the access to the variable in the neuron - * thread structure - */ - std::string get_variable_name(const std::string& name, bool use_instance = true) const; - - - /** - * Determine the variable name for the "current" used in breakpoint block taking into account - * intermediate code transformations. - * \param current The variable name for the current used in the model - * \return The name for the current to be printed in C++ - */ - std::string breakpoint_current(std::string current) const; + std::string format_float_string(const std::string& value); /** @@ -636,33 +552,20 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { std::vector get_int_variables(); - /** - * Print the items in a vector as a list - * - * This function prints a given vector of elements as a list with given separator onto the - * current printer. Elements are expected to be of type nmodl::ast::Ast and are printed by being - * visited. Care is taken to omit the separator after the the last element. - * - * \tparam T The element type in the vector, which must be of type nmodl::ast::Ast - * \param elements The vector of elements to be printed - * \param separator The separator string to print between all elements - * \param prefix A prefix string to print before each element - */ - template - void print_vector_elements(const std::vector& elements, - const std::string& separator, - const std::string& prefix = ""); + /****************************************************************************************/ + /* Backend specific routines */ + /****************************************************************************************/ + /** - * Generate the string representing the procedure parameter declaration - * - * The procedure parameters are stored in a vector of 4-tuples each representing a parameter. - * - * \param params The parameters that should be concatenated into the function parameter - * declaration - * \return The string representing the declaration of function parameters + * Print atomic update pragma for reduction statements */ - static std::string get_parameter_str(const ParamVector& params); + virtual void print_atomic_reduction_pragma() = 0; + + + /****************************************************************************************/ + /* Printing routines for code generation */ + /****************************************************************************************/ /** @@ -681,108 +584,95 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** - * Check if a structure for ion variables is required - * \return \c true if a structure fot ion variables must be generated + * Print call to internal or external function + * \param node The AST node representing a function call */ - bool ion_variable_struct_required() const; + virtual void print_function_call(const ast::FunctionCall& node) = 0; /** - * Process a verbatim block for possible variable renaming - * \param text The verbatim code to be processed - * \return The code with all variables renamed as needed + * Print function and procedures prototype declaration */ - std::string process_verbatim_text(std::string const& text); + virtual void print_function_prototypes() = 0; /** - * Process a token in a verbatim block for possible variable renaming - * \param token The verbatim token to be processed - * \return The code after variable renaming + * Print nmodl function or procedure (common code) + * \param node the AST node representing the function or procedure in NMODL + * \param name the name of the function or procedure */ - std::string process_verbatim_token(const std::string& token); + virtual void print_function_or_procedure(const ast::Block& node, const std::string& name) = 0; /** - * Rename function/procedure arguments that conflict with default arguments + * Common helper function to help printing function or procedure blocks + * \param node the AST node representing the function or procedure in NMODL */ - void rename_function_arguments(); + virtual void print_function_procedure_helper(const ast::Block& node) = 0; /** - * For a given output block type, return statements for all read ion variables - * - * \param type The type of code block being generated - * \return A \c vector of strings representing the reading of ion variables - */ - std::vector ion_read_statements(BlockType type) const; - - - /** - * For a given output block type, return minimal statements for all read ion variables - * - * \param type The type of code block being generated - * \return A \c vector of strings representing the reading of ion variables + * Print NMODL procedure in target backend code + * \param node */ - std::vector ion_read_statements_optimized(BlockType type) const; + virtual void print_procedure(const ast::ProcedureBlock& node) = 0; /** - * For a given output block type, return statements for writing back ion variables - * - * \param type The type of code block being generated - * \return A \c vector of strings representing the write-back of ion variables + * Print NMODL function in target backend code + * \param node */ - std::vector ion_write_statements(BlockType type); + virtual void print_function(const ast::FunctionBlock& node) = 0; /** - * Return ion variable name and corresponding ion read variable name - * \param name The ion variable name - * \return The ion read variable name + * Rename function/procedure arguments that conflict with default arguments */ - static std::pair read_ion_variable_name(const std::string& name); + void rename_function_arguments(); /** - * Return ion variable name and corresponding ion write variable name - * \param name The ion variable name - * \return The ion write variable name + * Print the items in a vector as a list + * + * This function prints a given vector of elements as a list with given separator onto the + * current printer. Elements are expected to be of type nmodl::ast::Ast and are printed by being + * visited. Care is taken to omit the separator after the the last element. + * + * \tparam T The element type in the vector, which must be of type nmodl::ast::Ast + * \param elements The vector of elements to be printed + * \param separator The separator string to print between all elements + * \param prefix A prefix string to print before each element */ - static std::pair write_ion_variable_name(const std::string& name); + template + void print_vector_elements(const std::vector& elements, + const std::string& separator, + const std::string& prefix = ""); - /** - * Generate Function call statement for nrn_wrote_conc - * \param ion_name The name of the ion variable - * \param concentration The name of the concentration variable - * \param index - * \return The string representing the function call - */ - std::string conc_write_statement(const std::string& ion_name, - const std::string& concentration, - int index); + /****************************************************************************************/ + /* Code-specific helper routines */ + /****************************************************************************************/ /** * Arguments for functions that are defined and used internally. * \return the method arguments */ - std::string internal_method_arguments(); + virtual std::string internal_method_arguments() = 0; /** * Parameters for internally defined functions * \return the method parameters */ - ParamVector internal_method_parameters(); + virtual ParamVector internal_method_parameters() = 0; /** * Arguments for external functions called from generated code * \return A string representing the arguments passed to an external function */ - static const char* external_method_arguments() noexcept; + virtual const char* external_method_arguments() noexcept = 0; /** @@ -794,180 +684,161 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param table * \return A string representing the parameters of the function */ - static const char* external_method_parameters(bool table = false) noexcept; - - - /** - * Arguments for register_mech or point_register_mech function - */ - std::string register_mechanism_arguments() const; + virtual const char* external_method_parameters(bool table = false) noexcept = 0; /** * Arguments for "_threadargs_" macro in neuron implementation */ - std::string nrn_thread_arguments() const; + virtual std::string nrn_thread_arguments() const = 0; /** * Arguments for "_threadargs_" macro in neuron implementation */ - std::string nrn_thread_internal_arguments(); - + virtual std::string nrn_thread_internal_arguments() = 0; /** - * Replace commonly used verbatim variables - * \param name A variable name to be checked and possibly updated - * \return The possibly replace variable name + * Process a verbatim block for possible variable renaming + * \param text The verbatim code to be processed + * \return The code with all variables renamed as needed */ - std::string replace_if_verbatim_variable(std::string name); + virtual std::string process_verbatim_text(std::string const& text) = 0; /** - * Return the name of main compute kernels - * \param type A block type + * Arguments for register_mech or point_register_mech function */ - virtual std::string compute_method_name(BlockType type) const; + virtual std::string register_mechanism_arguments() const = 0; /** - * The used global type qualifier - * - * For C++ code generation this is empty - * \return "" - * - * \return "uniform " - */ - virtual std::string global_var_struct_type_qualifier(); - - /** - * Instantiate global var instance + * Add quotes to string to be output * - * For C++ code generation this is empty - * \return "" - */ - virtual void print_global_var_struct_decl(); - - /** - * Print static assertions about the global variable struct. - */ - virtual void print_global_var_struct_assertions() const; - - /** - * Prints the start of the \c coreneuron namespace + * \param text The string to be quoted + * \return The same string with double-quotes pre- and postfixed */ - void print_namespace_start(); + std::string add_escape_quote(const std::string& text) const { + return "\"" + text + "\""; + } /** - * Prints the end of the \c coreneuron namespace + * Constructs the name of a function or procedure + * \param name The name of the function or procedure + * \return The name of the function or procedure postfixed with the model name */ - void print_namespace_stop(); + std::string method_name(const std::string& name) const { + return name + "_" + info.mod_suffix; + } /** - * Prints the start of namespace for the backend-specific code - * - * For the C++ backend no additional namespace is required + * Creates a temporary symbol + * \param name The name of the symbol + * \return A symbol based on the given name */ - virtual void print_backend_namespace_start(); + SymbolType make_symbol(const std::string& name) const { + return std::make_shared(name, ModToken()); + } - /** - * Prints the end of namespace for the backend-specific code - * - * For the C++ backend no additional namespace is required - */ - virtual void print_backend_namespace_stop(); + /****************************************************************************************/ + /* Code-specific printing routines for code generations */ + /****************************************************************************************/ /** - * Print the nmodl constants used in backend code - * - * Currently we define three basic constants, which are assumed to be present in NMODL, directly - * in the backend code: - * - * \code - * static const double FARADAY = 96485.3; - * static const double PI = 3.14159; - * static const double R = 8.3145; - * \endcode + * Prints the start of the simulator namespace */ - virtual void print_nmodl_constants(); + virtual void print_namespace_start() = 0; /** - * Print top file header printed in generated code + * Prints the end of the simulator namespace */ - void print_backend_info(); + virtual void print_namespace_stop() = 0; - /** - * Print memory allocation routine - */ - virtual void print_memory_allocation_routine() const; + /****************************************************************************************/ + /* Routines for returning variable name */ + /****************************************************************************************/ /** - * Print backend specific abort routine + * Determine the name of a \c float variable given its symbol + * + * This function typically returns the accessor expression in backend code for the given symbol. + * Since the model variables are stored in data arrays and accessed by offset, this function + * will return the C++ string representing the array access at the correct offset + * + * \param symbol The symbol of a variable for which we want to obtain its name + * \param use_instance Should the variable be accessed via instance or data array + * \return The backend code string representing the access to the given variable + * symbol */ - virtual void print_abort_routine() const; + virtual std::string float_variable_name(const SymbolType& symbol, bool use_instance) const = 0; /** - * Print standard C/C++ includes + * Determine the name of an \c int variable given its symbol + * + * This function typically returns the accessor expression in backend code for the given symbol. + * Since the model variables are stored in data arrays and accessed by offset, this function + * will return the C++ string representing the array access at the correct offset + * + * \param symbol The symbol of a variable for which we want to obtain its name + * \param name The name of the index variable + * \param use_instance Should the variable be accessed via instance or data array + * \return The backend code string representing the access to the given variable + * symbol */ - void print_standard_includes(); + virtual std::string int_variable_name(const IndexVariableInfo& symbol, + const std::string& name, + bool use_instance) const = 0; /** - * Print includes from coreneuron + * Determine the variable name for a global variable given its symbol + * \param symbol The symbol of a variable for which we want to obtain its name + * \param use_instance Should the variable be accessed via the (host-only) + * global variable or the instance-specific copy (also available on GPU). + * \return The C++ string representing the access to the global variable */ - void print_coreneuron_includes(); + virtual std::string global_variable_name(const SymbolType& symbol, + bool use_instance = true) const = 0; /** - * Print backend specific includes (none needed for C++ backend) + * Determine variable name in the structure of mechanism properties + * + * \param name Variable name that is being printed + * \param use_instance Should the variable be accessed via instance or data array + * \return The C++ string representing the access to the variable in the neuron + * thread structure */ - virtual void print_backend_includes(); + virtual std::string get_variable_name(const std::string& name, + bool use_instance = true) const = 0; - /** - * Check if ion variable copies should be avoided - */ - bool optimize_ion_variable_copies() const; + /****************************************************************************************/ + /* Main printing routines for code generation */ + /****************************************************************************************/ /** - * Check if reduction block in \c nrn\_cur required + * Print top file header printed in generated code */ - virtual bool nrn_cur_reduction_loop_required(); + virtual void print_backend_info() = 0; /** - * Check if variable is qualified as constant - * \param name The name of variable - * \return \c true if it is constant - */ - virtual bool is_constant_variable(const std::string& name) const; - - /** - * Check if the given name exist in the symbol - * \return \c return a tuple if variable - * is an array otherwise + * Print standard C/C++ includes */ - std::tuple check_if_var_is_array(const std::string& name); + virtual void print_standard_includes() = 0; - /** - * Print declaration of macro NRN_PRCELLSTATE for debugging - */ - void print_prcellstate_macros() const; - /** - * Print backend code for byte array that has mechanism information (to be registered - * with coreneuron) - */ - void print_mechanism_info(); + virtual void print_sdlists_init(bool print_initializers) = 0; /** @@ -976,788 +847,233 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param print_initializers Whether to include default values in the struct * definition (true: int foo{42}; false: int foo;) */ - void print_mechanism_global_var_structure(bool print_initializers); + virtual void print_mechanism_global_var_structure(bool print_initializers) = 0; /** - * Print structure of ion variables used for local copies + * Print declaration of macro NRN_PRCELLSTATE for debugging */ - void print_ion_var_structure(); + void print_prcellstate_macros() const; /** - * Print constructor of ion variables - * \param members The ion variable names + * Print backend code for byte array that has mechanism information (to be registered + * with NEURON/CoreNEURON) */ - virtual void print_ion_var_constructor(const std::vector& members); + void print_mechanism_info(); /** - * Print the ion variable struct + * Print byte arrays that register scalar and vector variables for hoc interface + * */ - virtual void print_ion_variable(); + virtual void print_global_variables_for_hoc() = 0; /** - * Returns floating point type for given range variable symbol - * \param symbol A range variable symbol + * Print the mechanism registration function + * */ - std::string get_range_var_float_type(const SymbolType& symbol); + virtual void print_mechanism_register() = 0; /** - * Print the function that initialize range variable with different data type + * Print common code for global functions like nrn_init, nrn_cur and nrn_state + * \param type The target backend code block type */ - void print_setup_range_variable(); - + virtual void print_global_function_common_code(BlockType type, + const std::string& function_name = "") = 0; - /** - * Print declarations of the functions used by \ref - * print_instance_struct_copy_to_device and \ref - * print_instance_struct_delete_from_device. - */ - virtual void print_instance_struct_transfer_routine_declarations() {} /** - * Print the definitions of the functions used by \ref - * print_instance_struct_copy_to_device and \ref - * print_instance_struct_delete_from_device. Declarations of these functions - * are printed by \ref print_instance_struct_transfer_routine_declarations. - * - * This updates the (pointer) member variables in the device copy of the - * instance struct to contain device pointers, which is why you must pass a - * list of names of those member variables. + * Print nrn_constructor function definition * - * \param ptr_members List of instance struct member names. */ - virtual void print_instance_struct_transfer_routines( - std::vector const& /* ptr_members */) {} + virtual void print_nrn_constructor() = 0; /** - * Transfer the instance struct to the device. This calls a function - * declared by \ref print_instance_struct_transfer_routine_declarations. - */ - virtual void print_instance_struct_copy_to_device() {} - - /** - * Delete the instance struct from the device. This calls a function - * declared by \ref print_instance_struct_transfer_routine_declarations. + * Print nrn_destructor function definition + * */ - virtual void print_instance_struct_delete_from_device() {} + virtual void print_nrn_destructor() = 0; /** - * Print the code to copy derivative advance flag to device + * Print nrn_alloc function definition + * */ - virtual void print_deriv_advance_flag_transfer_to_device() const; - + virtual void print_nrn_alloc() = 0; - /** - * Print the code to update NetSendBuffer_t count from device to host - */ - virtual void print_net_send_buf_count_update_to_host() const; - /** - * Print the code to update NetSendBuffer_t from device to host - */ - virtual void print_net_send_buf_update_to_host() const; + /****************************************************************************************/ + /* Print nrn_state routine */ + /****************************************************************************************/ /** - * Print the code to update NetSendBuffer_t count from host to device + * Print nrn_state / state update function definition */ - virtual void print_net_send_buf_count_update_to_device() const; + virtual void print_nrn_state() = 0; - /** - * Print the code to update dt from host to device - */ - virtual void print_dt_update_to_device() const; - /** - * Print the code to synchronise/wait on stream specific to NrnThread - */ - virtual void print_device_stream_wait() const; + /****************************************************************************************/ + /* Print nrn_cur related routines */ + /****************************************************************************************/ /** - * Print byte arrays that register scalar and vector variables for hoc interface + * Print the \c nrn_current kernel * + * \note nrn_cur_kernel will have two calls to nrn_current if no conductance keywords specified + * \param node the AST node representing the NMODL breakpoint block */ - void print_global_variables_for_hoc(); + virtual void print_nrn_current(const ast::BreakpointBlock& node) = 0; /** - * Print the getter method for thread variables and ids + * Print the \c nrn\_cur kernel with NMODL \c conductance keyword provisions + * + * If the NMODL \c conductance keyword is used in the \c breakpoint block, then + * CodegenCoreneuronCppVisitor::print_nrn_cur_kernel will use this printer * + * \param node the AST node representing the NMODL breakpoint block */ - void print_thread_getters(); + virtual void print_nrn_cur_conductance_kernel(const ast::BreakpointBlock& node) = 0; /** - * Print the getter method for index position of first pointer variable + * Print the \c nrn\_cur kernel without NMODL \c conductance keyword provisions * + * If the NMODL \c conductance keyword is \b not used in the \c breakpoint block, then + * CodegenCoreneuronCppVisitor::print_nrn_cur_kernel will use this printer */ - void print_first_pointer_var_index_getter(); + virtual void print_nrn_cur_non_conductance_kernel() = 0; /** - * Print the getter methods for float and integer variables count - * + * Print main body of nrn_cur function + * \param node the AST node representing the NMODL breakpoint block */ - void print_num_variable_getter(); + virtual void print_nrn_cur_kernel(const ast::BreakpointBlock& node) = 0; /** - * Print the getter method for getting number of arguments for net_receive - * + * Print fast membrane current calculation code */ - void print_net_receive_arg_size_getter(); + virtual void print_fast_imem_calculation() = 0; /** - * Print the getter method for returning membrane list from NrnThread - * + * Print nrn_cur / current update function definition */ - void print_memb_list_getter(); + virtual void print_nrn_cur() = 0; - /** - * Print the getter method for returning mechtype - * - */ - void print_mech_type_getter(); + /****************************************************************************************/ + /* Main code printing entry points */ + /****************************************************************************************/ /** - * Print the pragma annotation to update global variables from host to the device + * Print all includes * - * \note This is not used for the C++ backend */ - virtual void print_global_variable_device_update_annotation(); + virtual void print_headers_include() = 0; /** - * Print the setup method for setting matrix shadow vectors + * Print start of namespaces * */ - virtual void print_rhs_d_shadow_variables(); + virtual void print_namespace_begin() = 0; /** - * Print the backend specific device method annotation + * Print end of namespaces * - * \note This is not used for the C++ backend */ - virtual void print_device_method_annotation(); + virtual void print_namespace_end() = 0; /** - * Print backend specific global method annotation - * - * \note This is not used for the C++ backend + * Print all classes + * \param print_initializers Whether to include default values. */ - virtual void print_global_method_annotation(); + virtual void print_data_structures(bool print_initializers) = 0; /** - * Print call to internal or external function - * \param node The AST node representing a function call + * Set v_unused (voltage) for NRN_PRCELLSTATE feature */ - void print_function_call(const ast::FunctionCall& node); + virtual void print_v_unused() const = 0; /** - * Print call to \c net\_send - * \param node The AST node representing the function call + * Set g_unused (conductance) for NRN_PRCELLSTATE feature */ - void print_net_send_call(const ast::FunctionCall& node); + virtual void print_g_unused() const = 0; /** - * Print call to net\_move - * \param node The AST node representing the function call + * Print all compute functions for every backend + * */ - void print_net_move_call(const ast::FunctionCall& node); + virtual void print_compute_functions() = 0; /** - * Print call to net\_event - * \param node The AST node representing the function call + * Print entry point to code generation + * */ - void print_net_event_call(const ast::FunctionCall& node); + virtual void print_codegen_routines() = 0; /** - * Print pragma annotations for channel iterations + * Print the nmodl constants used in backend code * - * This can be overriden by backends to provide additonal annotations or pragmas to enable - * for example SIMD code generation (e.g. through \c ivdep) - * The default implementation prints + * Currently we define three basic constants, which are assumed to be present in NMODL, directly + * in the backend code: * * \code - * #pragma ivdep + * static const double FARADAY = 96485.3; + * static const double PI = 3.14159; + * static const double R = 8.3145; * \endcode - * - * \param type The block type - */ - virtual void print_channel_iteration_block_parallel_hint(BlockType type, - const ast::Block* block); - - - /** - * Print accelerator annotations indicating data presence on device - */ - virtual void print_kernel_data_present_annotation_block_begin(); - - - /** - * Print matching block end of accelerator annotations for data presence on device - */ - virtual void print_kernel_data_present_annotation_block_end(); - - - /** - * Print accelerator kernels begin annotation for net_init kernel - */ - virtual void print_net_init_acc_serial_annotation_block_begin(); - - - /** - * Print accelerator kernels end annotation for net_init kernel - */ - virtual void print_net_init_acc_serial_annotation_block_end(); - - - /** - * Print function and procedures prototype declaration */ - void print_function_prototypes(); + void print_nmodl_constants(); - /** - * Print check_table functions - */ - void print_check_table_thread_function(); + /****************************************************************************************/ + /* Protected constructors */ + /****************************************************************************************/ - /** - * Print nmodl function or procedure (common code) - * \param node the AST node representing the function or procedure in NMODL - * \param name the name of the function or procedure - */ - void print_function_or_procedure(const ast::Block& node, const std::string& name); + /// This constructor is private, only the derived classes' public constructors are public + CodegenCppVisitor(std::string mod_filename, + const std::string& output_dir, + std::string float_type, + const bool optimize_ionvar_copies) + : printer(std::make_unique(output_dir + "/" + mod_filename + ".cpp")) + , mod_filename(std::move(mod_filename)) + , float_type(std::move(float_type)) + , optimize_ionvar_copies(optimize_ionvar_copies) {} - /** - * Common helper function to help printing function or procedure blocks - * \param node the AST node representing the function or procedure in NMODL - */ - void print_function_procedure_helper(const ast::Block& node); + /// This constructor is private, only the derived classes' public constructors are public + CodegenCppVisitor(std::string mod_filename, + std::ostream& stream, + std::string float_type, + const bool optimize_ionvar_copies) + : printer(std::make_unique(stream)) + , mod_filename(std::move(mod_filename)) + , float_type(std::move(float_type)) + , optimize_ionvar_copies(optimize_ionvar_copies) {} - /** - * Print thread related memory allocation and deallocation callbacks - */ - void print_thread_memory_callbacks(); - - /** - * Print top level (global scope) verbatim blocks - */ - void print_top_verbatim_blocks(); - - - /** - * Print prototype declarations of functions or procedures - * \tparam T The AST node type of the node (must be of nmodl::ast::Ast or subclass) - * \param node The AST node representing the function or procedure block - * \param name A user defined name for the function - */ - template - void print_function_declaration(const T& node, const std::string& name); - - - /** - * Print initial block statements - * - * Generate the target backend code corresponding to the NMODL initial block statements - * - * \param node The AST Node representing a NMODL initial block - */ - void print_initial_block(const ast::InitialBlock* node); - - - /** - * Print initial block in the net receive block - */ - void print_net_init(); - - - /** - * Print the common code section for net receive related methods - * - * \param node The AST node representing the corresponding NMODL block - * \param need_mech_inst \c true if a local \c inst variable needs to be defined in generated - * code - */ - void print_net_receive_common_code(const ast::Block& node, bool need_mech_inst = true); - - - /** - * Print the code related to the update of NetSendBuffer_t cnt. For GPU this needs to be done - * with atomic operation, on CPU it's not needed. - * - */ - virtual void print_net_send_buffering_cnt_update() const; - - - /** - * Print statement that grows NetSendBuffering_t structure if needed. - * This function should be overridden for backends that cannot dynamically reallocate the buffer - */ - virtual void print_net_send_buffering_grow(); - - - /** - * Print kernel for buffering net_send events - * - * This kernel is only needed for accelerator backends where \c net\_send needs to be executed - * in two stages as the actual communication must be done in the host code. - */ - void print_net_send_buffering(); - - - /** - * Print send event move block used in net receive as well as watch - */ - void print_send_event_move(); - - - /** - * Generate the target backend code for the \c net\_receive\_buffering function delcaration - * \return The target code string - */ - virtual std::string net_receive_buffering_declaration(); - - - /** - * Print the target backend code for defining and checking a local \c Memb\_list variable - */ - virtual void print_get_memb_list(); - - - /** - * Print the code for the main \c net\_receive loop - */ - virtual void print_net_receive_loop_begin(); - - - /** - * Print the code for closing the main \c net\_receive loop - */ - virtual void print_net_receive_loop_end(); - - - /** - * Print \c net\_receive function definition - */ - void print_net_receive(); - - - /** - * Print derivative kernel when \c derivimplicit method is used - * - * \param block The corresponding AST node representing an NMODL \c derivimplicit block - */ - void print_derivimplicit_kernel(const ast::Block& block); - - - /** - * Print code block to transfer newtonspace structure to device - */ - virtual void print_newtonspace_transfer_to_device() const; - - - /** - * Print pragma annotation for increase and capture of variable in automatic way - */ - virtual void print_device_atomic_capture_annotation() const; - - /** - * Print atomic update pragma for reduction statements - */ - virtual void print_atomic_reduction_pragma(); - - - /** - * Print all reduction statements - * - */ - void print_shadow_reduction_statements(); - - - /** - * Process shadow update statement - * - * If the statement requires reduction then add it to vector of reduction statement and return - * statement using shadow update - * - * \param statement The statement that might require shadow updates - * \param type The target backend code block type - * \return The generated target backend code - */ - std::string process_shadow_update_statement(const ShadowUseStatement& statement, - BlockType type); - - - /** - * Print main body of nrn_cur function - * \param node the AST node representing the NMODL breakpoint block - */ - void print_nrn_cur_kernel(const ast::BreakpointBlock& node); - - - /** - * Print the \c nrn\_cur kernel with NMODL \c conductance keyword provisions - * - * If the NMODL \c conductance keyword is used in the \c breakpoint block, then - * CodegenCppVisitor::print_nrn_cur_kernel will use this printer - * - * \param node the AST node representing the NMODL breakpoint block - */ - void print_nrn_cur_conductance_kernel(const ast::BreakpointBlock& node); - - - /** - * Print the \c nrn\_cur kernel without NMODL \c conductance keyword provisions - * - * If the NMODL \c conductance keyword is \b not used in the \c breakpoint block, then - * CodegenCppVisitor::print_nrn_cur_kernel will use this printer - */ - void print_nrn_cur_non_conductance_kernel(); - - - /** - * Print the \c nrn_current kernel - * - * \note nrn_cur_kernel will have two calls to nrn_current if no conductance keywords specified - * \param node the AST node representing the NMODL breakpoint block - */ - void print_nrn_current(const ast::BreakpointBlock& node); - - - /** - * Print the update to matrix elements with/without shadow vectors - * - */ - virtual void print_nrn_cur_matrix_shadow_update(); - - - /** - * Print the reduction to matrix elements from shadow vectors - * - */ - virtual void print_nrn_cur_matrix_shadow_reduction(); - - - /** - * Print nrn_constructor function definition - * - */ - void print_nrn_constructor(); - - - /** - * Print nrn_destructor function definition - * - */ - void print_nrn_destructor(); - - - /** - * Print nrn_alloc function definition - * - */ - void print_nrn_alloc(); - - - /** - * Print common code for global functions like nrn_init, nrn_cur and nrn_state - * \param type The target backend code block type - */ - virtual void print_global_function_common_code(BlockType type, - const std::string& function_name = ""); - - - /** - * Print the mechanism registration function - * - */ - void print_mechanism_register(); - - - /** - * Print watch activate function - * - */ - void print_watch_activate(); - - - /** - * Print all includes - * - */ - virtual void print_headers_include(); - - - /** - * Print start of namespaces - * - */ - void print_namespace_begin(); - - - /** - * Print end of namespaces - * - */ - void print_namespace_end(); - - - /** - * Print common getters - * - */ - void print_common_getters(); - - - /** - * Print all classes - * \param print_initializers Whether to include default values. - */ - void print_data_structures(bool print_initializers); - - - /** - * Set v_unused (voltage) for NRN_PRCELLSTATE feature - */ - void print_v_unused() const; - - - /** - * Set g_unused (conductance) for NRN_PRCELLSTATE feature - */ - void print_g_unused() const; - - - /** - * Print all compute functions for every backend - * - */ - virtual void print_compute_functions(); - - - /** - * Print entry point to code generation - * - */ - virtual void print_codegen_routines(); - - - public: - /** - * \brief Constructs the C++ code generator visitor - * - * This constructor instantiates an NMODL C++ code generator and allows writing generated code - * directly to a file in \c [output_dir]/[mod_filename].[extension]. - * - * \note No code generation is performed at this stage. Since the code - * generator classes are all based on \c AstVisitor the AST must be visited using e.g. \c - * visit_program in order to generate the C++ code corresponding to the AST. - * - * \param mod_filename The name of the model for which code should be generated. - * It is used for constructing an output filename. - * \param output_dir The directory where target C++ file should be generated. - * \param float_type The float type to use in the generated code. The string will be used - * as-is in the target code. This defaults to \c double. - */ - CodegenCppVisitor(std::string mod_filename, - const std::string& output_dir, - std::string float_type, - const bool optimize_ionvar_copies) - : printer(std::make_unique(output_dir + "/" + mod_filename + ".cpp")) - , mod_filename(std::move(mod_filename)) - , float_type(std::move(float_type)) - , optimize_ionvar_copies(optimize_ionvar_copies) {} - - /** - * \copybrief nmodl::codegen::CodegenCppVisitor - * - * This constructor instantiates an NMODL C++ code generator and allows writing generated code - * into an output stream. - * - * \note No code generation is performed at this stage. Since the code - * generator classes are all based on \c AstVisitor the AST must be visited using e.g. \c - * visit_program in order to generate the C++ code corresponding to the AST. - * - * \param mod_filename The name of the model for which code should be generated. - * It is used for constructing an output filename. - * \param stream The output stream onto which to write the generated code - * \param float_type The float type to use in the generated code. The string will be used - * as-is in the target code. This defaults to \c double. - */ - CodegenCppVisitor(std::string mod_filename, - std::ostream& stream, - std::string float_type, - const bool optimize_ionvar_copies) - : printer(std::make_unique(stream)) - , mod_filename(std::move(mod_filename)) - , float_type(std::move(float_type)) - , optimize_ionvar_copies(optimize_ionvar_copies) {} - - /** - * Main and only member function to call after creating an instance of this class. - * \param program the AST to translate to C++ code - */ - void visit_program(const ast::Program& program) override; - - /** - * Print the \c nrn\_init function definition - * \param skip_init_check \c true to generate code executing the initialization conditionally - */ - void print_nrn_init(bool skip_init_check = true); - - - /** - * Print nrn_state / state update function definition - */ - void print_nrn_state(); - - - /** - * Print nrn_cur / current update function definition - */ - void print_nrn_cur(); - - /** - * Print fast membrane current calculation code - */ - virtual void print_fast_imem_calculation(); - - - /** - * Print kernel for buffering net_receive events - * - * This kernel is only needed for accelerator backends where \c net\_receive needs to be - * executed in two stages as the actual communication must be done in the host code. \param - * need_mech_inst \c true if the generated code needs a local inst variable to be defined - */ - void print_net_receive_buffering(bool need_mech_inst = true); - - - /** - * Print \c net\_receive kernel function definition - */ - void print_net_receive_kernel(); - - - /** - * Print watch activate function - */ - void print_watch_check(); - - - /** - * Print \c check\_function() for functions or procedure using table - * \param node The AST node representing a function or procedure block - */ - void print_table_check_function(const ast::Block& node); - - - /** - * Print replacement function for function or procedure using table - * \param node The AST node representing a function or procedure block - */ - void print_table_replacement_function(const ast::Block& node); - - - /** - * Print NMODL function in target backend code - * \param node - */ - void print_function(const ast::FunctionBlock& node); - - - /** - * Print NMODL function_table in target backend code - * \param node - */ - void print_function_tables(const ast::FunctionTableBlock& node); - - - /** - * Print NMODL procedure in target backend code - * \param node - */ - virtual void print_procedure(const ast::ProcedureBlock& node); - - /** - * Print NMODL before / after block in target backend code - * \param node AST node of type before/after type being printed - * \param block_id Index of the before/after block - */ - virtual void print_before_after_block(const ast::Block* node, size_t block_id); - - /** Setup the target backend code generator - * - * Typically called from within \c visit\_program but may be called from - * specialized targets to setup this Code generator as fallback. - */ - void setup(const ast::Program& node); - - - /** - * Set the global variables to be generated in target backend code - * \param global_vars - */ - void set_codegen_global_variables(const std::vector& global_vars); - - /** - * Find unique variable name defined in nmodl::utils::SingletonRandomString by the - * nmodl::visitor::SympySolverVisitor - * \param original_name Original name of variable to change - * \return std::string Unique name produced as [original_name]_[random_string] - */ - std::string find_var_unique_name(const std::string& original_name) const; - - /** - * Print the structure that wraps all range and int variables required for the NMODL - * - * \param print_initializers Whether or not default values for variables - * be included in the struct declaration. - */ - void print_mechanism_range_var_structure(bool print_initializers); - - /** - * Print the function that initialize instance structure - */ - void print_instance_variable_setup(); - - /** - * Go through the map of \c EigenNewtonSolverBlock s and their corresponding functor names - * and print the functor definitions before the definitions of the functions of the generated - * file - * - */ - void print_functors_definitions(); - - /** - * \brief Based on the \c EigenNewtonSolverBlock passed print the definition needed for its - * functor - * - * \param node \c EigenNewtonSolverBlock for which to print the functor - */ - void print_functor_definition(const ast::EigenNewtonSolverBlock& node); + /****************************************************************************************/ + /* Overloaded visitor routines */ + /****************************************************************************************/ void visit_binary_expression(const ast::BinaryExpression& node) override; void visit_binary_operator(const ast::BinaryOperator& node) override; @@ -1768,9 +1084,6 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { void visit_float(const ast::Float& node) override; void visit_from_statement(const ast::FromStatement& node) override; void visit_function_call(const ast::FunctionCall& node) override; - void visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock& node) override; - void visit_eigen_linear_solver_block(const ast::EigenLinearSolverBlock& node) override; - virtual void print_eigen_linear_solver(const std::string& float_type, int N); void visit_if_statement(const ast::IfStatement& node) override; void visit_indexed_name(const ast::IndexedName& node) override; void visit_integer(const ast::Integer& node) override; @@ -1780,22 +1093,48 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { void visit_prime_name(const ast::PrimeName& node) override; void visit_statement_block(const ast::StatementBlock& node) override; void visit_string(const ast::String& node) override; - void visit_solution_expression(const ast::SolutionExpression& node) override; void visit_unary_operator(const ast::UnaryOperator& node) override; void visit_unit(const ast::Unit& node) override; void visit_var_name(const ast::VarName& node) override; void visit_verbatim(const ast::Verbatim& node) override; - void visit_watch_statement(const ast::WatchStatement& node) override; void visit_while_statement(const ast::WhileStatement& node) override; - void visit_derivimplicit_callback(const ast::DerivimplicitCallback& node) override; - void visit_for_netcon(const ast::ForNetcon& node) override; void visit_update_dt(const ast::UpdateDt& node) override; void visit_protect_statement(const ast::ProtectStatement& node) override; void visit_mutex_lock(const ast::MutexLock& node) override; void visit_mutex_unlock(const ast::MutexUnlock& node) override; -}; + public: + /** Setup the target backend code generator + * + * Typically called from within \c visit\_program but may be called from + * specialized targets to setup this Code generator as fallback. + */ + void setup(const ast::Program& node); + + + /** + * Main and only member function to call after creating an instance of this class. + * \param program the AST to translate to C++ code + */ + void visit_program(const ast::Program& program) override; + + + /****************************************************************************************/ + /* Public printing routines for code generation for use in unit tests */ + /****************************************************************************************/ + + + /** + * Print the structure that wraps all range and int variables required for the NMODL + * + * \param print_initializers Whether or not default values for variables + * be included in the struct declaration. + */ + virtual void print_mechanism_range_var_structure(bool print_initializers) = 0; +}; + +/* Templated functions need to be defined in header file */ template void CodegenCppVisitor::print_vector_elements(const std::vector& elements, const std::string& separator, @@ -1810,60 +1149,7 @@ void CodegenCppVisitor::print_vector_elements(const std::vector& elements, } -/** - * Check if function or procedure node has parameter with given name - * - * \tparam T Node type (either procedure or function) - * \param node AST node (either procedure or function) - * \param name Name of parameter - * \return True if argument with name exist - */ -template -bool has_parameter_of_name(const T& node, const std::string& name) { - auto parameters = node->get_parameters(); - return std::any_of(parameters.begin(), - parameters.end(), - [&name](const decltype(*parameters.begin()) arg) { - return arg->get_node_name() == name; - }); -} - - -/** - * \details If there is an argument with name (say alpha) same as range variable (say alpha), - * we want to avoid it being printed as instance->alpha. And hence we disable variable - * name lookup during prototype declaration. Note that the name of procedure can be - * different in case of table statement. - */ -template -void CodegenCppVisitor::print_function_declaration(const T& node, const std::string& name) { - enable_variable_name_lookup = false; - auto type = default_float_data_type(); - - // internal and user provided arguments - auto internal_params = internal_method_parameters(); - const auto& params = node.get_parameters(); - for (const auto& param: params) { - internal_params.emplace_back("", type, "", param.get()->get_node_name()); - } - - // procedures have "int" return type by default - const char* return_type = "int"; - if (node.is_function_block()) { - return_type = default_float_data_type(); - } - - print_device_method_annotation(); - printer->add_indent(); - printer->fmt_text("inline {} {}({})", - return_type, - method_name(name), - get_parameter_str(internal_params)); - - enable_variable_name_lookup = true; -} - /** \} */ // end of codegen_backends } // namespace codegen -} // namespace nmodl +} // namespace nmodl \ No newline at end of file diff --git a/src/codegen/codegen_info.hpp b/src/codegen/codegen_info.hpp index 67cf6c9bb4..6627c05776 100644 --- a/src/codegen/codegen_info.hpp +++ b/src/codegen/codegen_info.hpp @@ -253,6 +253,10 @@ struct CodegenInfo { /// typically equal to number of primes int num_equations = 0; + /// True if we have to emit CVODE code + /// TODO: Figure out when this needs to be true + bool emit_cvode = false; + /// derivative block const ast::BreakpointBlock* breakpoint_node = nullptr; diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp new file mode 100644 index 0000000000..54abb3b87e --- /dev/null +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -0,0 +1,687 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "codegen/codegen_neuron_cpp_visitor.hpp" + +#include +#include +#include +#include +#include + +#include "ast/all.hpp" +#include "codegen/codegen_utils.hpp" +#include "config/config.h" +#include "utils/string_utils.hpp" +#include "visitors/visitor_utils.hpp" + +namespace nmodl { +namespace codegen { + +using namespace ast; + +using symtab::syminfo::NmodlType; + + +/****************************************************************************************/ +/* Generic information getters */ +/****************************************************************************************/ + + +std::string CodegenNeuronCppVisitor::simulator_name() { + return "NEURON"; +} + + +std::string CodegenNeuronCppVisitor::backend_name() const { + return "C++ (api-compatibility)"; +} + + +/****************************************************************************************/ +/* Common helper routines accross codegen functions */ +/****************************************************************************************/ + + +int CodegenNeuronCppVisitor::position_of_float_var(const std::string& name) const { + const auto has_name = [&name](const SymbolType& symbol) { return symbol->get_name() == name; }; + const auto var_iter = + std::find_if(codegen_float_variables.begin(), codegen_float_variables.end(), has_name); + if (var_iter != codegen_float_variables.end()) { + return var_iter - codegen_float_variables.begin(); + } else { + throw std::logic_error(name + " variable not found"); + } +} + + +int CodegenNeuronCppVisitor::position_of_int_var(const std::string& name) const { + const auto has_name = [&name](const IndexVariableInfo& index_var_symbol) { + return index_var_symbol.symbol->get_name() == name; + }; + const auto var_iter = + std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), has_name); + if (var_iter != codegen_int_variables.end()) { + return var_iter - codegen_int_variables.begin(); + } else { + throw std::logic_error(name + " variable not found"); + } +} + + +/****************************************************************************************/ +/* Backend specific routines */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_atomic_reduction_pragma() { + return; +} + + +/****************************************************************************************/ +/* Printing routines for code generation */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_function_call(const FunctionCall& node) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_function_prototypes() { + if (info.functions.empty() && info.procedures.empty()) { + return; + } + codegen = true; + /// TODO: Fill in + codegen = false; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_function_or_procedure(const ast::Block& node, + const std::string& name) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_function_procedure_helper(const ast::Block& node) { + codegen = true; + /// TODO: Fill in + codegen = false; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_procedure(const ast::ProcedureBlock& node) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_function(const ast::FunctionBlock& node) { + return; +} + + +/****************************************************************************************/ +/* Code-specific helper routines */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::internal_method_arguments() { + return {}; +} + + +/// TODO: Edit for NEURON +CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_parameters() { + return {}; +} + + +/// TODO: Edit for NEURON +const char* CodegenNeuronCppVisitor::external_method_arguments() noexcept { + return {}; +} + + +/// TODO: Edit for NEURON +const char* CodegenNeuronCppVisitor::external_method_parameters(bool table) noexcept { + return {}; +} + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::nrn_thread_arguments() const { + return {}; +} + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::nrn_thread_internal_arguments() { + return {}; +} + + +/// TODO: Write for NEURON +std::string CodegenNeuronCppVisitor::process_verbatim_text(std::string const& text) { + return {}; +} + + +/// TODO: Write for NEURON +std::string CodegenNeuronCppVisitor::register_mechanism_arguments() const { + return {}; +}; + + +/****************************************************************************************/ +/* Code-specific printing routines for code generation */ +/****************************************************************************************/ + + +void CodegenNeuronCppVisitor::print_namespace_start() { + printer->add_newline(2); + printer->push_block("namespace neuron"); +} + + +void CodegenNeuronCppVisitor::print_namespace_stop() { + printer->pop_block(); +} + + +/****************************************************************************************/ +/* Routines for returning variable name */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::float_variable_name(const SymbolType& symbol, + bool use_instance) const { + return symbol->get_name(); +} + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::int_variable_name(const IndexVariableInfo& symbol, + const std::string& name, + bool use_instance) const { + return name; +} + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::global_variable_name(const SymbolType& symbol, + bool use_instance) const { + return symbol->get_name(); +} + + +/// TODO: Edit for NEURON +std::string CodegenNeuronCppVisitor::get_variable_name(const std::string& name, + bool use_instance) const { + return name; +} + + +/****************************************************************************************/ +/* Main printing routines for code generation */ +/****************************************************************************************/ + + +void CodegenNeuronCppVisitor::print_backend_info() { + time_t current_time{}; + time(¤t_time); + std::string data_time_str{std::ctime(¤t_time)}; + auto version = nmodl::Version::NMODL_VERSION + " [" + nmodl::Version::GIT_REVISION + "]"; + + printer->add_line("/*********************************************************"); + printer->add_line("Model Name : ", info.mod_suffix); + printer->add_line("Filename : ", info.mod_file, ".mod"); + printer->add_line("NMODL Version : ", nmodl_version()); + printer->fmt_line("Vectorized : {}", info.vectorize); + printer->fmt_line("Threadsafe : {}", info.thread_safe); + printer->add_line("Created : ", stringutils::trim(data_time_str)); + printer->add_line("Simulator : ", simulator_name()); + printer->add_line("Backend : ", backend_name()); + printer->add_line("NMODL Compiler : ", version); + printer->add_line("*********************************************************/"); +} + + +void CodegenNeuronCppVisitor::print_standard_includes() { + printer->add_newline(); + printer->add_multi_line(R"CODE( + #include + #include + #include + )CODE"); + if (!info.vectorize) { + printer->add_line("#include "); + } +} + + +void CodegenNeuronCppVisitor::print_neuron_includes() { + printer->add_newline(); + printer->add_multi_line(R"CODE( + #include "mech_api.h" + #include "neuron/cache/mechanism_range.hpp" + #include "nrniv_mf.h" + #include "section_fwd.hpp" + )CODE"); +} + + +void CodegenNeuronCppVisitor::print_sdlists_init(bool print_initializers) { + for (auto i = 0; i < info.prime_variables_by_order.size(); ++i) { + const auto& prime_var = info.prime_variables_by_order[i]; + /// TODO: Something similar needs to happen for slist/dlist2 but I don't know their usage at + // the moment + /// TODO: We have to do checks and add errors similar to nocmodl in the + // SemanticAnalysisVisitor + if (prime_var->is_array()) { + /// TODO: Needs a for loop here. Look at + // https://github.com/neuronsimulator/nrn/blob/df001a436bcb4e23d698afe66c2a513819a6bfe8/src/nmodl/deriv.cpp#L524 + /// TODO: Also needs a test + printer->fmt_push_block("for (int _i = 0; _i < {}; ++_i)", prime_var->get_length()); + printer->fmt_line("/* {}[{}] */", prime_var->get_name(), prime_var->get_length()); + printer->fmt_line("_slist1[{}+_i] = {{{}, _i}}", + i, + position_of_float_var(prime_var->get_name())); + const auto prime_var_deriv_name = "D" + prime_var->get_name(); + printer->fmt_line("/* {}[{}] */", prime_var_deriv_name, prime_var->get_length()); + printer->fmt_line("_dlist1[{}+_i] = {{{}, _i}}", + i, + position_of_float_var(prime_var_deriv_name)); + printer->pop_block(); + } else { + printer->fmt_line("/* {} */", prime_var->get_name()); + printer->fmt_line("_slist1[{}] = {{{}, 0}}", + i, + position_of_float_var(prime_var->get_name())); + const auto prime_var_deriv_name = "D" + prime_var->get_name(); + printer->fmt_line("/* {} */", prime_var_deriv_name); + printer->fmt_line("_dlist1[{}] = {{{}, 0}}", + i, + position_of_float_var(prime_var_deriv_name)); + } + } +} + + +void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_initializers) { + /// TODO: Print only global variables printed in NEURON + printer->add_line(); + printer->add_line("/* NEURON global variables */"); + if (info.primes_size != 0) { + printer->fmt_line("static neuron::container::field_index _slist1[{0}], _dlist1[{0}];", + info.primes_size); + } +} + + +/// TODO: Same as CoreNEURON? +void CodegenNeuronCppVisitor::print_global_variables_for_hoc() { + /// TODO: Write HocParmLimits and other HOC global variables (delta_t) + // Probably needs more changes + auto variable_printer = + [&](const std::vector& variables, bool if_array, bool if_vector) { + for (const auto& variable: variables) { + if (variable->is_array() == if_array) { + // false => do not use the instance struct, which is not + // defined in the global declaration that we are printing + auto name = get_variable_name(variable->get_name(), false); + auto ename = add_escape_quote(variable->get_name() + "_" + info.mod_suffix); + auto length = variable->get_length(); + if (if_vector) { + printer->fmt_line("{{{}, {}, {}}},", ename, name, length); + } else { + printer->fmt_line("{{{}, &{}}},", ename, name); + } + } + } + }; + + auto globals = info.global_variables; + auto thread_vars = info.thread_variables; + + if (info.table_count > 0) { + globals.push_back(make_symbol(naming::USE_TABLE_VARIABLE)); + } + + printer->add_newline(2); + printer->add_line("/** connect global (scalar) variables to hoc -- */"); + printer->add_line("static DoubScal hoc_scalar_double[] = {"); + printer->increase_indent(); + variable_printer(globals, false, false); + variable_printer(thread_vars, false, false); + printer->add_line("{nullptr, nullptr}"); + printer->decrease_indent(); + printer->add_line("};"); + + printer->add_newline(2); + printer->add_line("/** connect global (array) variables to hoc -- */"); + printer->add_line("static DoubVec hoc_vector_double[] = {"); + printer->increase_indent(); + variable_printer(globals, true, true); + variable_printer(thread_vars, true, true); + printer->add_line("{nullptr, nullptr, 0}"); + printer->decrease_indent(); + printer->add_line("};"); +} + + +void CodegenNeuronCppVisitor::print_mechanism_register() { + /// TODO: Write this according to NEURON + printer->add_newline(2); + printer->add_line("/** register channel with the simulator */"); + printer->fmt_push_block("void _{}_reg()", info.mod_file); + print_sdlists_init(true); + // type related information + auto suffix = add_escape_quote(info.mod_suffix); + printer->add_newline(); + printer->fmt_line("int mech_type = nrn_get_mechtype({});", suffix); + + // More things to add here + printer->add_line("_nrn_mechanism_register_data_fields(_mechtype,"); + printer->increase_indent(); + const auto codegen_float_variables_size = codegen_float_variables.size(); + for (int i = 0; i < codegen_float_variables_size; ++i) { + const auto& float_var = codegen_float_variables[i]; + const auto print_comma = i < codegen_float_variables_size - 1 || info.emit_cvode; + if (float_var->is_array()) { + printer->fmt_line("_nrn_mechanism_field{{\"{}\", {}}} /* {} */{}", + float_var->get_name(), + float_var->get_length(), + i, + print_comma ? "," : ""); + } else { + printer->fmt_line("_nrn_mechanism_field{{\"{}\"}} /* {} */{}", + float_var->get_name(), + i, + print_comma ? "," : ""); + } + } + if (info.emit_cvode) { + printer->add_line("_nrn_mechanism_field{\"_cvode_ieq\", \"cvodeieq\"} /* 0 */"); + } + printer->decrease_indent(); + printer->add_line(");"); + printer->add_newline(); + printer->pop_block(); +} + + +void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) { + printer->add_line("/* NEURON RANGE variables macro definitions */"); + for (auto i = 0; i < codegen_float_variables.size(); ++i) { + const auto float_var = codegen_float_variables[i]; + if (float_var->is_array()) { + printer->add_line("#define ", + float_var->get_name(), + "(id) _ml->template data_array<", + std::to_string(i), + ", ", + std::to_string(float_var->get_length()), + ">(id)"); + } else { + printer->add_line("#define ", + float_var->get_name(), + "(id) _ml->template fpfield<", + std::to_string(i), + ">(id)"); + } + } +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type, + const std::string& function_name) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_constructor() { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_destructor() { + return; +} + + +/// TODO: Print the equivalent of `nrn_alloc_` +void CodegenNeuronCppVisitor::print_nrn_alloc() { + return; +} + + +/****************************************************************************************/ +/* Print nrn_state routine */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_state() { + if (!nrn_state_required()) { + return; + } + codegen = true; + + printer->add_line("void nrn_state() {}"); + /// TODO: Fill in + + codegen = false; +} + + +/****************************************************************************************/ +/* Print nrn_cur related routines */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_current(const BreakpointBlock& node) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_cur_conductance_kernel(const BreakpointBlock& node) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_cur_non_conductance_kernel() { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_cur_kernel(const BreakpointBlock& node) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_fast_imem_calculation() { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_nrn_cur() { + if (!nrn_cur_required()) { + return; + } + + codegen = true; + + printer->add_line("void nrn_cur() {}"); + /// TODO: Fill in + + codegen = false; +} + + +/****************************************************************************************/ +/* Main code printing entry points */ +/****************************************************************************************/ + +void CodegenNeuronCppVisitor::print_headers_include() { + print_standard_includes(); + print_neuron_includes(); +} + + +void CodegenNeuronCppVisitor::print_macro_definitions() { + print_global_macros(); + print_mechanism_variables_macros(); +} + + +void CodegenNeuronCppVisitor::print_global_macros() { + printer->add_newline(); + printer->add_line("/* NEURON global macro definitions */"); + if (info.vectorize) { + printer->add_multi_line(R"CODE( + /* VECTORIZED */ + #define NRN_VECTORIZED 1 + )CODE"); + } else { + printer->add_multi_line(R"CODE( + /* NOT VECTORIZED */ + #define NRN_VECTORIZED 0 + )CODE"); + } +} + + +void CodegenNeuronCppVisitor::print_mechanism_variables_macros() { + printer->add_newline(); + printer->add_line("static constexpr auto number_of_datum_variables = ", + std::to_string(int_variables_size()), + ";"); + printer->add_line("static constexpr auto number_of_floating_point_variables = ", + std::to_string(float_variables_size()), + ";"); + printer->add_newline(); + printer->add_multi_line(R"CODE( + namespace { + template + using _nrn_mechanism_std_vector = std::vector; + using _nrn_model_sorted_token = neuron::model_sorted_token; + using _nrn_mechanism_cache_range = neuron::cache::MechanismRange; + using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance; + template + using _nrn_mechanism_field = neuron::mechanism::field; + template + void _nrn_mechanism_register_data_fields(Args&&... args) { + neuron::mechanism::register_data_fields(std::forward(args)...); + } + } // namespace + )CODE"); + /// TODO: More prints here? +} + + +void CodegenNeuronCppVisitor::print_namespace_begin() { + print_namespace_start(); +} + + +void CodegenNeuronCppVisitor::print_namespace_end() { + print_namespace_stop(); +} + + +void CodegenNeuronCppVisitor::print_data_structures(bool print_initializers) { + print_mechanism_global_var_structure(print_initializers); + print_mechanism_range_var_structure(print_initializers); +} + + +void CodegenNeuronCppVisitor::print_v_unused() const { + if (!info.vectorize) { + return; + } + printer->add_multi_line(R"CODE( + #if NRN_PRCELLSTATE + inst->v_unused[id] = v; + #endif + )CODE"); +} + + +void CodegenNeuronCppVisitor::print_g_unused() const { + printer->add_multi_line(R"CODE( + #if NRN_PRCELLSTATE + inst->g_unused[id] = g; + #endif + )CODE"); +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_compute_functions() { + print_nrn_cur(); + print_nrn_state(); +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::print_codegen_routines() { + codegen = true; + print_backend_info(); + print_headers_include(); + print_macro_definitions(); + print_namespace_begin(); + print_nmodl_constants(); + print_prcellstate_macros(); + print_mechanism_info(); + print_data_structures(true); + print_global_variables_for_hoc(); + print_compute_functions(); // only nrn_cur and nrn_state + print_mechanism_register(); + print_namespace_end(); + codegen = false; +} + + +/****************************************************************************************/ +/* Overloaded visitor routines */ +/****************************************************************************************/ + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::visit_solution_expression(const SolutionExpression& node) { + return; +} + + +/// TODO: Edit for NEURON +void CodegenNeuronCppVisitor::visit_watch_statement(const ast::WatchStatement& /* node */) { + return; +} + + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp new file mode 100644 index 0000000000..e8bf7e36af --- /dev/null +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -0,0 +1,643 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * \dir + * \brief Code generation backend implementations for NEURON + * + * \file + * \brief \copybrief nmodl::codegen::CodegenNeuronCppVisitor + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "codegen/codegen_info.hpp" +#include "codegen/codegen_naming.hpp" +#include "printer/code_printer.hpp" +#include "symtab/symbol_table.hpp" +#include "utils/logger.hpp" +#include "visitors/ast_visitor.hpp" +#include + + +/// encapsulates code generation backend implementations +namespace nmodl { + +namespace codegen { + + +using printer::CodePrinter; + + +/** + * \defgroup codegen_backends Codegen Backends + * \ingroup codegen + * \brief Code generation backends for NEURON + * \{ + */ + +/** + * \class CodegenNeuronCppVisitor + * \brief %Visitor for printing C++ code compatible with legacy api of NEURON + * + * \todo + * - Handle define statement (i.e. macros) + * - If there is a return statement in the verbatim block + * of inlined function then it will be error. Need better + * error checking. For example, see netstim.mod where we + * have removed return from verbatim block. + */ +class CodegenNeuronCppVisitor: public CodegenCppVisitor { + protected: + /****************************************************************************************/ + /* Member variables */ + /****************************************************************************************/ + + + /****************************************************************************************/ + /* Generic information getters */ + /****************************************************************************************/ + + + /** + * Name of the simulator the code was generated for + */ + std::string simulator_name() override; + + + /** + * Name of the code generation backend + */ + virtual std::string backend_name() const override; + + + /****************************************************************************************/ + /* Common helper routines accross codegen functions */ + /****************************************************************************************/ + + + /** + * Determine the position in the data array for a given float variable + * \param name The name of a float variable + * \return The position index in the data array + */ + int position_of_float_var(const std::string& name) const override; + + + /** + * Determine the position in the data array for a given int variable + * \param name The name of an int variable + * \return The position index in the data array + */ + int position_of_int_var(const std::string& name) const override; + + + /****************************************************************************************/ + /* Backend specific routines */ + /****************************************************************************************/ + + + /** + * Print atomic update pragma for reduction statements + */ + virtual void print_atomic_reduction_pragma() override; + + + /****************************************************************************************/ + /* Printing routines for code generation */ + /****************************************************************************************/ + + + /** + * Print call to internal or external function + * \param node The AST node representing a function call + */ + void print_function_call(const ast::FunctionCall& node) override; + + + /** + * Print function and procedures prototype declaration + */ + void print_function_prototypes() override; + + + /** + * Print nmodl function or procedure (common code) + * \param node the AST node representing the function or procedure in NMODL + * \param name the name of the function or procedure + */ + void print_function_or_procedure(const ast::Block& node, const std::string& name) override; + + + /** + * Common helper function to help printing function or procedure blocks + * \param node the AST node representing the function or procedure in NMODL + */ + void print_function_procedure_helper(const ast::Block& node) override; + + + /** + * Print NMODL procedure in target backend code + * \param node + */ + virtual void print_procedure(const ast::ProcedureBlock& node) override; + + + /** + * Print NMODL function in target backend code + * \param node + */ + void print_function(const ast::FunctionBlock& node) override; + + + /****************************************************************************************/ + /* Code-specific helper routines */ + /****************************************************************************************/ + + + /** + * Arguments for functions that are defined and used internally. + * \return the method arguments + */ + std::string internal_method_arguments() override; + + + /** + * Parameters for internally defined functions + * \return the method parameters + */ + ParamVector internal_method_parameters() override; + + + /** + * Arguments for external functions called from generated code + * \return A string representing the arguments passed to an external function + */ + const char* external_method_arguments() noexcept override; + + + /** + * Parameters for functions in generated code that are called back from external code + * + * Functions registered in NEURON during initialization for callback must adhere to a prescribed + * calling convention. This method generates the string representing the function parameters for + * these externally called functions. + * \param table + * \return A string representing the parameters of the function + */ + const char* external_method_parameters(bool table = false) noexcept override; + + + /** + * Arguments for "_threadargs_" macro in neuron implementation + */ + std::string nrn_thread_arguments() const override; + + + /** + * Arguments for "_threadargs_" macro in neuron implementation + */ + std::string nrn_thread_internal_arguments() override; + + + /** + * Process a verbatim block for possible variable renaming + * \param text The verbatim code to be processed + * \return The code with all variables renamed as needed + */ + std::string process_verbatim_text(std::string const& text) override; + + + /** + * Arguments for register_mech or point_register_mech function + */ + std::string register_mechanism_arguments() const override; + + + /****************************************************************************************/ + /* Code-specific printing routines for code generations */ + /****************************************************************************************/ + + + /** + * Prints the start of the \c neuron namespace + */ + void print_namespace_start() override; + + + /** + * Prints the end of the \c neuron namespace + */ + void print_namespace_stop() override; + + + /****************************************************************************************/ + /* Routines for returning variable name */ + /****************************************************************************************/ + + + /** + * Determine the name of a \c float variable given its symbol + * + * This function typically returns the accessor expression in backend code for the given symbol. + * Since the model variables are stored in data arrays and accessed by offset, this function + * will return the C++ string representing the array access at the correct offset + * + * \param symbol The symbol of a variable for which we want to obtain its name + * \param use_instance Should the variable be accessed via instance or data array + * \return The backend code string representing the access to the given variable + * symbol + */ + std::string float_variable_name(const SymbolType& symbol, bool use_instance) const override; + + + /** + * Determine the name of an \c int variable given its symbol + * + * This function typically returns the accessor expression in backend code for the given symbol. + * Since the model variables are stored in data arrays and accessed by offset, this function + * will return the C++ string representing the array access at the correct offset + * + * \param symbol The symbol of a variable for which we want to obtain its name + * \param name The name of the index variable + * \param use_instance Should the variable be accessed via instance or data array + * \return The backend code string representing the access to the given variable + * symbol + */ + std::string int_variable_name(const IndexVariableInfo& symbol, + const std::string& name, + bool use_instance) const override; + + + /** + * Determine the variable name for a global variable given its symbol + * \param symbol The symbol of a variable for which we want to obtain its name + * \param use_instance Should the variable be accessed via the (host-only) + * global variable or the instance-specific copy (also available on GPU). + * \return The C++ string representing the access to the global variable + */ + std::string global_variable_name(const SymbolType& symbol, + bool use_instance = true) const override; + + + /** + * Determine variable name in the structure of mechanism properties + * + * \param name Variable name that is being printed + * \param use_instance Should the variable be accessed via instance or data array + * \return The C++ string representing the access to the variable in the neuron + * thread structure + */ + std::string get_variable_name(const std::string& name, bool use_instance = true) const override; + + + /****************************************************************************************/ + /* Main printing routines for code generation */ + /****************************************************************************************/ + + + /** + * Print top file header printed in generated code + */ + void print_backend_info() override; + + + /** + * Print standard C/C++ includes + */ + void print_standard_includes() override; + + + /** + * Print includes from NEURON + */ + void print_neuron_includes(); + + + void print_sdlists_init(bool print_initializers) override; + + + /** + * Print the structure that wraps all global variables used in the NMODL + * + * \param print_initializers Whether to include default values in the struct + * definition (true: int foo{42}; false: int foo;) + */ + void print_mechanism_global_var_structure(bool print_initializers) override; + + + /** + * Print byte arrays that register scalar and vector variables for hoc interface + * + */ + void print_global_variables_for_hoc() override; + + + /** + * Print the mechanism registration function + * + */ + void print_mechanism_register() override; + + + /** + * Print common code for global functions like nrn_init, nrn_cur and nrn_state + * \param type The target backend code block type + */ + virtual void print_global_function_common_code(BlockType type, + const std::string& function_name = "") override; + + + /** + * Print nrn_constructor function definition + * + */ + void print_nrn_constructor() override; + + + /** + * Print nrn_destructor function definition + * + */ + void print_nrn_destructor() override; + + + /** + * Print nrn_alloc function definition + * + */ + void print_nrn_alloc() override; + + + /****************************************************************************************/ + /* Print nrn_state routine */ + /****************************************************************************************/ + + + /** + * Print nrn_state / state update function definition + */ + void print_nrn_state() override; + + + /****************************************************************************************/ + /* Print nrn_cur related routines */ + /****************************************************************************************/ + + + /** + * Print the \c nrn_current kernel + * + * \note nrn_cur_kernel will have two calls to nrn_current if no conductance keywords specified + * \param node the AST node representing the NMODL breakpoint block + */ + void print_nrn_current(const ast::BreakpointBlock& node) override; + + + /** + * Print the \c nrn\_cur kernel with NMODL \c conductance keyword provisions + * + * If the NMODL \c conductance keyword is used in the \c breakpoint block, then + * CodegenCoreneuronCppVisitor::print_nrn_cur_kernel will use this printer + * + * \param node the AST node representing the NMODL breakpoint block + */ + void print_nrn_cur_conductance_kernel(const ast::BreakpointBlock& node) override; + + + /** + * Print the \c nrn\_cur kernel without NMODL \c conductance keyword provisions + * + * If the NMODL \c conductance keyword is \b not used in the \c breakpoint block, then + * CodegenCoreneuronCppVisitor::print_nrn_cur_kernel will use this printer + */ + void print_nrn_cur_non_conductance_kernel() override; + + + /** + * Print main body of nrn_cur function + * \param node the AST node representing the NMODL breakpoint block + */ + void print_nrn_cur_kernel(const ast::BreakpointBlock& node) override; + + + /** + * Print fast membrane current calculation code + */ + virtual void print_fast_imem_calculation() override; + + + /** + * Print nrn_cur / current update function definition + */ + void print_nrn_cur() override; + + + /****************************************************************************************/ + /* Main code printing entry points */ + /****************************************************************************************/ + + + /** + * Print all includes + * + */ + void print_headers_include() override; + + + /** + * Print all NEURON macros + * + */ + void print_macro_definitions(); + + + /** + * Print NEURON global variable macros + * + */ + void print_global_macros(); + + + /** + * Print mechanism variables' related macros + * + */ + void print_mechanism_variables_macros(); + + + /** + * Print start of namespaces + * + */ + void print_namespace_begin() override; + + + /** + * Print end of namespaces + * + */ + void print_namespace_end() override; + + + /** + * Print all classes + * \param print_initializers Whether to include default values. + */ + void print_data_structures(bool print_initializers) override; + + + /** + * Set v_unused (voltage) for NRN_PRCELLSTATE feature + */ + void print_v_unused() const override; + + + /** + * Set g_unused (conductance) for NRN_PRCELLSTATE feature + */ + void print_g_unused() const override; + + + /** + * Print all compute functions for every backend + * + */ + virtual void print_compute_functions() override; + + + /** + * Print entry point to code generation + * + */ + void print_codegen_routines() override; + + + /****************************************************************************************/ + /* Overloaded visitor routines */ + /****************************************************************************************/ + + + virtual void visit_solution_expression(const ast::SolutionExpression& node) override; + virtual void visit_watch_statement(const ast::WatchStatement& node) override; + + + /** + * Print prototype declarations of functions or procedures + * \tparam T The AST node type of the node (must be of nmodl::ast::Ast or subclass) + * \param node The AST node representing the function or procedure block + * \param name A user defined name for the function + */ + template + void print_function_declaration(const T& node, const std::string& name); + + + public: + /** + * \brief Constructs the C++ code generator visitor + * + * This constructor instantiates an NMODL C++ code generator and allows writing generated code + * directly to a file in \c [output_dir]/[mod_filename].cpp. + * + * \note No code generation is performed at this stage. Since the code + * generator classes are all based on \c AstVisitor the AST must be visited using e.g. \c + * visit_program in order to generate the C++ code corresponding to the AST. + * + * \param mod_filename The name of the model for which code should be generated. + * It is used for constructing an output filename. + * \param output_dir The directory where target C++ file should be generated. + * \param float_type The float type to use in the generated code. The string will be used + * as-is in the target code. This defaults to \c double. + */ + CodegenNeuronCppVisitor(std::string mod_filename, + const std::string& output_dir, + std::string float_type, + const bool optimize_ionvar_copies) + : CodegenCppVisitor(mod_filename, output_dir, float_type, optimize_ionvar_copies) {} + + /** + * \copybrief nmodl::codegen::CodegenNeuronCppVisitor + * + * This constructor instantiates an NMODL C++ code generator and allows writing generated code + * into an output stream. + * + * \note No code generation is performed at this stage. Since the code + * generator classes are all based on \c AstVisitor the AST must be visited using e.g. \c + * visit_program in order to generate the C++ code corresponding to the AST. + * + * \param mod_filename The name of the model for which code should be generated. + * It is used for constructing an output filename. + * \param stream The output stream onto which to write the generated code + * \param float_type The float type to use in the generated code. The string will be used + * as-is in the target code. This defaults to \c double. + */ + CodegenNeuronCppVisitor(std::string mod_filename, + std::ostream& stream, + std::string float_type, + const bool optimize_ionvar_copies) + : CodegenCppVisitor(mod_filename, stream, float_type, optimize_ionvar_copies) {} + + + /****************************************************************************************/ + /* Public printing routines for code generation for use in unit tests */ + /****************************************************************************************/ + + + /** + * Print the structure that wraps all range and int variables required for the NMODL + * + * \param print_initializers Whether or not default values for variables + * be included in the struct declaration. + */ + void print_mechanism_range_var_structure(bool print_initializers) override; +}; + + +/** + * \details If there is an argument with name (say alpha) same as range variable (say alpha), + * we want to avoid it being printed as instance->alpha. And hence we disable variable + * name lookup during prototype declaration. Note that the name of procedure can be + * different in case of table statement. + */ +template +void CodegenNeuronCppVisitor::print_function_declaration(const T& node, const std::string& name) { + enable_variable_name_lookup = false; + auto type = default_float_data_type(); + + // internal and user provided arguments + auto internal_params = internal_method_parameters(); + const auto& params = node.get_parameters(); + for (const auto& param: params) { + internal_params.emplace_back("", type, "", param.get()->get_node_name()); + } + + // procedures have "int" return type by default + const char* return_type = "int"; + if (node.is_function_block()) { + return_type = default_float_data_type(); + } + + /// TODO: Edit for NEURON + printer->add_indent(); + printer->fmt_text("inline {} {}({})", return_type, method_name(name), "params"); + + enable_variable_name_lookup = true; +} + +/** \} */ // end of codegen_backends + +} // namespace codegen +} // namespace nmodl diff --git a/src/main.cpp b/src/main.cpp index 87a682b2c0..cbd1e62c37 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -14,7 +14,8 @@ #include "ast/program.hpp" #include "codegen/codegen_acc_visitor.hpp" #include "codegen/codegen_compatibility_visitor.hpp" -#include "codegen/codegen_cpp_visitor.hpp" +#include "codegen/codegen_coreneuron_cpp_visitor.hpp" +#include "codegen/codegen_neuron_cpp_visitor.hpp" #include "codegen/codegen_transform_visitor.hpp" #include "config/config.h" #include "parser/nmodl_driver.hpp" @@ -69,8 +70,14 @@ int main(int argc, const char* argv[]) { /// true if debug logger statements should be shown std::string verbose("info"); - /// true if serial c code to be generated - bool c_backend(true); + /// true if code is to be generated for NEURON + bool neuron_code(false); + + /// true if code is to be generated for CoreNEURON + bool coreneuron_code(true); + + /// true if serial C++ code to be generated + bool cpp_backend(true); /// true if c code with openacc to be generated bool oacc_backend(false); @@ -174,16 +181,18 @@ int main(int argc, const char* argv[]) { app.add_option("--units", units_dir, "Directory of units lib file") ->capture_default_str() ->ignore_case(); + app.add_flag("--neuron", neuron_code, "Generate C++ code for NEURON"); + app.add_flag("--coreneuron", coreneuron_code, "Generate C++ code for CoreNEURON (Default)"); auto host_opt = app.add_subcommand("host", "HOST/CPU code backends")->ignore_case(); - host_opt->add_flag("--c", c_backend, fmt::format("C/C++ backend ({})", c_backend)) + host_opt->add_flag("--c,--cpp", cpp_backend, fmt::format("C++ backend ({})", cpp_backend)) ->ignore_case(); auto acc_opt = app.add_subcommand("acc", "Accelerator code backends")->ignore_case(); acc_opt ->add_flag("--oacc", oacc_backend, - fmt::format("C/C++ backend with OpenACC ({})", oacc_backend)) + fmt::format("C++ backend with OpenACC ({})", oacc_backend)) ->ignore_case(); // clang-format off @@ -520,8 +529,8 @@ int main(int argc, const char* argv[]) { } { - if (oacc_backend) { - logger->info("Running OpenACC backend code generator"); + if (coreneuron_code && oacc_backend) { + logger->info("Running OpenACC backend code generator for CoreNEURON"); CodegenAccVisitor visitor(modfile, output_dir, data_type, @@ -529,14 +538,30 @@ int main(int argc, const char* argv[]) { visitor.visit_program(*ast); } - else if (c_backend) { - logger->info("Running C++ backend code generator"); - CodegenCppVisitor visitor(modfile, - output_dir, - data_type, - optimize_ionvar_copies_codegen); + else if (coreneuron_code && !neuron_code && cpp_backend) { + logger->info("Running C++ backend code generator for CoreNEURON"); + CodegenCoreneuronCppVisitor visitor(modfile, + output_dir, + data_type, + optimize_ionvar_copies_codegen); visitor.visit_program(*ast); } + + else if (neuron_code && cpp_backend) { + logger->info("Running C++ backend code generator for NEURON"); + CodegenNeuronCppVisitor visitor(modfile, + output_dir, + data_type, + optimize_ionvar_copies_codegen); + visitor.visit_program(*ast); + } + + else { + throw std::runtime_error( + "Non valid code generation configuration. Code generation with NMODL is " + "supported for NEURON with C++ backend or CoreNEURON with C++/OpenACC " + "backends"); + } } } diff --git a/src/utils/common_utils.hpp b/src/utils/common_utils.hpp index 2508990917..147d46099e 100644 --- a/src/utils/common_utils.hpp +++ b/src/utils/common_utils.hpp @@ -63,7 +63,7 @@ std::string generate_random_string(int len, UseNumbersInString use_numbers); * Singleton class for random strings that are appended to the * Eigen matrices names that are used in the solutions of * nmodl::visitor::SympySolverVisitor and need to be the same to - * be printed by the nmodl::codegen::CodegenCppVisitor + * be printed by the nmodl::codegen::CodegenCoreneuronCppVisitor */ template class SingletonRandomString { diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 43da4cfa2c..7d09cb9b45 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -68,8 +68,13 @@ add_executable(testunitlexer units/lexer.cpp) add_executable(testunitparser units/parser.cpp) add_executable( testcodegen - codegen/main.cpp codegen/codegen_helper.cpp codegen/codegen_utils.cpp - codegen/codegen_cpp_visitor.cpp codegen/transform.cpp codegen/codegen_compatibility_visitor.cpp) + codegen/main.cpp + codegen/codegen_helper.cpp + codegen/codegen_utils.cpp + codegen/codegen_coreneuron_cpp_visitor.cpp + codegen/codegen_neuron_cpp_visitor.cpp + codegen/transform.cpp + codegen/codegen_compatibility_visitor.cpp) target_link_libraries(testmodtoken PRIVATE lexer util) target_link_libraries(testlexer PRIVATE lexer util) diff --git a/test/unit/codegen/codegen_cpp_visitor.cpp b/test/unit/codegen/codegen_coreneuron_cpp_visitor.cpp similarity index 95% rename from test/unit/codegen/codegen_cpp_visitor.cpp rename to test/unit/codegen/codegen_coreneuron_cpp_visitor.cpp index be30ce24d6..ec98b5b076 100644 --- a/test/unit/codegen/codegen_cpp_visitor.cpp +++ b/test/unit/codegen/codegen_coreneuron_cpp_visitor.cpp @@ -10,7 +10,7 @@ #include "ast/program.hpp" #include "codegen/codegen_acc_visitor.hpp" -#include "codegen/codegen_cpp_visitor.hpp" +#include "codegen/codegen_coreneuron_cpp_visitor.hpp" #include "codegen/codegen_helper_visitor.hpp" #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" @@ -32,9 +32,10 @@ using nmodl::parser::NmodlDriver; using nmodl::test_utils::reindent_text; /// Helper for creating C codegen visitor -std::shared_ptr create_c_visitor(const std::shared_ptr& ast, - const std::string& /* text */, - std::stringstream& ss) { +std::shared_ptr create_coreneuron_cpp_visitor( + const std::shared_ptr& ast, + const std::string& /* text */, + std::stringstream& ss) { /// construct symbol table SymtabVisitor().visit_program(*ast); @@ -44,7 +45,7 @@ std::shared_ptr create_c_visitor(const std::shared_ptr("temp.mod", ss, "double", false); + auto cv = std::make_shared("temp.mod", ss, "double", false); cv->setup(*ast); return cv; } @@ -71,20 +72,21 @@ std::shared_ptr create_acc_visitor(const std::shared_ptrprint_instance_variable_setup(); return reindent_text(ss.str()); } /// print entire code -std::string get_cpp_code(const std::string& nmodl_text, const bool generate_gpu_code = false) { +std::string get_coreneuron_cpp_code(const std::string& nmodl_text, + const bool generate_gpu_code = false) { const auto& ast = NmodlDriver().parse_string(nmodl_text); std::stringstream ss; if (generate_gpu_code) { auto accvisitor = create_acc_visitor(ast, nmodl_text, ss); accvisitor->visit_program(*ast); } else { - auto cvisitor = create_c_visitor(ast, nmodl_text, ss); + auto cvisitor = create_coreneuron_cpp_visitor(ast, nmodl_text, ss); cvisitor->visit_program(*ast); } return reindent_text(ss.str()); @@ -311,7 +313,7 @@ std::string get_instance_structure(std::string nmodl_text) { PerfVisitor{}.visit_program(*ast); // setup codegen std::stringstream ss{}; - CodegenCppVisitor cv{"temp.mod", ss, "double", false}; + CodegenCoreneuronCppVisitor cv{"temp.mod", ss, "double", false}; cv.setup(*ast); cv.print_mechanism_range_var_structure(true); return ss.str(); @@ -442,7 +444,7 @@ SCENARIO("Check code generation for TABLE statements", "[codegen][array_variable } )"; THEN("Array and global variables should be correctly generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); REQUIRE_THAT(generated, ContainsSubstring("double t_inf[2][201]{};")); REQUIRE_THAT(generated, ContainsSubstring("double t_tau[201]{};")); @@ -479,7 +481,7 @@ SCENARIO("Check code generation for TABLE statements", "[codegen][array_variable } )"; THEN("It should throw") { - REQUIRE_THROWS(get_cpp_code(nmodl_text)); + REQUIRE_THROWS(get_coreneuron_cpp_code(nmodl_text)); } } } @@ -514,7 +516,7 @@ SCENARIO("Check that BEFORE/AFTER block are well generated", "[codegen][before/a } )"; THEN("They should be well registered") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); // BEFORE BREAKPOINT { REQUIRE_THAT(generated, @@ -643,7 +645,7 @@ SCENARIO("Check that BEFORE/AFTER block are well generated", "[codegen][before/a AFTER SOLVE {} )"; THEN("They should be all registered") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); REQUIRE_THAT(generated, ContainsSubstring("hoc_reg_ba(mech_type, nrn_before_after_0_ba1, " "BAType::Before + BAType::Step);")); @@ -678,7 +680,7 @@ SCENARIO("Check CONSTANT variables are added to global variable structure", } )"; THEN("The global struct should contain these variables") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code = R"( struct CONST_Store { int reset{}; @@ -702,7 +704,7 @@ SCENARIO("Check code generation for FUNCTION_TABLE block", "[codegen][function_t FUNCTION_TABLE uuu(l, k) )"; THEN("Code should be generated correctly") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); REQUIRE_THAT(generated, ContainsSubstring("double ttt_glia(")); REQUIRE_THAT(generated, ContainsSubstring("double table_ttt_glia(")); REQUIRE_THAT(generated, @@ -737,7 +739,7 @@ SCENARIO("Check that loops are well generated", "[codegen][loops]") { })"; THEN("Correct code is generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code = R"(double a, b; if (a == 1.0) { b = 5.0; @@ -774,7 +776,7 @@ SCENARIO("Check that top verbatim blocks are well generated", "[codegen][top ver )"; THEN("Correct code is generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code = R"(using namespace coreneuron; @@ -803,7 +805,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even )"; THEN("Correct code is generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string cpu_net_send_expected_code = R"(static inline void net_send_buffering(const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, int weight_index, int point_index, double t, double flag) { int i = 0; @@ -821,7 +823,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even } })"; REQUIRE_THAT(generated, ContainsSubstring(cpu_net_send_expected_code)); - auto const gpu_generated = get_cpp_code(nmodl_text, true); + auto const gpu_generated = get_coreneuron_cpp_code(nmodl_text, true); std::string gpu_net_send_expected_code = R"(static inline void net_send_buffering(const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, int weight_index, int point_index, double t, double flag) { int i = 0; @@ -941,7 +943,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even } )"; THEN("It should generate a net_init") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code = R"(static void net_init(Point_process* pnt, int weight_index, double flag) { int tid = pnt->_tid; @@ -973,7 +975,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even } )"; THEN("It should generate a net_send_buffering with weight_index as parameter variable") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code( "net_send_buffering(nt, ml->_net_send_buffer, 0, inst->tqitem[0*pnodecount+id], " "weight_index, point_process, nt->_t+5.0, 1.0);"); @@ -988,7 +990,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even } )"; THEN("It should generate a net_send_buffering with weight_index parameter as 0") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code( "net_send_buffering(nt, ml->_net_send_buffer, 0, inst->tqitem[0*pnodecount+id], 0, " "point_process, nt->_t+5.0, 1.0);"); @@ -1005,7 +1007,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even } )"; THEN("New code is generated for for_netcons") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string net_receive_kernel_expected_code = R"(static inline void net_receive_kernel_(double t, Point_process* pnt, _Instance* inst, NrnThread* nt, Memb_list* ml, int weight_index, double flag) { int tid = pnt->_tid; @@ -1043,7 +1045,7 @@ SCENARIO("Check that codegen generate event functions well", "[codegen][net_even } )"; THEN("It should throw") { - REQUIRE_THROWS(get_cpp_code(nmodl_text)); + REQUIRE_THROWS(get_coreneuron_cpp_code(nmodl_text)); } } } @@ -1063,7 +1065,7 @@ SCENARIO("Some tests on derivimplicit", "[codegen][derivimplicit_solver]") { } )"; THEN("Correct code is generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string newton_state_expected_code = R"(namespace { struct _newton_state_ { int operator()(int id, int pnodecount, double* data, Datum* indexes, ThreadDatum* thread, NrnThread* nt, Memb_list* ml, double v) const { @@ -1128,7 +1130,7 @@ SCENARIO("Some tests on euler solver", "[codegen][euler_solver]") { } )"; THEN("Correct code is generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string nrn_state_expected_code = R"(inst->Dm[id] = 2.0 * inst->m[id]; inf = inf * 3.0; inst->Dn[id] = (2.0 + inst->m[id] - inf) * inst->n[id]; @@ -1163,7 +1165,7 @@ SCENARIO("Check codegen for MUTEX and PROTECT", "[codegen][mutex_protect]") { )"; THEN("Code with OpenMP critical sections is generated") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); // critical section for the mutex block std::string expected_code_initial = R"(#pragma omp critical (TEST) { @@ -1198,7 +1200,7 @@ SCENARIO("Array STATE variable", "[codegen][array_state]") { )"; THEN("nrn_init is printed with proper initialization of the whole array") { - auto const generated = get_cpp_code(nmodl_text); + auto const generated = get_coreneuron_cpp_code(nmodl_text); std::string expected_code_init = R"(/** initialize channel */ void nrn_init_ca_test(NrnThread* nt, Memb_list* ml, int type) { diff --git a/test/unit/codegen/codegen_neuron_cpp_visitor.cpp b/test/unit/codegen/codegen_neuron_cpp_visitor.cpp new file mode 100644 index 0000000000..b1d2a08a23 --- /dev/null +++ b/test/unit/codegen/codegen_neuron_cpp_visitor.cpp @@ -0,0 +1,255 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include "ast/program.hpp" +#include "codegen/codegen_neuron_cpp_visitor.hpp" +#include "parser/nmodl_driver.hpp" +#include "test/unit/utils/test_utils.hpp" +#include "visitors/inline_visitor.hpp" +#include "visitors/neuron_solve_visitor.hpp" +#include "visitors/solve_block_visitor.hpp" +#include "visitors/symtab_visitor.hpp" + +using Catch::Matchers::ContainsSubstring; + +using namespace nmodl; +using namespace visitor; +using namespace codegen; + +using nmodl::parser::NmodlDriver; +using nmodl::test_utils::reindent_text; + +/// Helper for creating C codegen visitor +std::shared_ptr create_neuron_cpp_visitor( + const std::shared_ptr& ast, + const std::string& /* text */, + std::stringstream& ss) { + /// construct symbol table + SymtabVisitor().visit_program(*ast); + + /// run all necessary pass + InlineVisitor().visit_program(*ast); + NeuronSolveVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + + /// create C code generation visitor + auto cv = std::make_shared("_test", ss, "double", false); + cv->setup(*ast); + return cv; +} + + +/// print entire code +std::string get_neuron_cpp_code(const std::string& nmodl_text, + const bool generate_gpu_code = false) { + const auto& ast = NmodlDriver().parse_string(nmodl_text); + std::stringstream ss; + auto cvisitor = create_neuron_cpp_visitor(ast, nmodl_text, ss); + cvisitor->visit_program(*ast); + return reindent_text(ss.str()); +} + + +SCENARIO("Check NEURON codegen for simple MOD file", "[codegen][neuron_boilerplate]") { + GIVEN("A simple mod file with RANGE, ARRAY and ION variables") { + std::string const nmodl_text = R"( + TITLE unit test based on passive membrane channel + + UNITS { + (mV) = (millivolt) + (mA) = (milliamp) + (S) = (siemens) + } + + NEURON { + SUFFIX pas_test + USEION na READ ena WRITE ina + NONSPECIFIC_CURRENT i + RANGE g, e + RANGE ar + } + + PARAMETER { + g = .001 (S/cm2) <0,1e9> + e = -70 (mV) + } + + ASSIGNED { + v (mV) + i (mA/cm2) + ena (mV) + ina (mA/cm2) + ar[2] + } + + INITIAL { + ar[0] = 1 + } + + BREAKPOINT { + SOLVE states METHOD cnexp + i = g*(v - e) + ina = g*(v - ena) + } + + STATE { + s + } + + DERIVATIVE states { + s' = ar[0] + } + )"; + auto const reindent_and_trim_text = [](const auto& text) { + return reindent_text(stringutils::trim(text)); + }; + auto const generated = reindent_and_trim_text(get_neuron_cpp_code(nmodl_text)); + THEN("Correct includes are printed") { + std::string expected_includes = R"(#include +#include +#include + +#include "mech_api.h" +#include "neuron/cache/mechanism_range.hpp" +#include "nrniv_mf.h" +#include "section_fwd.hpp")"; + + REQUIRE_THAT(generated, ContainsSubstring(reindent_and_trim_text(expected_includes))); + } + THEN("Correct number of variables are printed") { + std::string expected_num_variables = + R"(static constexpr auto number_of_datum_variables = 3; +static constexpr auto number_of_floating_point_variables = 10;)"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_num_variables))); + } + THEN("Correct using-directives are printed ") { + std::string expected_using_directives = R"(namespace { +template +using _nrn_mechanism_std_vector = std::vector; +using _nrn_model_sorted_token = neuron::model_sorted_token; +using _nrn_mechanism_cache_range = neuron::cache::MechanismRange; +using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance; +template +using _nrn_mechanism_field = neuron::mechanism::field; +template +void _nrn_mechanism_register_data_fields(Args&&... args) { + neuron::mechanism::register_data_fields(std::forward(args)...); +} +} // namespace)"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_using_directives))); + } + THEN("Correct namespace is printed") { + std::string expected_namespace = R"(namespace neuron {)"; + + REQUIRE_THAT(generated, ContainsSubstring(reindent_and_trim_text(expected_namespace))); + } + THEN("Correct channel information are printed") { + std::string expected_channel_info = R"(/** channel information */ + static const char *mechanism[] = { + "6.2.0", + "pas_test", + "g_pas_test", + "e_pas_test", + 0, + "i_pas_test", + "ar_pas_test[2]", + 0, + "s_pas_test", + 0, + 0 + };)"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_channel_info))); + } + THEN("Correct global variables are printed") { + std::string expected_global_variables = + R"(static neuron::container::field_index _slist1[1], _dlist1[1];)"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_global_variables))); + } + THEN("Correct range variables' macros are printed") { + std::string expected_range_macros = R"(/* NEURON RANGE variables macro definitions */ + #define g(id) _ml->template fpfield<0>(id) + #define e(id) _ml->template fpfield<1>(id) + #define i(id) _ml->template fpfield<2>(id) + #define ar(id) _ml->template data_array<3, 2>(id) + #define s(id) _ml->template fpfield<4>(id) + #define ena(id) _ml->template fpfield<5>(id) + #define ina(id) _ml->template fpfield<6>(id) + #define Ds(id) _ml->template fpfield<7>(id) + #define v_unused(id) _ml->template fpfield<8>(id) + #define g_unused(id) _ml->template fpfield<9>(id))"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_range_macros))); + } + THEN("Correct HOC global variables are printed") { + std::string expected_hoc_global_variables = + R"(/** connect global (scalar) variables to hoc -- */ + static DoubScal hoc_scalar_double[] = { + {nullptr, nullptr} + }; + + + /** connect global (array) variables to hoc -- */ + static DoubVec hoc_vector_double[] = { + {nullptr, nullptr, 0} + };)"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_hoc_global_variables))); + } + THEN("Placeholder nrn_cur function is printed") { + std::string expected_placeholder_nrn_cur = R"(void nrn_cur() {})"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_cur))); + } + THEN("Placeholder nrn_state function is printed") { + std::string expected_placeholder_nrn_state = R"(void nrn_state() {})"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_state))); + } + THEN("Placeholder registration function is printed") { + std::string expected_placeholder_reg = R"(/** register channel with the simulator */ + void __test_reg() { + /* s */ + _slist1[0] = {4, 0} + /* Ds */ + _dlist1[0] = {7, 0} + + int mech_type = nrn_get_mechtype("pas_test"); + _nrn_mechanism_register_data_fields(_mechtype, + _nrn_mechanism_field{"g"} /* 0 */, + _nrn_mechanism_field{"e"} /* 1 */, + _nrn_mechanism_field{"i"} /* 2 */, + _nrn_mechanism_field{"ar", 2} /* 3 */, + _nrn_mechanism_field{"s"} /* 4 */, + _nrn_mechanism_field{"ena"} /* 5 */, + _nrn_mechanism_field{"ina"} /* 6 */, + _nrn_mechanism_field{"Ds"} /* 7 */, + _nrn_mechanism_field{"v_unused"} /* 8 */, + _nrn_mechanism_field{"g_unused"} /* 9 */ + ); + + })"; + + REQUIRE_THAT(generated, + ContainsSubstring(reindent_and_trim_text(expected_placeholder_reg))); + } + } +} diff --git a/test/unit/visitor/sympy_solver.cpp b/test/unit/visitor/sympy_solver.cpp index 18fb307b39..456f897564 100644 --- a/test/unit/visitor/sympy_solver.cpp +++ b/test/unit/visitor/sympy_solver.cpp @@ -9,7 +9,7 @@ #include #include "ast/program.hpp" -#include "codegen/codegen_cpp_visitor.hpp" +#include "codegen/codegen_coreneuron_cpp_visitor.hpp" #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/checkparent_visitor.hpp" @@ -2237,12 +2237,13 @@ SCENARIO("Solve KINETIC block using SympySolver Visitor", "[visitor][solver][sym } /// Helper for creating C codegen visitor -std::shared_ptr create_c_visitor(const std::shared_ptr& ast, - const std::string& /* text */, - std::stringstream& ss, - bool inline_visitor = true, - bool pade = false, - bool cse = false) { +std::shared_ptr create_coreneuron_cpp_visitor( + const std::shared_ptr& ast, + const std::string& /* text */, + std::stringstream& ss, + bool inline_visitor = true, + bool pade = false, + bool cse = false) { /// construct symbol table SymtabVisitor().visit_program(*ast); @@ -2264,14 +2265,14 @@ std::shared_ptr create_c_visitor(const std::shared_ptr("temp.mod", ss, "double", false); + auto cv = std::make_shared("temp.mod", ss, "double", false); cv->setup(*ast); return cv; } @@ -2280,7 +2281,7 @@ std::shared_ptr create_c_visitor(const std::shared_ptrvisit_program(*ast); auto generated_string = ss.str(); return reindent_text(generated_string);