Skip to content

Commit

Permalink
Add documentation on how to add new source model functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708329584
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent 22ce670 commit c569e6c
Showing 1 changed file with 153 additions and 7 deletions.
160 changes: 153 additions & 7 deletions docs/model_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,162 @@ How to integrate new models

This page shows how to extend TORAX with new models.

Adding a source model
---------------------
Adding a source model implementation
************************************

TORAX sources are located in |torax.sources|_ and described in
:ref:`structure-sources`. While TORAX comes with several default sources
(see the |torax.sources|_ module or the API docs for the complete list),
users can both (a) configure the existing sources with custom models and
(b) add new source models to TORAX.
:ref:`structure-sources`.

TORAX sources can be run in 3 modes (see :ref:`configuration`) for more details:

* ZERO
* PRESCRIBED
* MODEL

In model_based mode, the source uses a given model to generate profile data.
All TORAX sources come with a default model.
(see the |torax.sources|_ module or the API docs for the complete list)

TORAX provides support for using a custom model for a given source. If you want
to use a different model, you can do so by registering a new model
implementation against one of the sources supported by TORAX.

Below we describe how to do this with an example. To do so, you must:

* Create a model function that follows the ``SourceProfileFunction`` interface.

.. code-block:: python
class SourceProfileFunction(Protocol):
"""Sources implement these functions to be able to provide source profiles."""
def __call__(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
...
* (Optionally create runtime parameter configuration for the model function.)

* Register the model function any new runtime parameter configuration.

Once the above is done, you can use the new model in your TORAX run by
specifying the name of the model in the config dictionary alongside any new
runtime parameter configuration.

Example
=======

For example if we wanted to register a new model implementation for the
``ion_cyclotron_source`` in TORAX which has an additional dynamic runtime
parameter ``my_new_param``:

Defining a new source model implementation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

First, we define a new model implementation (and optionally a new
``RuntimeParams``). (See :ref:`configuration` for more details on the runtime
parameters format and what is supported by TORAX.)

.. code-block:: python
from torax.sources import runtime_params as runtime_params_lib
# This inherits from the default source runtime parameters.
class CustomRuntimeParams(runtime_params_lib.RuntimeParams):
# Custom time interpolated parameter.
my_new_param: runtime_params_lib.TimeInterpolatedInput = 1.0
def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
) -> RuntimeParamsProvider:
return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh))
@chex.dataclass
class CustomRuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider):
"""Provides runtime parameters for a given time and geometry."""
runtime_params_config: GenericParticleSourceRuntimeParams
my_new_param: interpolated_param.InterpolatedVarSingleAxis
def build_dynamic_params(
self,
t: chex.Numeric,
) -> DynamicParticleRuntimeParams:
return DynamicRuntimeParams(
my_new_param=self.my_new_param.get_value(t),
prescribed_values=self.prescribed_values.get_value(t),
)
@chex.dataclass(frozen=True)
class CustomDynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
my_new_param: array_typing.ScalarFloat
def my_new_model(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
# To access the new runtime parameter, we index into the dynamic runtime
# params slice.
dynamic_source_params = dynamic_runtime_params_slice.sources['ion_cyclotron_source']
# Check the dynamic runtime params are the custom type we just defined.
assert isinstance(dynamic_source_params, CustomDynamicRuntimeParams)
my_new_param = dynamic_source_params.my_new_param
...
Then we register the new model function and runtime parameters.

.. code-block:: python
from torax.sources import register_sources
# This method must be called to register the source before starting your TORAX
# run so that the new model is discoverable to TORAX.
register_sources.register_model_function(
source_name='ion_cyclotron_source',
# The model function name is arbitrary, but must be unique for a source.
# It is used to identify the model function for a given source by TORAX.
# We follow the convention of using the name of the model function as the
# model name but you can use any string here.
model_function_name='my_new_model',
model_function=my_new_model,
runtime_params_class=RuntimeParams,
)
If you don't have any custom runtime parameters, you can simply omit the
`runtime_params_class` argument and then default source runtime parameters
will be used.

Using a new source model implementation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once you have created and registered the new model function you can then use
it as you would the existing source model implementations.

.. code-block:: python
CONFIG = {
'sources': {
'ion_cyclotron_source': {
'mode': 'model', # use the source in model mode.
'model_func': 'my_new_model', # matches name of registered model function.
'my_new_param': 2.0, # must match name of the runtime parameter.
},
...
}
Instructions for how to do this are under construction.
Adding a transport model
------------------------
Expand Down

0 comments on commit c569e6c

Please sign in to comment.