Skip to content

Commit

Permalink
Allow Sources to accept multiple model functions.
Browse files Browse the repository at this point in the history
This modifies the sources registry to register a dict of model functions. These will be tied to runtime params and either autocreate a builder (or a given source builder class can be provided that will be used)

Follow up changes to:
- move all `function`s to be used through `model_func`.
- remove `function` and `affected_core_profiles`.
- add a `Protocol` for `SourceProfileFunctions` that will define the interface.
- open the registry to accept new model functions.

PiperOrigin-RevId: 706769008
  • Loading branch information
Nush395 authored and Torax team committed Dec 18, 2024
1 parent 7ec5d8e commit af68785
Show file tree
Hide file tree
Showing 40 changed files with 563 additions and 577 deletions.
2 changes: 2 additions & 0 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def set_jax_precision():
'geo_t',
'geo_t_plus_dt',
'geometry_provider',
'source_name',
'x_old',
'state',
'unused_state',
Expand All @@ -88,6 +89,7 @@ def set_jax_precision():
'source_profiles',
'source_profile',
'explicit_source_profiles',
'model_func',
'source_models',
'pedestal_model',
'time_step_calculator',
Expand Down
3 changes: 0 additions & 3 deletions torax/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,6 @@ def _calc_coeffs_full(
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.qei_source_name
],
geo=geo,
# For Qei, always use the current set of core profiles.
# In the linear solver, core_profiles is the set of profiles at time t (at
Expand Down
17 changes: 15 additions & 2 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,11 @@ def _build_single_source_builder_from_config(
) -> source_lib.SourceBuilderProtocol:
"""Builds a source builder from the input config."""
registered_source = register_source.get_registered_source(source_name)
runtime_params = registered_source.default_runtime_params_class()
model_func = register_source.DEFAULT_MODEL_FUNCTION_NAME
if 'model_func' in source_config:
model_func = source_config.pop('model_func')
model_function = registered_source.model_functions[model_func]
runtime_params = model_function.runtime_params_class()
# Update the defaults with the config provided.
source_config = copy.copy(source_config)
if 'mode' in source_config:
Expand Down Expand Up @@ -444,7 +448,16 @@ def _build_single_source_builder_from_config(
if formula is not None:
kwargs['formula'] = formula

return registered_source.source_builder_class(**kwargs)
source_builder_class = model_function.source_builder_class
if source_builder_class is None:
source_builder_class = source_lib.make_source_builder(
registered_source.source_class,
runtime_params_type=model_function.runtime_params_class,
links_back=model_function.links_back,
model_func=model_function.source_profile_function,
)

return source_builder_class(**kwargs)


def build_transport_model_builder_from_config(
Expand Down
36 changes: 0 additions & 36 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,7 @@ def _prescribe_currents_no_bootstrap(
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
],
geo=geo,
core_profiles=core_profiles,
)
Expand All @@ -312,7 +308,6 @@ def _prescribe_currents_no_bootstrap(
jtot_hires = _get_jtot_hires(
static_runtime_params_slice,
dynamic_runtime_params_slice,
dynamic_generic_current_params,
geo,
bootstrap_profile,
Iohm,
Expand Down Expand Up @@ -363,13 +358,7 @@ def _prescribe_currents_with_bootstrap(

bootstrap_profile = source_models.j_bootstrap.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand All @@ -389,11 +378,7 @@ def _prescribe_currents_with_bootstrap(
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.generic_current_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -424,7 +409,6 @@ def _prescribe_currents_with_bootstrap(
jtot_hires = _get_jtot_hires(
static_runtime_params_slice,
dynamic_runtime_params_slice,
dynamic_generic_current_params,
geo,
bootstrap_profile,
Iohm,
Expand Down Expand Up @@ -478,31 +462,16 @@ def _calculate_currents_from_psi(

bootstrap_profile = source_models.j_bootstrap.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
geo=geo,
core_profiles=core_profiles,
)

# Calculate splitting of currents depending on input runtime params.
dynamic_generic_current_params = get_generic_current_params(
dynamic_runtime_params_slice, source_models
)

# calculate "External" current profile (e.g. ECCD)
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.generic_current_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -994,7 +963,6 @@ def compute_boundary_conditions(
def _get_jtot_hires(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_generic_current_params: generic_current_source.DynamicRuntimeParams,
geo: Geometry,
bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile,
Iohm: jax.Array | float,
Expand All @@ -1008,11 +976,7 @@ def _get_jtot_hires(
# calculate hi-res "External" current profile (e.g. ECCD) on cell grid.
generic_current_hires = generic_current.generic_current_source_hires(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
],
geo=geo,
)

Expand Down
20 changes: 0 additions & 20 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,33 +1292,16 @@ def update_current_distribution(

bootstrap_profile = source_models.j_bootstrap.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
geo=geo,
core_profiles=core_profiles,
)

# Calculate splitting of currents depending on input runtime params.
dynamic_generic_current_params = (
core_profile_setters.get_generic_current_params(
dynamic_runtime_params_slice, source_models
)
)

# calculate "External" current profile (e.g. ECCD)
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.generic_current_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -1468,9 +1451,6 @@ def get_initial_source_profiles(
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.qei_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down
19 changes: 14 additions & 5 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class BootstrapCurrentSource(source.Source):
"""
SOURCE_NAME: ClassVar[str] = 'j_bootstrap'

@property
def source_name(self) -> str:
return self.SOURCE_NAME

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
return (
Expand All @@ -110,12 +114,16 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:
def get_value(
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
static_source_runtime_params: runtime_params_lib.StaticRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> source_profiles.BootstrapCurrentProfile:
static_source_runtime_params = static_runtime_params_slice.sources[
self.source_name
]
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
self.source_name
]
# Make sure the input mode requested is supported.
self.check_mode(static_source_runtime_params.mode)
# Make sure the input params are the correct type.
Expand All @@ -126,7 +134,6 @@ def get_value(
)
bootstrap_current = calc_neoclassical(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_source_runtime_params,
geo=geo,
temp_ion=core_profiles.temp_ion,
temp_el=core_profiles.temp_el,
Expand Down Expand Up @@ -174,7 +181,6 @@ def get_source_profile_for_affected_core_profile(
@jax_utils.jit
def calc_neoclassical(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: DynamicRuntimeParams,
geo: geometry.Geometry,
temp_ion: cell_variable.CellVariable,
temp_el: cell_variable.CellVariable,
Expand All @@ -186,7 +192,6 @@ def calc_neoclassical(
Args:
dynamic_runtime_params_slice: General configuration parameters.
dynamic_source_runtime_params: Source-specific runtime parameters.
geo: Torus geometry.
temp_ion: Ion temperature. We don't pass in a full `core_profiles` here
because this function is used to create the `Currents` in the initial
Expand All @@ -199,6 +204,10 @@ def calc_neoclassical(
Returns:
A BootstrapCurrentProfile. See that class's docstring for more info.
"""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
BootstrapCurrentSource.SOURCE_NAME
]
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Expand Down
12 changes: 8 additions & 4 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,20 @@ def calc_relativistic_correction() -> jax.Array:

def bremsstrahlung_model_func(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
static_source_runtime_params: runtime_params_lib.StaticRuntimeParams,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_model_func: source_models.SourceModels | None,
) -> jax.Array:
"""Model function for the Bremsstrahlung heat sink."""
del (
static_source_runtime_params,
static_runtime_params_slice,
unused_model_func,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
_, P_brem_profile = calc_bremsstrahlung(
core_profiles,
Expand All @@ -150,7 +151,10 @@ def bremsstrahlung_model_func(
class BremsstrahlungHeatSink(source.Source):
"""Brehmsstrahlung heat sink for electron heat equation."""
SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink'
model_func: source.SourceProfileFunction = bremsstrahlung_model_func

@property
def source_name(self) -> str:
return self.SOURCE_NAME

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
Expand Down
22 changes: 12 additions & 10 deletions torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,13 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
gaussian_ec_total_power: array_typing.ScalarFloat


def _calc_heating_and_current(
def calc_heating_and_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
static_source_runtime_params: runtime_params_lib.StaticRuntimeParams,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_model_func: source_models.SourceModels,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Model function for the electron-cyclotron source.
Expand All @@ -119,22 +118,22 @@ def _calc_heating_and_current(
Args:
static_runtime_params_slice: Static runtime parameters.
static_source_runtime_params: Static runtime parameters.
dynamic_runtime_params_slice: Global runtime parameters
dynamic_source_runtime_params: Specific runtime parameters for the
electron-cyclotron source.
geo: Magnetic geometry.
source_name: Name of the source.
core_profiles: CoreProfiles component of the state.
unused_model_func: (unused) source models used in the simulation.
Returns:
2D array of electron cyclotron heating power density and current density.
"""
del (
static_source_runtime_params,
unused_model_func,
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
# Helps linter understand the type of dynamic_source_runtime_params.
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
# Construct the profile
Expand Down Expand Up @@ -187,7 +186,10 @@ def _get_ec_output_shape(geo: geometry.Geometry) -> tuple[int, ...]:
class ElectronCyclotronSource(source.Source):
"""Electron cyclotron source for the Te and Psi equations."""
SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source"
model_func: source.SourceProfileFunction = _calc_heating_and_current

@property
def source_name(self) -> str:
return self.SOURCE_NAME

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
Expand Down
Loading

0 comments on commit af68785

Please sign in to comment.