Skip to content

Commit

Permalink
Allow Sources to accept multiple model functions.
Browse files Browse the repository at this point in the history
This modifies the sources registry to register a dict of model functions. These will be tied to runtime params and either autocreate a builder (or a given source builder class can be provided that will be used)

Follow up changes to:
- move all `function`s to be used through `model_func`.
- remove `function` and `affected_core_profiles`.
- add a `Protocol` for `SourceProfileFunctions` that will define the interface.
- open the registry to accept new model functions.

PiperOrigin-RevId: 706769008
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent c835d07 commit 18f5d0a
Show file tree
Hide file tree
Showing 32 changed files with 309 additions and 209 deletions.
1 change: 1 addition & 0 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def set_jax_precision():
'source_profiles',
'source_profile',
'explicit_source_profiles',
'model_func',
'source_models',
'pedestal_model',
'time_step_calculator',
Expand Down
25 changes: 22 additions & 3 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,18 @@ def _build_single_source_builder_from_config(
source_config: dict[str, Any],
) -> source_lib.SourceBuilderProtocol:
"""Builds a source builder from the input config."""
registered_source = register_source.get_registered_source(source_name)
runtime_params = registered_source.default_runtime_params_class()
supported_source = register_source.get_supported_source(source_name)
if 'model_func' in source_config:
# If the user has specified a model function, try to retrive that from the
# registered source model functions.
model_func = source_config.pop('model_func')
model_function = supported_source.model_functions[model_func]
else:
# Otherwise, use the default model function.
model_function = supported_source.model_functions[
supported_source.source_class.DEFAULT_MODEL_FUNCTION_NAME
]
runtime_params = model_function.runtime_params_class()
# Update the defaults with the config provided.
source_config = copy.copy(source_config)
if 'mode' in source_config:
Expand Down Expand Up @@ -445,7 +455,16 @@ def _build_single_source_builder_from_config(
if formula is not None:
kwargs['formula'] = formula

return registered_source.source_builder_class(**kwargs)
source_builder_class = model_function.source_builder_class
if source_builder_class is None:
source_builder_class = source_lib.make_source_builder(
supported_source.source_class,
runtime_params_type=model_function.runtime_params_class,
links_back=model_function.links_back,
model_func=model_function.source_profile_function,
)

return source_builder_class(**kwargs)


def build_transport_model_builder_from_config(
Expand Down
1 change: 1 addition & 0 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class BootstrapCurrentSource(source.Source):
"""

SOURCE_NAME: ClassVar[str] = 'j_bootstrap'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_neoclassical'

@property
def source_name(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class BremsstrahlungHeatSink(source.Source):
"""Brehmsstrahlung heat sink for electron heat equation."""

SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'bremsstrahlung_model_func'
model_func: source.SourceProfileFunction = bremsstrahlung_model_func

@property
Expand Down
9 changes: 5 additions & 4 deletions torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
gaussian_ec_total_power: array_typing.ScalarFloat


def _calc_heating_and_current(
def calc_heating_and_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_model_func: source_models.SourceModels,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Model function for the electron-cyclotron source.
Expand All @@ -128,7 +128,7 @@ def _calc_heating_and_current(
2D array of electron cyclotron heating power density and current density.
"""
del (
unused_model_func,
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
Expand Down Expand Up @@ -187,7 +187,8 @@ class ElectronCyclotronSource(source.Source):
"""Electron cyclotron source for the Te and Psi equations."""

SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source"
model_func: source.SourceProfileFunction = _calc_heating_and_current
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = "calc_heating_and_current"
model_func: source.SourceProfileFunction = calc_heating_and_current

@property
def source_name(self) -> str:
Expand Down
18 changes: 12 additions & 6 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams):


# Default formula: exponential with nref normalization.
def _calc_puff_source(
def calc_puff_source(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -109,7 +109,9 @@ class GasPuffSource(source.Source):
"""Gas puff source for the ne equation."""

SOURCE_NAME: ClassVar[str] = 'gas_puff_source'
formula: source.SourceProfileFunction = _calc_puff_source
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source'
formula: source.SourceProfileFunction = calc_puff_source
model_func: source.SourceProfileFunction = calc_puff_source

@property
def source_name(self) -> str:
Expand Down Expand Up @@ -172,7 +174,7 @@ class DynamicParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
S_tot: array_typing.ScalarFloat


def _calc_generic_particle_source(
def calc_generic_particle_source(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -205,7 +207,9 @@ class GenericParticleSource(source.Source):
"""Neutral-beam injection source for the ne equation."""

SOURCE_NAME: ClassVar[str] = 'generic_particle_source'
formula: source.SourceProfileFunction = _calc_generic_particle_source
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source'
formula: source.SourceProfileFunction = calc_generic_particle_source
model_func: source.SourceProfileFunction = calc_generic_particle_source

@property
def source_name(self) -> str:
Expand Down Expand Up @@ -260,7 +264,7 @@ class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
S_pellet_tot: array_typing.ScalarFloat


def _calc_pellet_source(
def calc_pellet_source(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -293,7 +297,9 @@ class PelletSource(source.Source):
"""Pellet source for the ne equation."""

SOURCE_NAME: ClassVar[str] = 'pellet_source'
formula: source.SourceProfileFunction = _calc_pellet_source
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_pellet_source'
formula: source.SourceProfileFunction = calc_pellet_source
model_func: source.SourceProfileFunction = calc_pellet_source

@property
def source_name(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class FusionHeatSource(source.Source):
"""Fusion heat source for both ion and electron heat."""

SOURCE_NAME: ClassVar[str] = 'fusion_heat_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'fusion_heat_model_func'
model_func: source.SourceProfileFunction = fusion_heat_model_func

@property
Expand Down
6 changes: 4 additions & 2 deletions torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __post_init__(self):

# pytype bug: does not treat 'source_models.SourceModels' as a forward reference
# pytype: disable=name-error
def _calculate_generic_current_face(
def calculate_generic_current_face(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -235,8 +235,10 @@ class GenericCurrentSource(source.Source):
"""A generic current density source profile."""

SOURCE_NAME: ClassVar[str] = 'generic_current_source'
formula: source.SourceProfileFunction = _calculate_generic_current_face
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face'
formula: source.SourceProfileFunction = calculate_generic_current_face
hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires
model_func: source.SourceProfileFunction = calculate_generic_current_face

@property
def source_name(self) -> str:
Expand Down
6 changes: 4 additions & 2 deletions torax/sources/generic_ion_el_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def calc_generic_heat_source(


# pytype: disable=name-error
def _default_formula(
def default_formula(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -150,7 +150,9 @@ class GenericIonElectronHeatSource(source.Source):
"""Generic heat source for both ion and electron heat."""

SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source'
formula: source.SourceProfileFunction = _default_formula
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula'
formula: source.SourceProfileFunction = default_formula
model_func: source.SourceProfileFunction = default_formula

@property
def source_name(self) -> str:
Expand Down
6 changes: 5 additions & 1 deletion torax/sources/impurity_radiation_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Basic impurity radiation heat sink for electron heat equation.."""

import dataclasses
from typing import ClassVar

import chex
import jax
Expand Down Expand Up @@ -139,10 +140,13 @@ class ImpurityRadiationHeatSink(source_lib.Source):
"""Impurity radiation heat sink for electron heat equation."""

SOURCE_NAME = "impurity_radiation_heat_sink"
source_models: source_models_lib.SourceModels
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = (
"radially_constant_fraction_of_Pin"
)
model_func: source_lib.SourceProfileFunction = (
radially_constant_fraction_of_Pin
)
source_models: source_models_lib.SourceModels

@property
def source_name(self) -> str:
Expand Down
20 changes: 12 additions & 8 deletions torax/sources/ion_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _helium3_tail_temperature(
return core_profiles.temp_el.value * (1 + epsilon)


def _icrh_model_func(
def icrh_model_func(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -485,8 +485,6 @@ def _icrh_model_func(
source_ion += power_deposition_2T * dynamic_source_runtime_params.Ptot

return jnp.stack([source_ion, source_el])


# pylint: enable=invalid-name


Expand All @@ -495,6 +493,7 @@ class IonCyclotronSource(source.Source):
"""Ion cyclotron source with surrogate model."""

SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'icrh_model_func'

@property
def source_name(self) -> str:
Expand Down Expand Up @@ -528,13 +527,18 @@ class IonCyclotronSourceBuilder:
default_factory=RuntimeParams
)
links_back: bool = False
model_func: source.SourceProfileFunction | None = None

def __post_init__(self):
if self.model_func is None:
self.model_func = functools.partial(
icrh_model_func,
toric_nn=ToricNNWrapper(),
)

def __call__(
self,
formula: source.SourceProfileFunction | None = None,
) -> IonCyclotronSource:
model_func: source.SourceProfileFunction = functools.partial(
_icrh_model_func,
toric_nn=ToricNNWrapper(),
)
return IonCyclotronSource(formula=formula, model_func=model_func,)

return IonCyclotronSource(formula=formula, model_func=self.model_func,)
13 changes: 3 additions & 10 deletions torax/sources/ohmic_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def ohmic_model_func(
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels | None = None,
source_models: source_models_lib.SourceModels,
) -> jax.Array:
"""Returns the Ohmic source for electron heat equation."""
del source_name # Unused.
Expand Down Expand Up @@ -190,17 +190,10 @@ class OhmicHeatSource(source_lib.Source):
"""

SOURCE_NAME: ClassVar[str] = 'ohmic_heat_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'ohmic_model_func'
model_func: source_lib.SourceProfileFunction = ohmic_model_func
# Users must pass in a pointer to the complete set of sources to this object.
source_models: source_models_lib.SourceModels
# The model function is fixed to ohmic_model_func because that is the only
# supported implementation of this source.
# However, since this is a param in the parent dataclass, we need to (a)
# remove the parameter from the init args and (b) set the default to the
# desired value.
model_func: source_lib.SourceProfileFunction | None = dataclasses.field(
init=False,
default_factory=lambda: ohmic_model_func,
)

@property
def source_name(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions torax/sources/qei_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class QeiSource(source.Source):
"""

SOURCE_NAME: ClassVar[str] = 'qei_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'model_based_qei'

@property
def source_name(self) -> str:
Expand Down
Loading

0 comments on commit 18f5d0a

Please sign in to comment.