diff --git a/docs/model_integration.rst b/docs/model_integration.rst index 6640e8ec..983de589 100644 --- a/docs/model_integration.rst +++ b/docs/model_integration.rst @@ -5,16 +5,162 @@ How to integrate new models This page shows how to extend TORAX with new models. -Adding a source model ---------------------- +Adding a source model implementation +************************************ TORAX sources are located in |torax.sources|_ and described in -:ref:`structure-sources`. While TORAX comes with several default sources -(see the |torax.sources|_ module or the API docs for the complete list), -users can both (a) configure the existing sources with custom models and -(b) add new source models to TORAX. +:ref:`structure-sources`. + +TORAX sources can be run in 3 modes (see :ref:`configuration`) for more details: + +* ZERO +* PRESCRIBED +* MODEL + +In model_based mode, the source uses a given model to generate profile data. +All TORAX sources come with a default model. +(see the |torax.sources|_ module or the API docs for the complete list) + +TORAX provides support for using a custom model for a given source. If you want +to use a different model, you can do so by registering a new model +implementation against one of the sources supported by TORAX. + +Below we describe how to do this with an example. To do so, you must: + +* Create a model function that follows the ``SourceProfileFunction`` interface. + +.. code-block:: python + + class SourceProfileFunction(Protocol): + """Sources implement these functions to be able to provide source profiles.""" + + def __call__( + self, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + source_name: str, + core_profiles: state.CoreProfiles, + source_models: Optional['source_models.SourceModels'], + ) -> chex.ArrayTree: + ... + +* (Optionally create runtime parameter configuration for the model function.) + +* Register the model function any new runtime parameter configuration. + +Once the above is done, you can use the new model in your TORAX run by +specifying the name of the model in the config dictionary alongside any new +runtime parameter configuration. + +Example +======= + +For example if we wanted to register a new model implementation for the +``ion_cyclotron_source`` in TORAX which has an additional dynamic runtime +parameter ``my_new_param``: + +Defining a new source model implementation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, we define a new model implementation (and optionally a new +``RuntimeParams``). (See :ref:`configuration` for more details on the runtime +parameters format and what is supported by TORAX.) + +.. code-block:: python + + from torax.sources import runtime_params as runtime_params_lib + + # This inherits from the default source runtime parameters. + class CustomRuntimeParams(runtime_params_lib.RuntimeParams): + # Custom time interpolated parameter. + my_new_param: runtime_params_lib.TimeInterpolatedInput = 1.0 + + def make_provider( + self, + torax_mesh: geometry.Grid1D | None = None, + ) -> RuntimeParamsProvider: + return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) + + + @chex.dataclass + class CustomRuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider): + """Provides runtime parameters for a given time and geometry.""" + + runtime_params_config: GenericParticleSourceRuntimeParams + my_new_param: interpolated_param.InterpolatedVarSingleAxis + + def build_dynamic_params( + self, + t: chex.Numeric, + ) -> DynamicParticleRuntimeParams: + return DynamicRuntimeParams( + my_new_param=self.my_new_param.get_value(t), + prescribed_values=self.prescribed_values.get_value(t), + ) + + + @chex.dataclass(frozen=True) + class CustomDynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): + my_new_param: array_typing.ScalarFloat + + def my_new_model( + self, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + source_name: str, + core_profiles: state.CoreProfiles, + source_models: Optional['source_models.SourceModels'], + ) -> chex.ArrayTree: + # To access the new runtime parameter, we index into the dynamic runtime + # params slice. + dynamic_source_params = dynamic_runtime_params_slice.sources['ion_cyclotron_source'] + # Check the dynamic runtime params are the custom type we just defined. + assert isinstance(dynamic_source_params, CustomDynamicRuntimeParams) + my_new_param = dynamic_source_params.my_new_param + ... + +Then we register the new model function and runtime parameters. + +.. code-block:: python + + from torax.sources import register_sources + + # This method must be called to register the source before starting your TORAX + # run so that the new model is discoverable to TORAX. + register_sources.register_model_function( + source_name='ion_cyclotron_source', + # The model function name is arbitrary, but must be unique for a source. + # It is used to identify the model function for a given source by TORAX. + # We follow the convention of using the name of the model function as the + # model name but you can use any string here. + model_function_name='my_new_model', + model_function=my_new_model, + runtime_params_class=RuntimeParams, + ) + +If you don't have any custom runtime parameters, you can simply omit the +`runtime_params_class` argument and then default source runtime parameters +will be used. + +Using a new source model implementation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Once you have created and registered the new model function you can then use +it as you would the existing source model implementations. + +.. code-block:: python + + CONFIG = { + 'sources': { + 'ion_cyclotron_source': { + 'mode': 'model', # use the source in model mode. + 'model_func': 'my_new_model', # matches name of registered model function. + 'my_new_param': 2.0, # must match name of the runtime parameter. + }, + ... + } -Instructions for how to do this are under construction. Adding a transport model ------------------------