Skip to content

Commit

Permalink
Cleanup of source model function documentation.
Browse files Browse the repository at this point in the history
There were a few typos and edits that needed to be made to tidy this up.

PiperOrigin-RevId: 709374546
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent ba3ae6d commit dcd279b
Showing 1 changed file with 43 additions and 31 deletions.
74 changes: 43 additions & 31 deletions docs/model_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ Adding a source model implementation
TORAX sources are located in |torax.sources|_ and described in
:ref:`structure-sources`.

TORAX sources can be run in 3 modes (see :ref:`configuration`) for more details:
TORAX sources can be run in 3 modes (see :ref:`configuration` for more details):

* ZERO
* PRESCRIBED
* MODEL

In model_based mode, the source uses a given model to generate profile data.
All TORAX sources come with a default model.
(see the |torax.sources|_ module or the API docs for the complete list)
In MODEL mode, the source uses a given model to generate profile data.
All TORAX sources come with a default model
(see the |torax.sources|_ module or the API docs for the complete list).

TORAX provides support for using a custom model for a given source. If you want
to use a different model, you can do so by registering a new model
to use a custom model, you can do so by registering a new model
implementation against one of the sources supported by TORAX.

Below we describe how to do this with an example. To do so, you must:
Below we describe how to do this with an example. In short, to do so you must:

* Create a model function that follows the ``SourceProfileFunction`` interface.

Expand All @@ -34,20 +34,20 @@ Below we describe how to do this with an example. To do so, you must:
class SourceProfileFunction(Protocol):
"""Sources implement these functions to be able to provide source profiles."""
def __call__(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
...
def __call__(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
...
* (Optionally create runtime parameter configuration for the model function.)

* Register the model function any new runtime parameter configuration.
* Register the model function (and any new runtime parameter configuration).

Once the above is done, you can use the new model in your TORAX run by
specifying the name of the model in the config dictionary alongside any new
Expand All @@ -56,9 +56,9 @@ runtime parameter configuration.
Example
=======

For example if we wanted to register a new model implementation for the
``ion_cyclotron_source`` in TORAX which has an additional dynamic runtime
parameter ``my_new_param``:
Here is an example if we wanted to register a new model implementation for the
``IonCyclotronSource`` in TORAX which requires an additional dynamic runtime
parameter ``my_new_param``.

Defining a new source model implementation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -69,32 +69,40 @@ parameters format and what is supported by TORAX.)

.. code-block:: python
import chex
import dataclasses
from torax import array_typing
from torax import interpolated_param
from torax import state
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
# This inherits from the default source runtime parameters.
@dataclasses.dataclass
class CustomRuntimeParams(runtime_params_lib.RuntimeParams):
# Custom time interpolated parameter.
my_new_param: runtime_params_lib.TimeInterpolatedInput = 1.0
def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
) -> RuntimeParamsProvider:
return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh))
def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
) -> RuntimeParamsProvider:
return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh))
@chex.dataclass
class CustomRuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider):
"""Provides runtime parameters for a given time and geometry."""
runtime_params_config: GenericParticleSourceRuntimeParams
runtime_params_config: CustomRuntimeParams
my_new_param: interpolated_param.InterpolatedVarSingleAxis
def build_dynamic_params(
self,
t: chex.Numeric,
) -> DynamicParticleRuntimeParams:
return DynamicRuntimeParams(
) -> CustomDynamicRuntimeParams:
return CustomDynamicRuntimeParams(
my_new_param=self.my_new_param.get_value(t),
prescribed_values=self.prescribed_values.get_value(t),
)
Expand All @@ -104,6 +112,7 @@ parameters format and what is supported by TORAX.)
class CustomDynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
my_new_param: array_typing.ScalarFloat
# Define a custom model function.
def my_new_model(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
Expand All @@ -130,18 +139,21 @@ Then we register the new model function and runtime parameters.
# This method must be called to register the source before starting your TORAX
# run so that the new model is discoverable to TORAX.
register_sources.register_model_function(
# Matches IonCyclotronSource.SOURCE_NAME.
source_name='ion_cyclotron_source',
# The model function name is arbitrary, but must be unique for a source.
# It is used to identify the model function for a given source by TORAX.
# We follow the convention of using the name of the model function as the
# model name but you can use any string here.
model_function_name='my_new_model',
# Reference to the model function we just defined.
model_function=my_new_model,
runtime_params_class=RuntimeParams,
# Reference to the runtime parameters class we just defined.
runtime_params_class=CustomRuntimeParams,
)
If you don't have any custom runtime parameters, you can simply omit the
`runtime_params_class` argument and then default source runtime parameters
``runtime_params_class`` argument and then default source runtime parameters
will be used.

Using a new source model implementation
Expand All @@ -163,7 +175,7 @@ it as you would the existing source model implementations.
Adding a transport model
------------------------
************************
TORAX transport models are located in |torax.transport_model|_ and described
in :ref:`structure-transport-model`. TORAX comes with several transport models
Expand Down

0 comments on commit dcd279b

Please sign in to comment.