Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation on how to add new source model functions. #627

Merged
merged 1 commit into from
Dec 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 153 additions & 7 deletions docs/model_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,162 @@ How to integrate new models

This page shows how to extend TORAX with new models.

Adding a source model
---------------------
Adding a source model implementation
************************************

TORAX sources are located in |torax.sources|_ and described in
:ref:`structure-sources`. While TORAX comes with several default sources
(see the |torax.sources|_ module or the API docs for the complete list),
users can both (a) configure the existing sources with custom models and
(b) add new source models to TORAX.
:ref:`structure-sources`.

TORAX sources can be run in 3 modes (see :ref:`configuration`) for more details:

* ZERO
* PRESCRIBED
* MODEL

In model_based mode, the source uses a given model to generate profile data.
All TORAX sources come with a default model.
(see the |torax.sources|_ module or the API docs for the complete list)

TORAX provides support for using a custom model for a given source. If you want
to use a different model, you can do so by registering a new model
implementation against one of the sources supported by TORAX.

Below we describe how to do this with an example. To do so, you must:

* Create a model function that follows the ``SourceProfileFunction`` interface.

.. code-block:: python

class SourceProfileFunction(Protocol):
"""Sources implement these functions to be able to provide source profiles."""

def __call__(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
...

* (Optionally create runtime parameter configuration for the model function.)

* Register the model function any new runtime parameter configuration.

Once the above is done, you can use the new model in your TORAX run by
specifying the name of the model in the config dictionary alongside any new
runtime parameter configuration.

Example
=======

For example if we wanted to register a new model implementation for the
``ion_cyclotron_source`` in TORAX which has an additional dynamic runtime
parameter ``my_new_param``:

Defining a new source model implementation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

First, we define a new model implementation (and optionally a new
``RuntimeParams``). (See :ref:`configuration` for more details on the runtime
parameters format and what is supported by TORAX.)

.. code-block:: python

from torax.sources import runtime_params as runtime_params_lib

# This inherits from the default source runtime parameters.
class CustomRuntimeParams(runtime_params_lib.RuntimeParams):
# Custom time interpolated parameter.
my_new_param: runtime_params_lib.TimeInterpolatedInput = 1.0

def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
) -> RuntimeParamsProvider:
return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh))


@chex.dataclass
class CustomRuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider):
"""Provides runtime parameters for a given time and geometry."""

runtime_params_config: GenericParticleSourceRuntimeParams
my_new_param: interpolated_param.InterpolatedVarSingleAxis

def build_dynamic_params(
self,
t: chex.Numeric,
) -> DynamicParticleRuntimeParams:
return DynamicRuntimeParams(
my_new_param=self.my_new_param.get_value(t),
prescribed_values=self.prescribed_values.get_value(t),
)


@chex.dataclass(frozen=True)
class CustomDynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
my_new_param: array_typing.ScalarFloat

def my_new_model(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
# To access the new runtime parameter, we index into the dynamic runtime
# params slice.
dynamic_source_params = dynamic_runtime_params_slice.sources['ion_cyclotron_source']
# Check the dynamic runtime params are the custom type we just defined.
assert isinstance(dynamic_source_params, CustomDynamicRuntimeParams)
my_new_param = dynamic_source_params.my_new_param
...

Then we register the new model function and runtime parameters.

.. code-block:: python

from torax.sources import register_sources

# This method must be called to register the source before starting your TORAX
# run so that the new model is discoverable to TORAX.
register_sources.register_model_function(
source_name='ion_cyclotron_source',
# The model function name is arbitrary, but must be unique for a source.
# It is used to identify the model function for a given source by TORAX.
# We follow the convention of using the name of the model function as the
# model name but you can use any string here.
model_function_name='my_new_model',
model_function=my_new_model,
runtime_params_class=RuntimeParams,
)

If you don't have any custom runtime parameters, you can simply omit the
`runtime_params_class` argument and then default source runtime parameters
will be used.

Using a new source model implementation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once you have created and registered the new model function you can then use
it as you would the existing source model implementations.

.. code-block:: python

CONFIG = {
'sources': {
'ion_cyclotron_source': {
'mode': 'model', # use the source in model mode.
'model_func': 'my_new_model', # matches name of registered model function.
'my_new_param': 2.0, # must match name of the runtime parameter.
},
...
}

Instructions for how to do this are under construction.

Adding a transport model
------------------------
Expand Down
Loading