diff --git a/docs/model_integration.rst b/docs/model_integration.rst index 983de589..768e59f3 100644 --- a/docs/model_integration.rst +++ b/docs/model_integration.rst @@ -11,21 +11,21 @@ Adding a source model implementation TORAX sources are located in |torax.sources|_ and described in :ref:`structure-sources`. -TORAX sources can be run in 3 modes (see :ref:`configuration`) for more details: +TORAX sources can be run in 3 modes (see :ref:`configuration` for more details): * ZERO * PRESCRIBED * MODEL -In model_based mode, the source uses a given model to generate profile data. -All TORAX sources come with a default model. -(see the |torax.sources|_ module or the API docs for the complete list) +In MODEL mode, the source uses a given model to generate profile data. +All TORAX sources come with a default model +(see the |torax.sources|_ module or the API docs for the complete list). TORAX provides support for using a custom model for a given source. If you want -to use a different model, you can do so by registering a new model +to use a custom model, you can do so by registering a new model implementation against one of the sources supported by TORAX. -Below we describe how to do this with an example. To do so, you must: +Below we describe how to do this with an example. In short, to do so you must: * Create a model function that follows the ``SourceProfileFunction`` interface. @@ -34,20 +34,20 @@ Below we describe how to do this with an example. To do so, you must: class SourceProfileFunction(Protocol): """Sources implement these functions to be able to provide source profiles.""" - def __call__( - self, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - source_name: str, - core_profiles: state.CoreProfiles, - source_models: Optional['source_models.SourceModels'], - ) -> chex.ArrayTree: - ... + def __call__( + self, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + source_name: str, + core_profiles: state.CoreProfiles, + source_models: Optional['source_models.SourceModels'], + ) -> chex.ArrayTree: + ... * (Optionally create runtime parameter configuration for the model function.) -* Register the model function any new runtime parameter configuration. +* Register the model function (and any new runtime parameter configuration). Once the above is done, you can use the new model in your TORAX run by specifying the name of the model in the config dictionary alongside any new @@ -56,9 +56,9 @@ runtime parameter configuration. Example ======= -For example if we wanted to register a new model implementation for the -``ion_cyclotron_source`` in TORAX which has an additional dynamic runtime -parameter ``my_new_param``: +Here is an example if we wanted to register a new model implementation for the +``IonCyclotronSource`` in TORAX which requires an additional dynamic runtime +parameter ``my_new_param``. Defining a new source model implementation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -69,32 +69,40 @@ parameters format and what is supported by TORAX.) .. code-block:: python + import chex + import dataclasses + from torax import array_typing + from torax import interpolated_param + from torax import state + from torax.config import runtime_params_slice + from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib # This inherits from the default source runtime parameters. + @dataclasses.dataclass class CustomRuntimeParams(runtime_params_lib.RuntimeParams): # Custom time interpolated parameter. my_new_param: runtime_params_lib.TimeInterpolatedInput = 1.0 - def make_provider( - self, - torax_mesh: geometry.Grid1D | None = None, - ) -> RuntimeParamsProvider: - return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) + def make_provider( + self, + torax_mesh: geometry.Grid1D | None = None, + ) -> RuntimeParamsProvider: + return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) @chex.dataclass class CustomRuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider): """Provides runtime parameters for a given time and geometry.""" - runtime_params_config: GenericParticleSourceRuntimeParams + runtime_params_config: CustomRuntimeParams my_new_param: interpolated_param.InterpolatedVarSingleAxis def build_dynamic_params( self, t: chex.Numeric, - ) -> DynamicParticleRuntimeParams: - return DynamicRuntimeParams( + ) -> CustomDynamicRuntimeParams: + return CustomDynamicRuntimeParams( my_new_param=self.my_new_param.get_value(t), prescribed_values=self.prescribed_values.get_value(t), ) @@ -104,6 +112,7 @@ parameters format and what is supported by TORAX.) class CustomDynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): my_new_param: array_typing.ScalarFloat + # Define a custom model function. def my_new_model( self, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, @@ -130,18 +139,21 @@ Then we register the new model function and runtime parameters. # This method must be called to register the source before starting your TORAX # run so that the new model is discoverable to TORAX. register_sources.register_model_function( + # Matches IonCyclotronSource.SOURCE_NAME. source_name='ion_cyclotron_source', # The model function name is arbitrary, but must be unique for a source. # It is used to identify the model function for a given source by TORAX. # We follow the convention of using the name of the model function as the # model name but you can use any string here. model_function_name='my_new_model', + # Reference to the model function we just defined. model_function=my_new_model, - runtime_params_class=RuntimeParams, + # Reference to the runtime parameters class we just defined. + runtime_params_class=CustomRuntimeParams, ) If you don't have any custom runtime parameters, you can simply omit the -`runtime_params_class` argument and then default source runtime parameters +``runtime_params_class`` argument and then default source runtime parameters will be used. Using a new source model implementation @@ -163,7 +175,7 @@ it as you would the existing source model implementations. Adding a transport model ------------------------- +************************ TORAX transport models are located in |torax.transport_model|_ and described in :ref:`structure-transport-model`. TORAX comes with several transport models