Skip to content

Commit

Permalink
Move over remaining Source formula implementations to be defined as m…
Browse files Browse the repository at this point in the history
…odel_func instead.

PiperOrigin-RevId: 707556513
  • Loading branch information
Nush395 authored and Torax team committed Dec 19, 2024
1 parent f6af816 commit edfcba5
Show file tree
Hide file tree
Showing 42 changed files with 596 additions and 607 deletions.
2 changes: 2 additions & 0 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def set_jax_precision():
'geo_t',
'geo_t_plus_dt',
'geometry_provider',
'source_name',
'x_old',
'state',
'unused_state',
Expand All @@ -88,6 +89,7 @@ def set_jax_precision():
'source_profiles',
'source_profile',
'explicit_source_profiles',
'model_func',
'source_models',
'pedestal_model',
'time_step_calculator',
Expand Down
17 changes: 15 additions & 2 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,11 @@ def _build_single_source_builder_from_config(
) -> source_lib.SourceBuilderProtocol:
"""Builds a source builder from the input config."""
registered_source = register_source.get_registered_source(source_name)
runtime_params = registered_source.default_runtime_params_class()
model_func = register_source.DEFAULT_MODEL_FUNCTION_NAME
if 'model_func' in source_config:
model_func = source_config.pop('model_func')
model_function = registered_source.model_functions[model_func]
runtime_params = model_function.runtime_params_class()
# Update the defaults with the config provided.
source_config = copy.copy(source_config)
if 'mode' in source_config:
Expand Down Expand Up @@ -445,7 +449,16 @@ def _build_single_source_builder_from_config(
if formula is not None:
kwargs['formula'] = formula

return registered_source.source_builder_class(**kwargs)
source_builder_class = model_function.source_builder_class
if source_builder_class is None:
source_builder_class = source_lib.make_source_builder(
registered_source.source_class,
runtime_params_type=model_function.runtime_params_class,
links_back=model_function.links_back,
model_func=model_function.source_profile_function,
)

return source_builder_class(**kwargs)


def build_transport_model_builder_from_config(
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_adding_standard_source_via_config(self):
# pytype: enable=attribute-error
self.assertEqual(
source_models_builder.runtime_params['gas_puff_source'].mode,
source_runtime_params_lib.Mode.FORMULA_BASED, # On by default.
source_runtime_params_lib.Mode.MODEL_BASED, # On by default.
)
self.assertEqual(
source_models_builder.runtime_params['ohmic_heat_source'].mode,
Expand Down
36 changes: 0 additions & 36 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,7 @@ def _prescribe_currents_no_bootstrap(
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
],
geo=geo,
core_profiles=core_profiles,
)
Expand All @@ -311,7 +307,6 @@ def _prescribe_currents_no_bootstrap(
jtot_hires = _get_jtot_hires(
static_runtime_params_slice,
dynamic_runtime_params_slice,
dynamic_generic_current_params,
geo,
bootstrap_profile,
Iohm,
Expand Down Expand Up @@ -362,13 +357,7 @@ def _prescribe_currents_with_bootstrap(

bootstrap_profile = source_models.j_bootstrap.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand All @@ -388,11 +377,7 @@ def _prescribe_currents_with_bootstrap(
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.generic_current_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -423,7 +408,6 @@ def _prescribe_currents_with_bootstrap(
jtot_hires = _get_jtot_hires(
static_runtime_params_slice,
dynamic_runtime_params_slice,
dynamic_generic_current_params,
geo,
bootstrap_profile,
Iohm,
Expand Down Expand Up @@ -477,31 +461,16 @@ def _calculate_currents_from_psi(

bootstrap_profile = source_models.j_bootstrap.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
geo=geo,
core_profiles=core_profiles,
)

# Calculate splitting of currents depending on input runtime params.
dynamic_generic_current_params = get_generic_current_params(
dynamic_runtime_params_slice, source_models
)

# calculate "External" current profile (e.g. ECCD)
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.generic_current_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -991,7 +960,6 @@ def compute_boundary_conditions(
def _get_jtot_hires(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_generic_current_params: generic_current_source.DynamicRuntimeParams,
geo: geometry.Geometry,
bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile,
Iohm: jax.Array | float,
Expand All @@ -1005,11 +973,7 @@ def _get_jtot_hires(
# calculate hi-res "External" current profile (e.g. ECCD) on cell grid.
generic_current_hires = generic_current.generic_current_source_hires(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
],
geo=geo,
)

Expand Down
3 changes: 0 additions & 3 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,6 @@ def _calc_coeffs_full(
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.qei_source_name
],
geo=geo,
# For Qei, always use the current set of core profiles.
# In the linear solver, core_profiles is the set of profiles at time t (at
Expand Down
20 changes: 0 additions & 20 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,33 +1349,16 @@ def update_current_distribution(

bootstrap_profile = source_models.j_bootstrap.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.j_bootstrap_name
],
geo=geo,
core_profiles=core_profiles,
)

# Calculate splitting of currents depending on input runtime params.
dynamic_generic_current_params = (
core_profile_setters.get_generic_current_params(
dynamic_runtime_params_slice, source_models
)
)

# calculate "External" current profile (e.g. ECCD)
# form of external current on face grid
generic_current_face = source_models.generic_current_source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_generic_current_params,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_runtime_params_slice.sources[
source_models.generic_current_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -1525,9 +1508,6 @@ def get_initial_source_profiles(
qei = source_models.qei_source.get_qei(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
source_models.qei_source_name
],
geo=geo,
core_profiles=core_profiles,
)
Expand Down
19 changes: 14 additions & 5 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class BootstrapCurrentSource(source.Source):

SOURCE_NAME: ClassVar[str] = 'j_bootstrap'

@property
def source_name(self) -> str:
return self.SOURCE_NAME

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
return (
Expand All @@ -111,12 +115,16 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:
def get_value(
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
static_source_runtime_params: runtime_params_lib.StaticRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> source_profiles.BootstrapCurrentProfile:
static_source_runtime_params = static_runtime_params_slice.sources[
self.source_name
]
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
self.source_name
]
# Make sure the input mode requested is supported.
self.check_mode(static_source_runtime_params.mode)
# Make sure the input params are the correct type.
Expand All @@ -127,7 +135,6 @@ def get_value(
)
bootstrap_current = calc_neoclassical(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_source_runtime_params,
geo=geo,
temp_ion=core_profiles.temp_ion,
temp_el=core_profiles.temp_el,
Expand Down Expand Up @@ -175,7 +182,6 @@ def get_source_profile_for_affected_core_profile(
@jax_utils.jit
def calc_neoclassical(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: DynamicRuntimeParams,
geo: geometry.Geometry,
temp_ion: cell_variable.CellVariable,
temp_el: cell_variable.CellVariable,
Expand All @@ -187,7 +193,6 @@ def calc_neoclassical(
Args:
dynamic_runtime_params_slice: General configuration parameters.
dynamic_source_runtime_params: Source-specific runtime parameters.
geo: Torus geometry.
temp_ion: Ion temperature. We don't pass in a full `core_profiles` here
because this function is used to create the `Currents` in the initial
Expand All @@ -200,6 +205,10 @@ def calc_neoclassical(
Returns:
A BootstrapCurrentProfile. See that class's docstring for more info.
"""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
BootstrapCurrentSource.SOURCE_NAME
]
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Expand Down
12 changes: 8 additions & 4 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,20 @@ def calc_relativistic_correction() -> jax.Array:

def bremsstrahlung_model_func(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
static_source_runtime_params: runtime_params_lib.StaticRuntimeParams,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_model_func: source_models.SourceModels | None,
) -> jax.Array:
"""Model function for the Bremsstrahlung heat sink."""
del (
static_source_runtime_params,
static_runtime_params_slice,
unused_model_func,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
_, P_brem_profile = calc_bremsstrahlung(
core_profiles,
Expand All @@ -152,7 +153,10 @@ class BremsstrahlungHeatSink(source.Source):
"""Brehmsstrahlung heat sink for electron heat equation."""

SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink'
model_func: source.SourceProfileFunction = bremsstrahlung_model_func

@property
def source_name(self) -> str:
return self.SOURCE_NAME

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
Expand Down
22 changes: 12 additions & 10 deletions torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,13 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
gaussian_ec_total_power: array_typing.ScalarFloat


def _calc_heating_and_current(
def calc_heating_and_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
static_source_runtime_params: runtime_params_lib.StaticRuntimeParams,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_model_func: source_models.SourceModels,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Model function for the electron-cyclotron source.
Expand All @@ -119,22 +118,22 @@ def _calc_heating_and_current(
Args:
static_runtime_params_slice: Static runtime parameters.
static_source_runtime_params: Static runtime parameters.
dynamic_runtime_params_slice: Global runtime parameters
dynamic_source_runtime_params: Specific runtime parameters for the
electron-cyclotron source.
geo: Magnetic geometry.
source_name: Name of the source.
core_profiles: CoreProfiles component of the state.
unused_model_func: (unused) source models used in the simulation.
Returns:
2D array of electron cyclotron heating power density and current density.
"""
del (
static_source_runtime_params,
unused_model_func,
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
# Helps linter understand the type of dynamic_source_runtime_params.
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
# Construct the profile
Expand Down Expand Up @@ -188,7 +187,10 @@ class ElectronCyclotronSource(source.Source):
"""Electron cyclotron source for the Te and Psi equations."""

SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source"
model_func: source.SourceProfileFunction = _calc_heating_and_current

@property
def source_name(self) -> str:
return self.SOURCE_NAME

@property
def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]:
Expand Down
Loading

0 comments on commit edfcba5

Please sign in to comment.