From 8dbe3921ab225fb6a40b9e2383e295d0297561bd Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Wed, 18 Dec 2024 09:43:06 -0800 Subject: [PATCH] Remove the `Exponential` and `Gaussian` formulas from configuration. The formulas are kept for general use but all the configuration objects are removed as they could be directly input by a user if needed (pending addition of a utility to register new model functions). PiperOrigin-RevId: 707580250 --- docs/configuration.rst | 24 +- docs/physics_models.rst | 7 +- torax/__init__.py | 2 + torax/config/build_sim.py | 92 ++----- torax/config/tests/build_sim.py | 30 +-- torax/config/tests/runtime_params_slice.py | 59 ----- torax/core_profile_setters.py | 36 --- torax/fvm/calc_coeffs.py | 3 - torax/fvm/tests/fvm.py | 4 +- torax/sim.py | 20 -- torax/sources/bootstrap_current_source.py | 20 +- torax/sources/bremsstrahlung_heat_sink.py | 12 +- torax/sources/electron_cyclotron_source.py | 24 +- torax/sources/electron_density_sources.py | 70 ++--- torax/sources/formula_config.py | 143 ----------- torax/sources/formulas.py | 77 ------ torax/sources/fusion_heat_source.py | 14 +- torax/sources/generic_current_source.py | 58 +++-- torax/sources/generic_ion_el_heat_source.py | 19 +- torax/sources/impurity_radiation_heat_sink.py | 35 ++- torax/sources/ion_cyclotron_source.py | 55 ++-- torax/sources/ohmic_heat_source.py | 23 +- torax/sources/qei_source.py | 13 +- torax/sources/register_source.py | 200 +++++++++------ torax/sources/runtime_params.py | 10 - torax/sources/source.py | 142 +++++----- torax/sources/source_models.py | 41 +-- .../sources/tests/bootstrap_current_source.py | 2 + .../sources/tests/bremsstrahlung_heat_sink.py | 2 + .../tests/electron_cyclotron_source.py | 26 +- .../sources/tests/electron_density_sources.py | 33 +-- torax/sources/tests/formulas.py | 242 ------------------ torax/sources/tests/fusion_heat_source.py | 2 + torax/sources/tests/generic_current_source.py | 14 +- .../tests/generic_ion_el_heat_source.py | 7 +- .../tests/impurity_radiation_heat_sink.py | 59 ++--- torax/sources/tests/ion_cyclotron_source.py | 19 +- torax/sources/tests/ohmic_heat_source.py | 2 + torax/sources/tests/qei_source.py | 4 +- torax/sources/tests/register_source.py | 19 +- torax/sources/tests/source.py | 111 +++----- torax/sources/tests/source_models.py | 7 +- torax/sources/tests/test_lib.py | 62 +++-- torax/tests/physics.py | 9 +- torax/tests/sim_custom_sources.py | 85 +++--- torax/tests/sim_output_source_profiles.py | 49 +++- torax/tests/state.py | 6 - torax/tests/test_lib/default_sources.py | 21 +- 48 files changed, 681 insertions(+), 1333 deletions(-) delete mode 100644 torax/sources/formula_config.py delete mode 100644 torax/sources/tests/formulas.py diff --git a/docs/configuration.rst b/docs/configuration.rst index 80f8c552..505cb398 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -840,25 +840,6 @@ and can be set to anything convenient. beginning of a time step, or do not have any dependance on state. Implicit sources depend on updated states as the iterative solvers evolve the state through the course of a time step. If a source model is complex but evolves over slow timescales compared to the state, it may be beneficial to set it as explicit. -``formula_type`` (str='default') - Sets the formula type if ``mode=='formula'``. The current options are: - -* ``'exponential'`` takes the following arguments: - * c1 (float): Offset location - * c2 (float): Exponential decay parameter - * total (float): integral - - The profile is parameterized as follows :math:`Q = C e^{-(r - c1) / c2}` , where ``C`` is calculated to be consistent with ``total``. If ``use_normalized_r==True``, - then c1 and c2 are interpreted as being in normalized toroidal flux units. - -* ``'gaussian'`` takes the following arguments: - * c1 (float): Gaussian peak Location - * c2 (float): Gaussian width - * total (float): integral - - The profile is parameterized as follows :math:`Q = C e^{-((r - c1)^2) / (2 c2^2)}` , where ``C`` is calculated to be consistent with ``total``. If ``use_normalized_r==True``, - then c1 and c2 are interpreted as being in normalized toroidal flux units. - * ``'default'`` Some sources have default implementations which use the above formulas under the hood with intuitive parameter names for c1 and c2. Consult the list below for further details. @@ -868,10 +849,7 @@ generic_ion_el_heat_source A utility source module that allows for a time dependent Gaussian ion and electron heat source. -``mode`` (str = 'formula') - -``formula_type`` (str = 'default') - Uses the Gaussian formula. +``mode`` (str = 'model') ``rsource`` (float = 0.0), **time-varying-scalar** Gaussian center of source profile in units of :math:`\hat{\rho}`. diff --git a/docs/physics_models.rst b/docs/physics_models.rst index f871a48d..c8db037a 100644 --- a/docs/physics_models.rst +++ b/docs/physics_models.rst @@ -284,11 +284,8 @@ not need to be JAX-compatible, since explicit sources are an input into the PDE and do not require JIT compilation. Conversely, implicit treatment can be important for accurately resolving the impact of fast-evolving source terms. -All sources can optionally be set to zero, prescribed with non-physics-based formulas -(currently Gaussian or exponential) with user-configurable time-dependent parameters like -amplitude, width, and location, or calculated with a dedicated physics-based model. Not -all sources currently have a model implementation. However, the code modular structure -facilitates easy coupling of additional source models in future work. Specifics of source models +All sources can optionally be set to zero, prescribed with explicit values or calculated with a dedicated physics-based model. +However, the code modular structure facilitates easy coupling of additional source models in future work. Specifics of source models currently implemented in TORAX follow: Ion-electron heat exchange diff --git a/torax/__init__.py b/torax/__init__.py index 9ca6d8f7..9f1936f4 100644 --- a/torax/__init__.py +++ b/torax/__init__.py @@ -73,6 +73,7 @@ def set_jax_precision(): 'geo_t', 'geo_t_plus_dt', 'geometry_provider', + 'source_name', 'x_old', 'state', 'unused_state', @@ -88,6 +89,7 @@ def set_jax_precision(): 'source_profiles', 'source_profile', 'explicit_source_profiles', + 'model_func', 'source_models', 'pedestal_model', 'time_step_calculator', diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index afaf7ddc..a749ed3e 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -25,10 +25,7 @@ from torax.geometry import geometry_provider from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped -from torax.sources import formula_config -from torax.sources import formulas from torax.sources import register_source -from torax.sources import runtime_params as source_runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method @@ -353,48 +350,12 @@ def build_sources_builder_from_config( }, } - If the `mode` is set to `formula_based`, then the you can provide a - `formula_type` key which may have the following values: - - - `default`: Uses the default impl (if the source has one) (default) - - - The other config args are based on the source's RuntimeParams object - outlined above. - - - `exponential`: Exponential profile. - - - The other config args are from `sources.formula_config.Exponential`. - - - `gaussian`: Gaussian profile. - - - The other config args are from `sources.formula_config.Gaussian`. - - E.g. for an example heat source: - - .. code-block:: python - - { - mode: 'formula', - formula_type: 'gaussian', - total: 120e6, # total heating - c1: 0.0, # Source Gaussian central location (in normalized r) - c2: 0.25, # Gaussian width in normalized radial coordinates - } - - If you have custom source implementations, you may update this funtion to - handle those new sources and keys, or you may use the "advanced" configuration - method and build your `SourceModel` object directly. - Args: source_configs: Input config dict defining all sources, with a structure as described above. Returns: A `SourceModelsBuilder`. - - Raises: - ValueError if an input key doesn't match one of the source names defined - above. """ source_builders = { @@ -410,42 +371,33 @@ def _build_single_source_builder_from_config( source_config: dict[str, Any], ) -> source_lib.SourceBuilderProtocol: """Builds a source builder from the input config.""" - registered_source = register_source.get_registered_source(source_name) - runtime_params = registered_source.default_runtime_params_class() - # Update the defaults with the config provided. - source_config = copy.copy(source_config) - if 'mode' in source_config: - mode = source_runtime_params_lib.Mode[source_config.pop('mode').upper()] - runtime_params.mode = mode - formula = None - if 'formula_type' in source_config: - func = source_config.pop('formula_type').lower() - if func == 'default': - pass # Nothing to do here. - elif func == 'exponential': - runtime_params.formula = config_args.recursive_replace( - formula_config.Exponential(), - ignore_extra_kwargs=True, - **source_config, - ) - formula = formulas.Exponential() - elif func == 'gaussian': - runtime_params.formula = config_args.recursive_replace( - formula_config.Gaussian(), - ignore_extra_kwargs=True, - **source_config, - ) - formula = formulas.Gaussian() - else: - raise ValueError(f'Unknown formula_type for source {source_name}: {func}') + supported_source = register_source.get_supported_source(source_name) + if 'model_func' in source_config: + # If the user has specified a model function, try to retrive that from the + # registered source model functions. + model_func = source_config.pop('model_func') + model_function = supported_source.model_functions[model_func] + else: + # Otherwise, use the default model function. + model_function = supported_source.model_functions[ + supported_source.source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + runtime_params = model_function.runtime_params_class() runtime_params = config_args.recursive_replace( runtime_params, ignore_extra_kwargs=True, **source_config ) kwargs = {'runtime_params': runtime_params} - if formula is not None: - kwargs['formula'] = formula - return registered_source.source_builder_class(**kwargs) + source_builder_class = model_function.source_builder_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + supported_source.source_class, + runtime_params_type=model_function.runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + + return source_builder_class(**kwargs) def build_transport_model_builder_from_config( diff --git a/torax/config/tests/build_sim.py b/torax/config/tests/build_sim.py index 39619751..445018d6 100644 --- a/torax/config/tests/build_sim.py +++ b/torax/config/tests/build_sim.py @@ -23,8 +23,6 @@ from torax.geometry import geometry from torax.geometry import geometry_provider from torax.pedestal_model import set_tped_nped -from torax.sources import formula_config -from torax.sources import formulas from torax.sources import runtime_params as source_runtime_params_lib from torax.stepper import linear_theta_method from torax.stepper import nonlinear_theta_method @@ -366,39 +364,13 @@ def test_adding_standard_source_via_config(self): # pytype: enable=attribute-error self.assertEqual( source_models_builder.runtime_params['gas_puff_source'].mode, - source_runtime_params_lib.Mode.FORMULA_BASED, # On by default. + source_runtime_params_lib.Mode.MODEL_BASED, # On by default. ) self.assertEqual( source_models_builder.runtime_params['ohmic_heat_source'].mode, source_runtime_params_lib.Mode.ZERO, ) - def test_updating_formula_via_source_config(self): - """Tests that we can set the formula type and params via the config.""" - source_models_builder = build_sim.build_sources_builder_from_config({ - 'gas_puff_source': { - 'formula_type': 'gaussian', - 'total': 1, - 'c1': 2, - 'c2': 3, - } - }) - source_models = source_models_builder() - gas_source = source_models.sources['gas_puff_source'] - self.assertIsInstance(gas_source.formula, formulas.Gaussian) - gas_source_runtime_params = source_models_builder.runtime_params[ - 'gas_puff_source' - ] - self.assertIsInstance( - gas_source_runtime_params.formula, - formula_config.Gaussian, - ) - # pytype: disable=attribute-error - self.assertEqual(gas_source_runtime_params.formula.total, 1) - self.assertEqual(gas_source_runtime_params.formula.c1, 2) - self.assertEqual(gas_source_runtime_params.formula.c2, 3) - # pytype: enable=attribute-error - def test_missing_transport_model_raises_error(self): with self.assertRaises(ValueError): build_sim.build_transport_model_builder_from_config({}) diff --git a/torax/config/tests/runtime_params_slice.py b/torax/config/tests/runtime_params_slice.py index 52f6a5c2..ee110cf9 100644 --- a/torax/config/tests/runtime_params_slice.py +++ b/torax/config/tests/runtime_params_slice.py @@ -26,9 +26,7 @@ from torax.geometry import geometry from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources -from torax.sources import formula_config from torax.sources import generic_current_source -from torax.sources import runtime_params as sources_params_lib from torax.stepper import runtime_params as stepper_params_lib from torax.tests.test_lib import default_sources from torax.transport_model import runtime_params as transport_params_lib @@ -248,63 +246,6 @@ def test_source_formula_config_has_time_dependent_params(self): ) np.testing.assert_allclose(generic_particle_source.S_tot, 4.0) - with self.subTest('exponential_formula'): - runtime_params = general_runtime_params.GeneralRuntimeParams() - dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources={ - electron_density_sources.GasPuffSource.SOURCE_NAME: ( - sources_params_lib.RuntimeParams( - formula=formula_config.Exponential( - total={0.0: 0.0, 1.0: 1.0}, - c1={0.0: 0.0, 1.0: 2.0}, - c2={0.0: 0.0, 1.0: 3.0}, - ) - ) - ), - }, - torax_mesh=self._geo.torax_mesh, - )( - t=0.25, - ) - gas_puff_source = dcs.sources[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ] - assert isinstance( - gas_puff_source.formula, - formula_config.DynamicExponential, - ) - np.testing.assert_allclose(gas_puff_source.formula.total, 0.25) - np.testing.assert_allclose(gas_puff_source.formula.c1, 0.5) - np.testing.assert_allclose(gas_puff_source.formula.c2, 0.75) - - with self.subTest('gaussian_formula'): - runtime_params = general_runtime_params.GeneralRuntimeParams() - dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources={ - electron_density_sources.GasPuffSource.SOURCE_NAME: ( - sources_params_lib.RuntimeParams( - formula=formula_config.Gaussian( - total={0.0: 0.0, 1.0: 1.0}, - c1={0.0: 0.0, 1.0: 2.0}, - c2={0.0: 0.0, 1.0: 3.0}, - ) - ) - ), - }, - torax_mesh=self._geo.torax_mesh, - )( - t=0.25, - ) - gas_puff_source = dcs.sources[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ] - assert isinstance(gas_puff_source.formula, formula_config.DynamicGaussian) - np.testing.assert_allclose(gas_puff_source.formula.total, 0.25) - np.testing.assert_allclose(gas_puff_source.formula.c1, 0.5) - np.testing.assert_allclose(gas_puff_source.formula.c2, 0.75) - def test_wext_in_dynamic_runtime_params_cannot_be_negative(self): """Tests that wext cannot be negative.""" runtime_params = general_runtime_params.GeneralRuntimeParams() diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index e7ad62d1..028deda0 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -280,11 +280,7 @@ def _prescribe_currents_no_bootstrap( # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo=geo, core_profiles=core_profiles, ) @@ -311,7 +307,6 @@ def _prescribe_currents_no_bootstrap( jtot_hires = _get_jtot_hires( static_runtime_params_slice, dynamic_runtime_params_slice, - dynamic_generic_current_params, geo, bootstrap_profile, Iohm, @@ -362,13 +357,7 @@ def _prescribe_currents_with_bootstrap( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles, ) @@ -388,11 +377,7 @@ def _prescribe_currents_with_bootstrap( # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.generic_current_source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -423,7 +408,6 @@ def _prescribe_currents_with_bootstrap( jtot_hires = _get_jtot_hires( static_runtime_params_slice, dynamic_runtime_params_slice, - dynamic_generic_current_params, geo, bootstrap_profile, Iohm, @@ -477,31 +461,16 @@ def _calculate_currents_from_psi( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles, ) - # Calculate splitting of currents depending on input runtime params. - dynamic_generic_current_params = get_generic_current_params( - dynamic_runtime_params_slice, source_models - ) - # calculate "External" current profile (e.g. ECCD) # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.generic_current_source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -991,7 +960,6 @@ def compute_boundary_conditions( def _get_jtot_hires( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_generic_current_params: generic_current_source.DynamicRuntimeParams, geo: geometry.Geometry, bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile, Iohm: jax.Array | float, @@ -1005,11 +973,7 @@ def _get_jtot_hires( # calculate hi-res "External" current profile (e.g. ECCD) on cell grid. generic_current_hires = generic_current.generic_current_source_hires( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo=geo, ) diff --git a/torax/fvm/calc_coeffs.py b/torax/fvm/calc_coeffs.py index a30cf7b8..13ee3e82 100644 --- a/torax/fvm/calc_coeffs.py +++ b/torax/fvm/calc_coeffs.py @@ -728,9 +728,6 @@ def _calc_coeffs_full( qei = source_models.qei_source.get_qei( static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.qei_source_name - ], geo=geo, # For Qei, always use the current set of core profiles. # In the linear solver, core_profiles is the set of profiles at time t (at diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index bb106b0b..1b95de93 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -578,7 +578,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): ) ) geo = geometry.build_circular_geometry(n_rho=num_cells) - source_models = source_models_lib.SourceModels() + source_models = source_models_builder() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice, dynamic_runtime_params_slice, @@ -729,7 +729,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): ), ) - source_models = source_models_lib.SourceModels() + source_models = source_models_builder() pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice_theta0, diff --git a/torax/sim.py b/torax/sim.py index 81608854..3318fe8e 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -1349,33 +1349,16 @@ def update_current_distribution( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles, ) - # Calculate splitting of currents depending on input runtime params. - dynamic_generic_current_params = ( - core_profile_setters.get_generic_current_params( - dynamic_runtime_params_slice, source_models - ) - ) - # calculate "External" current profile (e.g. ECCD) # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.generic_current_source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -1525,9 +1508,6 @@ def get_initial_source_profiles( qei = source_models.qei_source.get_qei( static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.qei_source_name - ], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index 1f75a8c8..8f41ff60 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -91,6 +91,11 @@ class BootstrapCurrentSource(source.Source): """ SOURCE_NAME: ClassVar[str] = 'j_bootstrap' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_neoclassical' + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: @@ -111,12 +116,16 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: def get_value( self, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.BootstrapCurrentProfile: + static_source_runtime_params = static_runtime_params_slice.sources[ + self.source_name + ] + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] # Make sure the input mode requested is supported. self.check_mode(static_source_runtime_params.mode) # Make sure the input params are the correct type. @@ -127,7 +136,6 @@ def get_value( ) bootstrap_current = calc_neoclassical( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, temp_ion=core_profiles.temp_ion, temp_el=core_profiles.temp_el, @@ -175,7 +183,6 @@ def get_source_profile_for_affected_core_profile( @jax_utils.jit def calc_neoclassical( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: DynamicRuntimeParams, geo: geometry.Geometry, temp_ion: cell_variable.CellVariable, temp_el: cell_variable.CellVariable, @@ -187,7 +194,6 @@ def calc_neoclassical( Args: dynamic_runtime_params_slice: General configuration parameters. - dynamic_source_runtime_params: Source-specific runtime parameters. geo: Torus geometry. temp_ion: Ion temperature. We don't pass in a full `core_profiles` here because this function is used to create the `Currents` in the initial @@ -200,6 +206,10 @@ def calc_neoclassical( Returns: A BootstrapCurrentProfile. See that class's docstring for more info. """ + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + BootstrapCurrentSource.SOURCE_NAME + ] + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index 312668f1..7bc38cac 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -122,19 +122,20 @@ def calc_relativistic_correction() -> jax.Array: def bremsstrahlung_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_model_func: source_models.SourceModels | None, ) -> jax.Array: """Model function for the Bremsstrahlung heat sink.""" del ( - static_source_runtime_params, static_runtime_params_slice, unused_model_func, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) _, P_brem_profile = calc_bremsstrahlung( core_profiles, @@ -152,8 +153,13 @@ class BremsstrahlungHeatSink(source.Source): """Brehmsstrahlung heat sink for electron heat equation.""" SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'bremsstrahlung_model_func' model_func: source.SourceProfileFunction = bremsstrahlung_model_func + @property + def source_name(self) -> str: + return self.SOURCE_NAME + @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: """Returns the modes supported by this source.""" diff --git a/torax/sources/electron_cyclotron_source.py b/torax/sources/electron_cyclotron_source.py index e5f02023..e404c15d 100644 --- a/torax/sources/electron_cyclotron_source.py +++ b/torax/sources/electron_cyclotron_source.py @@ -103,14 +103,13 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): gaussian_ec_total_power: array_typing.ScalarFloat -def _calc_heating_and_current( +def calc_heating_and_current( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, - unused_model_func: source_models.SourceModels, + unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Model function for the electron-cyclotron source. @@ -119,11 +118,9 @@ def _calc_heating_and_current( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters. dynamic_runtime_params_slice: Global runtime parameters - dynamic_source_runtime_params: Specific runtime parameters for the - electron-cyclotron source. geo: Magnetic geometry. + source_name: Name of the source. core_profiles: CoreProfiles component of the state. unused_model_func: (unused) source models used in the simulation. @@ -131,10 +128,12 @@ def _calc_heating_and_current( 2D array of electron cyclotron heating power density and current density. """ del ( - static_source_runtime_params, - unused_model_func, + unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] # Helps linter understand the type of dynamic_source_runtime_params. assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Construct the profile @@ -188,7 +187,12 @@ class ElectronCyclotronSource(source.Source): """Electron cyclotron source for the Te and Psi equations.""" SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source" - model_func: source.SourceProfileFunction = _calc_heating_and_current + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = "calc_heating_and_current" + model_func: source.SourceProfileFunction = calc_heating_and_current + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 9038ed23..e03652e2 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -41,7 +41,7 @@ class GasPuffRuntimeParams(runtime_params_lib.RuntimeParams): puff_decay_length: runtime_params_lib.TimeInterpolatedInput = 0.05 # total gas puff particles/s S_puff_tot: runtime_params_lib.TimeInterpolatedInput = 1e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -76,21 +76,22 @@ class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams): # Default formula: exponential with nref normalization. -def _calc_puff_source( +def calc_puff_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from puffs.""" del ( - static_source_runtime_params, unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicGasPuffRuntimeParams) return formulas.exponential_profile( c1=1.0, @@ -108,7 +109,12 @@ class GasPuffSource(source.Source): """Gas puff source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'gas_puff_source' - formula: source.SourceProfileFunction = _calc_puff_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source' + model_func: source.SourceProfileFunction = calc_puff_source + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -125,7 +131,7 @@ class GenericParticleSourceRuntimeParams(runtime_params_lib.RuntimeParams): deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.0 # total particle source S_tot: runtime_params_lib.TimeInterpolatedInput = 1e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -155,7 +161,6 @@ def build_dynamic_params( particle_width=float(self.particle_width.get_value(t)), deposition_location=float(self.deposition_location.get_value(t)), S_tot=float(self.S_tot.get_value(t)), - formula=self.formula.build_dynamic_params(t), prescribed_values=self.prescribed_values.get_value(t), ) @@ -167,21 +172,22 @@ class DynamicParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams): S_tot: array_typing.ScalarFloat -def _calc_generic_particle_source( +def calc_generic_particle_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from SBI.""" del ( - static_source_runtime_params, unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicParticleRuntimeParams) return formulas.gaussian_profile( c1=dynamic_source_runtime_params.deposition_location, @@ -199,7 +205,12 @@ class GenericParticleSource(source.Source): """Neutral-beam injection source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'generic_particle_source' - formula: source.SourceProfileFunction = _calc_generic_particle_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source' + model_func: source.SourceProfileFunction = calc_generic_particle_source + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -218,7 +229,7 @@ class PelletRuntimeParams(runtime_params_lib.RuntimeParams): pellet_deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.85 # total pellet particles/s (continuous pellet model) S_pellet_tot: runtime_params_lib.TimeInterpolatedInput = 2e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -250,21 +261,22 @@ class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams): S_pellet_tot: array_typing.ScalarFloat -def _calc_pellet_source( +def calc_pellet_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from pellets.""" del ( - static_source_runtime_params, unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicPelletRuntimeParams) return formulas.gaussian_profile( c1=dynamic_source_runtime_params.pellet_deposition_location, @@ -282,23 +294,13 @@ class PelletSource(source.Source): """Pellet source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'pellet_source' - formula: source.SourceProfileFunction = _calc_pellet_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_pellet_source' + model_func: source.SourceProfileFunction = calc_pellet_source + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: return (source.AffectedCoreProfile.NE,) - - -# pylint: enable=invalid-name -# The sources below don't have any source-specific implementations, so their -# bodies are empty. You can refer to their base class to see the implementation. -# We define new classes here to: -# a) support any future source-specific implementation. -# b) better readability and human-friendly error messages when debugging. -@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class RecombinationDensitySink(source.Source): - """Recombination sink for the electron density equation.""" - - affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( - source.AffectedCoreProfile.NE, - ) diff --git a/torax/sources/formula_config.py b/torax/sources/formula_config.py deleted file mode 100644 index 589ef2e6..00000000 --- a/torax/sources/formula_config.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Defines the runtime (dynamic) configuration of formulas used in sources.""" - -from __future__ import annotations - -import dataclasses -from typing import TypeAlias - -import chex -from torax import array_typing -from torax import interpolated_param -from torax.config import base -from torax.config import config_args -from torax.geometry import geometry - - -TimeInterpolatedInput: TypeAlias = interpolated_param.TimeInterpolatedInput - - -@chex.dataclass(frozen=True) -class DynamicFormula: - """Base class for dynamic configs.""" - - -@dataclasses.dataclass -class Exponential(base.RuntimeParametersConfig['ExponentialProvider']): - """Configures an exponential formula. - - See formulas.Exponential for more information on how this config is used. - """ - - # floats to parameterize the different formulas. - total: TimeInterpolatedInput = 1.0 - c1: TimeInterpolatedInput = 1.0 - c2: TimeInterpolatedInput = 1.0 - - def make_provider( - self, torax_mesh: geometry.Grid1D | None = None - ) -> ExponentialProvider: - del torax_mesh # Unused. - return ExponentialProvider( - runtime_params_config=self, - total=config_args.get_interpolated_var_single_axis( - self.total, - ), - c1=config_args.get_interpolated_var_single_axis( - self.c1, - ), - c2=config_args.get_interpolated_var_single_axis( - self.c2, - ), - ) - - -@chex.dataclass -class ExponentialProvider(base.RuntimeParametersProvider['DynamicExponential']): - """Runtime parameter provider for a single source/sink term.""" - - runtime_params_config: Exponential - total: interpolated_param.InterpolatedVarSingleAxis - c1: interpolated_param.InterpolatedVarSingleAxis - c2: interpolated_param.InterpolatedVarSingleAxis - - def build_dynamic_params( - self, - t: chex.Numeric, - ) -> DynamicExponential: - return DynamicExponential(**self.get_dynamic_params_kwargs(t)) - - -@chex.dataclass(frozen=True) -class DynamicExponential(DynamicFormula): - - total: array_typing.ScalarFloat - c1: array_typing.ScalarFloat - c2: array_typing.ScalarFloat - - -@dataclasses.dataclass -class Gaussian(base.RuntimeParametersConfig['GaussianProvider']): - """Configures a Gaussian formula. - - See formulas.Gaussian for more information on how this config is used. - """ - - # floats to parameterize the different formulas. - total: TimeInterpolatedInput = 1.0 - c1: TimeInterpolatedInput = 1.0 - c2: TimeInterpolatedInput = 1.0 - - def make_provider( - self, torax_mesh: geometry.Grid1D | None = None - ) -> GaussianProvider: - del torax_mesh # Unused. - return GaussianProvider( - runtime_params_config=self, - total=config_args.get_interpolated_var_single_axis( - self.total, - ), - c1=config_args.get_interpolated_var_single_axis( - self.c1, - ), - c2=config_args.get_interpolated_var_single_axis( - self.c2, - ), - ) - - -@chex.dataclass -class GaussianProvider(base.RuntimeParametersProvider['DynamicGaussian']): - """Runtime parameter provider for a single source/sink term.""" - - runtime_params_config: Gaussian - total: interpolated_param.InterpolatedVarSingleAxis - c1: interpolated_param.InterpolatedVarSingleAxis - c2: interpolated_param.InterpolatedVarSingleAxis - - def build_dynamic_params( - self, - t: chex.Numeric, - ) -> DynamicGaussian: - return DynamicGaussian(**self.get_dynamic_params_kwargs(t)) - - -@chex.dataclass(frozen=True) -class DynamicGaussian(DynamicFormula): - - total: array_typing.ScalarFloat - c1: array_typing.ScalarFloat - c2: array_typing.ScalarFloat diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index 9ff681f6..d3c08c0e 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -13,18 +13,9 @@ # limitations under the License. """Prescribed formulas for computing source profiles.""" - -import dataclasses -from typing import Optional import jax from jax import numpy as jnp -from torax import state -from torax.config import runtime_params_slice from torax.geometry import geometry -from torax.sources import formula_config -from torax.sources import runtime_params - - # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name @@ -103,71 +94,3 @@ def gaussian_profile( geo.vpr_face * S_face, geo.rho_face_norm ) return C * S - - -# pylint: enable=invalid-name - - -# Callable classes used as arguments for Source formulas. - - -@dataclasses.dataclass(frozen=True) -class Exponential: - """Callable class providing an exponential profile.""" - - def __call__( # pytype: disable=name-error - self, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params.StaticRuntimeParams, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, - geo: geometry.Geometry, - unused_state: state.CoreProfiles | None, - unused_source_models: Optional['source_models.SourceModels'] = None, - ) -> jax.Array: - del ( - dynamic_runtime_params_slice, - static_runtime_params_slice, - static_source_runtime_params, - unused_state, - unused_source_models, - ) # Unused. - exp_config = dynamic_source_runtime_params.formula - assert isinstance(exp_config, formula_config.DynamicExponential) - return exponential_profile( - c1=exp_config.c1, - c2=exp_config.c2, - total=exp_config.total, - geo=geo, - ) - - -@dataclasses.dataclass(frozen=True) -class Gaussian: - """Callable class providing a gaussian profile.""" - - def __call__( # pytype: disable=name-error - self, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params.StaticRuntimeParams, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, - geo: geometry.Geometry, - unused_state: state.CoreProfiles | None, - unused_source_models: Optional['source_models.SourceModels'] = None, - ) -> jax.Array: - del ( - dynamic_runtime_params_slice, - static_runtime_params_slice, - static_source_runtime_params, - unused_state, - unused_source_models, - ) # Unused. - gaussian_config = dynamic_source_runtime_params.formula - assert isinstance(gaussian_config, formula_config.DynamicGaussian) - return gaussian_profile( - c1=gaussian_config.c1, - c2=gaussian_config.c2, - total=gaussian_config.total, - geo=geo, - ) diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index 9e395874..ce6f4f34 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -121,20 +121,15 @@ def calc_fusion( # pytype: disable=name-error def fusion_heat_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Model function for fusion heating.""" # pytype: enable=name-error - del ( - dynamic_source_runtime_params, - static_source_runtime_params, - static_runtime_params_slice, - ) # Unused. + del static_runtime_params_slice, source_name # Unused. # pylint: disable=invalid-name _, Pfus_i, Pfus_e = calc_fusion( geo, core_profiles, dynamic_runtime_params_slice.numerics.nref @@ -148,8 +143,13 @@ class FusionHeatSource(source.Source): """Fusion heat source for both ion and electron heat.""" SOURCE_NAME: ClassVar[str] = 'fusion_heat_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'fusion_heat_model_func' model_func: source.SourceProfileFunction = fusion_heat_model_func + @property + def source_name(self) -> str: + return self.SOURCE_NAME + @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: return ( diff --git a/torax/sources/generic_current_source.py b/torax/sources/generic_current_source.py index bef21640..ba39ad34 100644 --- a/torax/sources/generic_current_source.py +++ b/torax/sources/generic_current_source.py @@ -52,7 +52,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): # Toggles if external current is provided absolutely or as a fraction of Ip. use_absolute_current: bool = False - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED @property def grid_type(self) -> base.GridType: @@ -105,12 +105,11 @@ def __post_init__(self): # pytype bug: does not treat 'source_models.SourceModels' as a forward reference # pytype: disable=name-error -def _calculate_generic_current_face( +def calculate_generic_current_face( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: @@ -118,23 +117,25 @@ def _calculate_generic_current_face( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters. dynamic_runtime_params_slice: Parameter configuration at present timestep. - dynamic_source_runtime_params: Source-specific parameters at the present - timestep. geo: Tokamak geometry. + source_name: Name of the source. unused_state: State argument not used in this function but is present to adhere to the source API. + unused_source_models: Source models argument not used in this function but + is present to adhere to the source API. Returns: External current density profile along the face grid. """ del ( - static_source_runtime_params, static_runtime_params_slice, unused_state, unused_source_models, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] # pytype: enable=name-error assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) Iext = _calculate_Iext( @@ -163,10 +164,9 @@ def _calculate_generic_current_face( # pytype: disable=name-error def _calculate_generic_current_hires( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: @@ -174,23 +174,25 @@ def _calculate_generic_current_hires( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters. dynamic_runtime_params_slice: Parameter configuration at present timestep. - dynamic_source_runtime_params: Source-specific parameters at the present - timestep. geo: Tokamak geometry. + source_name: Name of the source. unused_state: State argument not used in this function but is present to adhere to the source API. + unused_source_models: Source models argument not used in this function but + is present to adhere to the source API. Returns: Generic current density profile along the hires cell grid. """ del ( - static_source_runtime_params, static_runtime_params_slice, unused_state, unused_source_models, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] # pytype: enable=name-error assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) Iext = _calculate_Iext( @@ -233,8 +235,13 @@ class GenericCurrentSource(source.Source): """A generic current density source profile.""" SOURCE_NAME: ClassVar[str] = 'generic_current_source' - formula: source.SourceProfileFunction = _calculate_generic_current_face + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face' hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires + model_func: source.SourceProfileFunction = calculate_generic_current_face + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -261,10 +268,8 @@ def get_source_profile_for_affected_core_profile( # pytype: disable=name-error def generic_current_source_hires( self, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, unused_source_models: Optional['source_models.SourceModels'] = None, @@ -272,6 +277,12 @@ def generic_current_source_hires( # pytype: enable=name-error """Return the current density profile along the hires cell grid.""" del unused_state, unused_source_models # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] + static_source_runtime_params = static_runtime_params_slice.sources[ + self.source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) self.check_mode(static_source_runtime_params.mode) @@ -289,19 +300,14 @@ def generic_current_source_hires( return source.get_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, core_profiles=None, # There is no model for this source. - model_func=( - lambda _0, _1, _2, _3, _4, _5, _6: jnp.zeros_like( - geo.rho_hires_norm - ) - ), - formula=self.hires_formula, + model_func=self.hires_formula, + formula=None, output_shape=geo.rho_hires_norm.shape, prescribed_values=hires_prescribed_values, source_models=getattr(self, 'source_models', None), + source_name=self.source_name, ) diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index 3dd761fd..f7eca7ba 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -48,7 +48,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): Ptot: runtime_params_lib.TimeInterpolatedInput = 120e6 # electron heating fraction el_heat_fraction: runtime_params_lib.TimeInterpolatedInput = 0.66666 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -113,24 +113,24 @@ def calc_generic_heat_source( # pytype: disable=name-error -def _default_formula( +def default_formula( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Returns the default formula-based ion/electron heat source profile.""" # pytype: enable=name-error del ( - dynamic_runtime_params_slice, core_profiles, - static_source_runtime_params, static_runtime_params_slice, unused_source_models, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) ion, el = calc_generic_heat_source( geo, @@ -150,7 +150,12 @@ class GenericIonElectronHeatSource(source.Source): """Generic heat source for both ion and electron heat.""" SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source' - formula: source.SourceProfileFunction = _default_formula + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula' + model_func: source.SourceProfileFunction = default_formula + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: diff --git a/torax/sources/impurity_radiation_heat_sink.py b/torax/sources/impurity_radiation_heat_sink.py index 4e9bff90..06c3e7da 100644 --- a/torax/sources/impurity_radiation_heat_sink.py +++ b/torax/sources/impurity_radiation_heat_sink.py @@ -16,6 +16,7 @@ """Basic impurity radiation heat sink for electron heat equation..""" import dataclasses +from typing import ClassVar import chex import jax @@ -30,12 +31,11 @@ from torax.sources import source_models as source_models_lib -def _radially_constant_fraction_of_Pin( # pylint: disable=invalid-name +def radially_constant_fraction_of_Pin( # pylint: disable=invalid-name static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, source_models: source_models_lib.SourceModels, ) -> jax.Array: @@ -46,37 +46,30 @@ def _radially_constant_fraction_of_Pin( # pylint: disable=invalid-name Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static source runtime parameters. dynamic_runtime_params_slice: Dynamic runtime parameters. - dynamic_source_runtime_params: Dynamic source runtime parameters. geo: Geometry object. + source_name: Name of the source. core_profiles: Core profiles object. source_models: Source models object. Returns: The heat sink profile. """ - del (static_source_runtime_params,) # Unused + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Based on source_models.sum_sources_temp_el and source_models.calc_and_sum # sources_psi, but only summing over heating *input* sources # (Pohm + Paux + Palpha + ...) and summing over *both* ion + electron heating - def get_heat_source_profile( - source_name: str, source: source_lib.Source - ) -> jax.Array: + def get_heat_source_profile(source: source_lib.Source) -> jax.Array: # TODO(b/381543891): Currently this recomputes the profile for each source, # which is inefficient. Refactor to avoid this. profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -95,7 +88,6 @@ def get_heat_source_profile( } source_profiles = jax.tree.map( get_heat_source_profile, - list(heat_sources.keys()), list(heat_sources.values()), ) Qtot_in = jnp.sum(jnp.stack(source_profiles), axis=0) @@ -148,10 +140,17 @@ class ImpurityRadiationHeatSink(source_lib.Source): """Impurity radiation heat sink for electron heat equation.""" SOURCE_NAME = "impurity_radiation_heat_sink" - source_models: source_models_lib.SourceModels + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = ( + "radially_constant_fraction_of_Pin" + ) model_func: source_lib.SourceProfileFunction = ( - _radially_constant_fraction_of_Pin + radially_constant_fraction_of_Pin ) + source_models: source_models_lib.SourceModels + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: diff --git a/torax/sources/ion_cyclotron_source.py b/torax/sources/ion_cyclotron_source.py index 13e92f4f..f2d88a2d 100644 --- a/torax/sources/ion_cyclotron_source.py +++ b/torax/sources/ion_cyclotron_source.py @@ -364,12 +364,11 @@ def _helium3_tail_temperature( return core_profiles.temp_el.value * (1 + epsilon) -def _icrh_model_func( +def icrh_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_source_models: source_models.SourceModels | None, toric_nn: ToricNNWrapper, @@ -377,10 +376,12 @@ def _icrh_model_func( """Compute ion/electron heat source terms.""" del ( unused_source_models, - dynamic_runtime_params_slice, - static_source_runtime_params, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Construct inputs for ToricNN. volume = integrate.trapezoid(geo.vpr_face, geo.rho_face_norm) @@ -484,8 +485,6 @@ def _icrh_model_func( source_ion += power_deposition_2T * dynamic_source_runtime_params.Ptot return jnp.stack([source_ion, source_el]) - - # pylint: enable=invalid-name @@ -494,18 +493,11 @@ class IonCyclotronSource(source.Source): """Ion cyclotron source with surrogate model.""" SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source' - # The model function is fixed to _icrh_model_func because that is the only - # supported implementation of this source. - # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set the default to the - # desired value. - model_func: source.SourceProfileFunction = dataclasses.field( - init=False, - default_factory=lambda: functools.partial( - _icrh_model_func, - toric_nn=ToricNNWrapper(), - ), - ) + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'icrh_model_func' + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: @@ -525,3 +517,28 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @property def output_shape_getter(self) -> source.SourceOutputShapeFunction: return source.get_ion_el_output_shape + + +@dataclasses.dataclass(kw_only=True, frozen=False) +class IonCyclotronSourceBuilder: + """Builder for the IonCyclotronSource.""" + + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams + ) + links_back: bool = False + model_func: source.SourceProfileFunction | None = None + + def __post_init__(self): + if self.model_func is None: + self.model_func = functools.partial( + icrh_model_func, + toric_nn=ToricNNWrapper(), + ) + + def __call__( + self, + formula: source.SourceProfileFunction | None = None, + ) -> IonCyclotronSource: + + return IonCyclotronSource(formula=formula, model_func=self.model_func,) diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py index 7a31037f..0c5c8166 100644 --- a/torax/sources/ohmic_heat_source.py +++ b/torax/sources/ohmic_heat_source.py @@ -145,16 +145,14 @@ def calc_psidot( def ohmic_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels | None = None, + source_models: source_models_lib.SourceModels, ) -> jax.Array: """Returns the Ohmic source for electron heat equation.""" - del dynamic_source_runtime_params, static_source_runtime_params - + del source_name # Unused. if source_models is None: raise TypeError('source_models is a required argument for ohmic_model_func') @@ -192,17 +190,14 @@ class OhmicHeatSource(source_lib.Source): """ SOURCE_NAME: ClassVar[str] = 'ohmic_heat_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'ohmic_model_func' + model_func: source_lib.SourceProfileFunction = ohmic_model_func # Users must pass in a pointer to the complete set of sources to this object. source_models: source_models_lib.SourceModels - # The model function is fixed to ohmic_model_func because that is the only - # supported implementation of this source. - # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set the default to the - # desired value. - model_func: source_lib.SourceProfileFunction | None = dataclasses.field( - init=False, - default_factory=lambda: ohmic_model_func, - ) + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index 26ada6dd..990bbc25 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -72,6 +72,11 @@ class QeiSource(source.Source): """ SOURCE_NAME: ClassVar[str] = 'qei_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'model_based_qei' + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: @@ -92,14 +97,16 @@ def get_qei( self, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.QeiInfo: """Computes the value of the source.""" - self.check_mode(static_runtime_params_slice.sources[self.SOURCE_NAME].mode) + self.check_mode(static_runtime_params_slice.sources[self.source_name].mode) + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] return jax.lax.cond( - static_runtime_params_slice.sources[self.SOURCE_NAME].mode + static_runtime_params_slice.sources[self.source_name].mode == runtime_params_lib.Mode.MODEL_BASED.value, lambda: _model_based_qei( static_runtime_params_slice, diff --git a/torax/sources/register_source.py b/torax/sources/register_source.py index c705a204..f8cf784f 100644 --- a/torax/sources/register_source.py +++ b/torax/sources/register_source.py @@ -35,6 +35,7 @@ class to build, the runtime associated with that source and (optionally) the expected to grow over time as TORAX becomes more feature rich but ultimately be finite. """ + import dataclasses from typing import Type @@ -54,116 +55,149 @@ class to build, the runtime associated with that source and (optionally) the @dataclasses.dataclass(frozen=True) -class RegisteredSource: - source_class: Type[source.Source] - source_builder_class: source.SourceBuilderProtocol - default_runtime_params_class: Type[runtime_params.RuntimeParams] - - -def _register_new_source( - source_class: Type[source.Source], - default_runtime_params_class: Type[runtime_params.RuntimeParams], - source_builder_class: source.SourceBuilderProtocol | None = None, - links_back: bool = False, -) -> RegisteredSource: - """Register source class, default runtime params and (optional) builder for this source. - - Args: - source_class: The source class. - default_runtime_params_class: The default runtime params class. - source_builder_class: The source builder class. If None, a default builder - is created which uses the source class and default runtime params class to - construct a builder for that source. - links_back: Whether the source requires a reference to all the source - models. - - Returns: - A `RegisteredSource` dataclass containing the source class, source - builder class, and default runtime params class. - """ - if source_builder_class is None: - builder_class = source.make_source_builder( - source_class, - runtime_params_type=default_runtime_params_class, - links_back=links_back, - ) - else: - builder_class = source_builder_class - - return RegisteredSource( - source_class=source_class, - source_builder_class=builder_class, - default_runtime_params_class=default_runtime_params_class, - ) +class ModelFunction: + source_profile_function: source.SourceProfileFunction | None + runtime_params_class: Type[runtime_params.RuntimeParams] + source_builder_class: source.SourceBuilderProtocol | None = None + links_back: bool = False -_REGISTERED_SOURCES = { - bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME: ( - _register_new_source( - source_class=bootstrap_current_source.BootstrapCurrentSource, - default_runtime_params_class=bootstrap_current_source.RuntimeParams, - ) +@dataclasses.dataclass(frozen=True) +class SupportedSource: + """Source that can be used in TORAX and any associated model functions.""" + source_class: Type[source.Source] + model_functions: dict[str, ModelFunction] + + +_SUPPORTED_SOURCES = { + bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME: SupportedSource( + source_class=bootstrap_current_source.BootstrapCurrentSource, + model_functions={ + bootstrap_current_source.BootstrapCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=bootstrap_current_source.RuntimeParams, + ) + }, ), - generic_current_source.GenericCurrentSource.SOURCE_NAME: ( - _register_new_source( - source_class=generic_current_source.GenericCurrentSource, - default_runtime_params_class=generic_current_source.RuntimeParams, - ) + generic_current_source.GenericCurrentSource.SOURCE_NAME: SupportedSource( + source_class=generic_current_source.GenericCurrentSource, + model_functions={ + generic_current_source.GenericCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=generic_current_source.calculate_generic_current_face, + runtime_params_class=generic_current_source.RuntimeParams, + ) + }, ), - electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: _register_new_source( + electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: SupportedSource( source_class=electron_cyclotron_source.ElectronCyclotronSource, - default_runtime_params_class=electron_cyclotron_source.RuntimeParams, + model_functions={ + electron_cyclotron_source.ElectronCyclotronSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_cyclotron_source.calc_heating_and_current, + runtime_params_class=electron_cyclotron_source.RuntimeParams, + ) + }, ), - electron_density_sources.GenericParticleSource.SOURCE_NAME: _register_new_source( + electron_density_sources.GenericParticleSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.GenericParticleSource, - default_runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams, + model_functions={ + electron_density_sources.GenericParticleSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_generic_particle_source, + runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams, + ) + }, ), - electron_density_sources.GasPuffSource.SOURCE_NAME: _register_new_source( + electron_density_sources.GasPuffSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.GasPuffSource, - default_runtime_params_class=electron_density_sources.GasPuffRuntimeParams, + model_functions={ + electron_density_sources.GasPuffSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_puff_source, + runtime_params_class=electron_density_sources.GasPuffRuntimeParams, + ) + }, ), - electron_density_sources.PelletSource.SOURCE_NAME: _register_new_source( + electron_density_sources.PelletSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.PelletSource, - default_runtime_params_class=electron_density_sources.PelletRuntimeParams, + model_functions={ + electron_density_sources.PelletSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_pellet_source, + runtime_params_class=electron_density_sources.PelletRuntimeParams, + ) + }, ), - ion_el_heat.GenericIonElectronHeatSource.SOURCE_NAME: _register_new_source( + ion_el_heat.GenericIonElectronHeatSource.SOURCE_NAME: SupportedSource( source_class=ion_el_heat.GenericIonElectronHeatSource, - default_runtime_params_class=ion_el_heat.RuntimeParams, + model_functions={ + ion_el_heat.GenericIonElectronHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=ion_el_heat.default_formula, + runtime_params_class=ion_el_heat.RuntimeParams, + ) + }, ), - fusion_heat_source.FusionHeatSource.SOURCE_NAME: _register_new_source( + fusion_heat_source.FusionHeatSource.SOURCE_NAME: SupportedSource( source_class=fusion_heat_source.FusionHeatSource, - default_runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, + model_functions={ + fusion_heat_source.FusionHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=fusion_heat_source.fusion_heat_model_func, + runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, + ) + }, ), - qei_source.QeiSource.SOURCE_NAME: _register_new_source( + qei_source.QeiSource.SOURCE_NAME: SupportedSource( source_class=qei_source.QeiSource, - default_runtime_params_class=qei_source.RuntimeParams, + model_functions={ + qei_source.QeiSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=qei_source.RuntimeParams, + ) + }, ), - ohmic_heat_source.OhmicHeatSource.SOURCE_NAME: _register_new_source( + ohmic_heat_source.OhmicHeatSource.SOURCE_NAME: SupportedSource( source_class=ohmic_heat_source.OhmicHeatSource, - default_runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, - links_back=True, + model_functions={ + ohmic_heat_source.OhmicHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ( + ModelFunction( + source_profile_function=ohmic_heat_source.ohmic_model_func, + runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, + links_back=True, + ) + ) + }, ), - bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME: ( - _register_new_source( - source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, - default_runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, - ) + bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME: SupportedSource( + source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, + model_functions={ + bremsstrahlung_heat_sink.BremsstrahlungHeatSink.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=bremsstrahlung_heat_sink.bremsstrahlung_model_func, + runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, + ) + }, ), - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: _register_new_source( + ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: SupportedSource( source_class=ion_cyclotron_source.IonCyclotronSource, - default_runtime_params_class=ion_cyclotron_source.RuntimeParams, + model_functions={ + ion_cyclotron_source.IonCyclotronSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=ion_cyclotron_source.RuntimeParams, + source_builder_class=ion_cyclotron_source.IonCyclotronSourceBuilder, + ) + }, ), - impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: _register_new_source( + impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: SupportedSource( source_class=impurity_radiation_heat_sink.ImpurityRadiationHeatSink, - default_runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, - links_back=True, + model_functions={ + impurity_radiation_heat_sink.ImpurityRadiationHeatSink.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=impurity_radiation_heat_sink.radially_constant_fraction_of_Pin, + runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, + links_back=True, + ) + }, ), } -def get_registered_source(source_name: str) -> RegisteredSource: - """Used when building a simulation to get the registered source.""" - if source_name in _REGISTERED_SOURCES: - return _REGISTERED_SOURCES[source_name] +def get_supported_source(source_name: str) -> SupportedSource: + """Used when building a simulation to get the supported source.""" + if source_name in _SUPPORTED_SOURCES: + return _SUPPORTED_SOURCES[source_name] else: raise RuntimeError(f'Source:{source_name} has not been registered.') diff --git a/torax/sources/runtime_params.py b/torax/sources/runtime_params.py index 4200cffc..b1b38314 100644 --- a/torax/sources/runtime_params.py +++ b/torax/sources/runtime_params.py @@ -25,7 +25,6 @@ from torax import interpolated_param from torax.config import base from torax.geometry import geometry -from torax.sources import formula_config TimeInterpolatedInput = interpolated_param.TimeInterpolatedInput @@ -81,12 +80,6 @@ class RuntimeParams(base.RuntimeParametersConfig): # running the simulation. is_explicit: bool = False - # Parameters used only when the source is using a prescribed formula to - # compute its profile. - formula: base.RuntimeParametersConfig = dataclasses.field( - default_factory=formula_config.Exponential - ) - # Prescribed values for the source. Used only when the source is fully # prescribed (i.e. source.mode == Mode.PRESCRIBED). # The default here is a vector of all zeros along for all rho and time, and @@ -118,7 +111,6 @@ class RuntimeParamsProvider( """Runtime parameter provider for a single source/sink term.""" runtime_params_config: RuntimeParams - formula: base.RuntimeParametersProvider prescribed_values: interpolated_param.InterpolatedVarTimeRho def get_dynamic_params_kwargs( @@ -148,8 +140,6 @@ class DynamicRuntimeParams: stateless, so these params are their inputs to determine their output profiles. """ - - formula: formula_config.DynamicFormula prescribed_values: array_typing.ArrayFloat diff --git a/torax/sources/source.py b/torax/sources/source.py index 34f03278..753ae8ce 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -28,7 +28,7 @@ import enum import types import typing -from typing import Any, Callable, Optional, Protocol, TypeAlias +from typing import Any, Callable, ClassVar, Optional, Protocol, TypeAlias # We use Optional here because | doesn't work with string name types. # We use string name 'source_models.SourceModels' in this file to avoid @@ -43,21 +43,23 @@ from torax.sources import runtime_params as runtime_params_lib -# Sources implement these functions to be able to provide source profiles. # pytype bug: 'source_models.SourceModels' not treated as forward reference -SourceProfileFunction: TypeAlias = Callable[ # pytype: disable=name-error - [ # Arguments - runtime_params_slice.StaticRuntimeParamsSlice, # Static runtime params. - runtime_params_lib.StaticRuntimeParams, # Source-specific params. - runtime_params_slice.DynamicRuntimeParamsSlice, # General config params - runtime_params_lib.DynamicRuntimeParams, # Source-specific params. - geometry.Geometry, - state.CoreProfiles, - Optional['source_models.SourceModels'], - ], - # Returns a JAX array, tuple of arrays, or mapping of arrays. - chex.ArrayTree, -] +# pytype: disable=name-error +@typing.runtime_checkable +class SourceProfileFunction(Protocol): + """Sources implement these functions to be able to provide source profiles.""" + + def __call__( + self, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + source_name: str, + core_profiles: state.CoreProfiles, + source_models: Optional['source_models.SourceModels'], + ) -> chex.ArrayTree: + ... +# pytype: enable=name-error # Any callable which takes the dynamic runtime_params, geometry, and optional @@ -105,6 +107,8 @@ class Source(abc.ABC): are in turn used to compute coeffs in sim.py. Attributes: + DEFAULT_MODEL_FUNCTION_NAME: The name of the model function used with this + source if another isn't specified. runtime_params: Input dataclass containing all the source-specific runtime parameters. At runtime, the parameters here are interpolated to a specific time t and then passed to the model_func or formula, depending on the mode @@ -128,10 +132,15 @@ class Source(abc.ABC): affected_core_profiles_ints: Derived property from the affected_core_profiles. Integer values of those enums. """ - + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default' model_func: SourceProfileFunction | None = None formula: SourceProfileFunction | None = None + @property + @abc.abstractmethod + def source_name(self) -> str: + """Returns the name of the source.""" + @property @abc.abstractmethod def affected_core_profiles(self) -> tuple[AffectedCoreProfile, ...]: @@ -148,6 +157,7 @@ def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: return ( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED, + runtime_params_lib.Mode.MODEL_BASED, runtime_params_lib.Mode.PRESCRIBED, ) @@ -169,9 +179,7 @@ def check_mode( def get_value( self, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> chex.ArrayTree: @@ -179,11 +187,8 @@ def get_value( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters for this source. dynamic_runtime_params_slice: Slice of the general TORAX config that can be used as input for this time step. - dynamic_source_runtime_params: Slice of this source's runtime parameters - at a specific time t. geo: Geometry of the torus. core_profiles: Core plasma profiles. May be the profiles at the start of the time step or a "live" set of core profiles being actively updated @@ -195,31 +200,26 @@ def get_value( Returns: Array, arrays, or nested dataclass/dict of arrays for the source profile. """ + static_source_runtime_params = static_runtime_params_slice.sources[ + self.source_name + ] + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] self.check_mode(static_source_runtime_params.mode) output_shape = self.output_shape_getter(geo) - model_func = ( - (lambda _0, _1, _2, _3, _4, _5, _6: jnp.zeros(output_shape)) - if self.model_func is None - else self.model_func - ) - formula = ( - (lambda _0, _1, _2, _3, _4, _5, _6: jnp.zeros(output_shape)) - if self.formula is None - else self.formula - ) return get_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, core_profiles=core_profiles, - model_func=model_func, - formula=formula, + model_func=self.model_func, + formula=self.formula, prescribed_values=dynamic_source_runtime_params.prescribed_values, output_shape=output_shape, source_models=getattr(self, 'source_models', None), + source_name=self.source_name, ) def get_source_profile_for_affected_core_profile( @@ -301,13 +301,12 @@ def get_profile_shape(self, geo: geometry.Geometry) -> tuple[int, ...]: # pytype: disable=name-error def get_source_profiles( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, - model_func: SourceProfileFunction, - formula: SourceProfileFunction, + model_func: SourceProfileFunction | None, + formula: SourceProfileFunction | None, prescribed_values: chex.Array, output_shape: tuple[int, ...], source_models: Optional['source_models.SourceModels'], @@ -321,12 +320,10 @@ def get_source_profiles( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters for this source. dynamic_runtime_params_slice: Slice of the general TORAX config that can be used as input for this time step. - dynamic_source_runtime_params: Slice of this source's runtime parameters at - a specific time t. geo: Geometry information. Used as input to the source profile functions. + source_name: The name of the source. core_profiles: Core plasma profiles. Used as input to the source profile functions. model_func: Model function. @@ -340,25 +337,31 @@ def get_source_profiles( Output array of a profile or concatenated/stacked profiles. """ # pytype: enable=name-error - mode = static_source_runtime_params.mode + mode = static_runtime_params_slice.sources[source_name].mode match mode: case runtime_params_lib.Mode.MODEL_BASED.value: + if model_func is None: + raise ValueError( + 'Source is in MODEL_BASED mode but has no model function.' + ) return model_func( static_runtime_params_slice, - static_source_runtime_params, dynamic_runtime_params_slice, - dynamic_source_runtime_params, geo, + source_name, core_profiles, source_models, ) case runtime_params_lib.Mode.FORMULA_BASED.value: + if formula is None: + raise ValueError( + 'Source is in FORMULA_BASED mode but has no formula function.' + ) return formula( static_runtime_params_slice, - static_source_runtime_params, dynamic_runtime_params_slice, - dynamic_source_runtime_params, geo, + source_name, core_profiles, source_models, ) @@ -428,6 +431,7 @@ def is_source_builder(obj, raise_if_false: bool = False) -> bool: def _convert_source_builder_to_init_kwargs( source_builder: ..., + model_func: SourceProfileFunction | None, ) -> dict[str, Any]: """Returns a dict of init kwargs for the source builder.""" source_init_kwargs = {} @@ -439,12 +443,14 @@ def _convert_source_builder_to_init_kwargs( # including turning custom dataclasses with __call__ methods into # plain Python dictionaries. source_init_kwargs[field.name] = getattr(source_builder, field.name) + source_init_kwargs['model_func'] = model_func return source_init_kwargs def make_source_builder( source_type: ..., runtime_params_type: ... = runtime_params_lib.RuntimeParams, + model_func: SourceProfileFunction | None = None, links_back=False, ) -> SourceBuilderProtocol: """Given a Source type, returns a Builder for that type. @@ -455,6 +461,7 @@ def make_source_builder( source_type: The Source class to make a builder for. runtime_params_type: The type of `runtime_params` field which will be added to the builder dataclass. + model_func: The model function to pass to the source. links_back: If True, the Source class has a `source_models` field linking back to the SourceModels object. This must be passed to the builder's __call__ method. @@ -511,7 +518,10 @@ def check_kwargs(source_init_kwargs, context_msg): assert all([isinstance(var, runtime_params_lib.Mode) for var in v]) elif f.type == 'SourceProfileFunction | None': assert v is None or callable(v) - elif f.type == 'source.SourceProfileFunction': + elif f.type in [ + 'source.SourceProfileFunction', + 'source_lib.SourceProfileFunction', + ]: if not callable(v): raise TypeError( f'While {context_msg} {source_type} got field ' @@ -536,8 +546,8 @@ def check_kwargs(source_init_kwargs, context_msg): raise TypeError(f'Unrecognized type string: {f.type}') # Check if the field is a parameterized generic. - # Python cannot check isinstance for parameterized generics, so we need - # to handle those cases differently. + # Python cannot check isinstance for parameterized generics, so we ignore + # these cases for now. # For instance, if a field type is `tuple[float, ...]` and the value is # valid, like `(1, 2, 3)`, then `isinstance(v, f.type)` would raise a # TypeError. @@ -545,14 +555,16 @@ def check_kwargs(source_init_kwargs, context_msg): type(f.type) == types.GenericAlias # pylint: disable=unidiomatic-typecheck or typing.get_origin(f.type) is not None ): - # Do a superficial check in these instances. Only check that the origin - # type matches the value. Don't look into the rest of the object. - if not isinstance(v, typing.get_origin(f.type)): - raise TypeError( - f'While {context_msg} {source_type} got field {f.name} with ' - f'input type {type(v)} but an expected type {f.type}.' - ) - + # For `Union`s check if the value is a member of the union. + # `typing.Union` is for types defined with `Union[A, B, C]` syntax. + # `types.UnionType` is for types defined with `A | B | C` syntax. + if typing.get_origin(f.type) in [typing.Union, types.UnionType]: + if not isinstance(v, typing.get_args(f.type)): + raise TypeError( + f'While {context_msg} {source_type} got argument ' + f'{f.name} of type {type(v)} but expected ' + f'{f.type}).' + ) else: try: type_works = isinstance(v, f.type) @@ -572,7 +584,9 @@ def check_kwargs(source_init_kwargs, context_msg): # pylint doesn't like this function name because it doesn't realize # this function is to be installed in a class def __post_init__(self): # pylint:disable=invalid-name - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, model_func + ) check_kwargs(source_init_kwargs, 'making builder') # check_kwargs checks only the kwargs to Source, not SourceBuilder, # so it doesn't check "runtime_params" @@ -598,7 +612,10 @@ def check_source(source): if links_back: def build_source(self, source_models): - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, + model_func, + ) source_init_kwargs['source_models'] = source_models check_kwargs(source_init_kwargs, 'building') source = source_type(**source_init_kwargs) @@ -608,7 +625,10 @@ def build_source(self, source_models): else: def build_source(self): - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, + model_func, + ) check_kwargs(source_init_kwargs, 'building') source = source_type(**source_init_kwargs) check_source(source) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 2390d1ed..851a7a38 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -69,15 +69,11 @@ def build_source_profiles( """ # Bootstrap current is a special-case source with multiple outputs, so handle # it here. - dynamic_bootstrap_runtime_params = dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ] static_bootstrap_runtime_params = static_runtime_params_slice.sources[ source_models.j_bootstrap_name ] bootstrap_profiles = _build_bootstrap_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, static_runtime_params_slice=static_runtime_params_slice, static_source_runtime_params=static_bootstrap_runtime_params, geo=geo, @@ -106,7 +102,6 @@ def _build_bootstrap_profiles( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, j_bootstrap_source: bootstrap_current_source.BootstrapCurrentSource, @@ -122,8 +117,6 @@ def _build_bootstrap_profiles( bootstrap current source that do not change from time step to time step. dynamic_runtime_params_slice: Input config for this time step. Can change from time step to time step. - dynamic_source_runtime_params: Input runtime parameters for this time step, - specific to the bootstrap current source. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -141,9 +134,7 @@ def _build_bootstrap_profiles( """ bootstrap_profile = j_bootstrap_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, core_profiles=core_profiles, ) @@ -240,9 +231,6 @@ def _build_standard_source_profiles( affected_core_profiles_set = set(affected_core_profiles) for source_name, source in source_models.standard_sources.items(): if affected_core_profiles_set.intersection(source.affected_core_profiles): - dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ - source_name - ] static_source_runtime_params = static_runtime_params_slice.sources[ source_name ] @@ -253,9 +241,7 @@ def _build_standard_source_profiles( ), source.get_value( static_runtime_params_slice, - static_source_runtime_params, dynamic_runtime_params_slice, - dynamic_source_runtime_params, geo, core_profiles, ), @@ -359,15 +345,11 @@ def calc_and_sum_sources_psi( affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, geo=geo, ) - dynamic_bootstrap_runtime_params = dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ] static_bootstrap_runtime_params = static_runtime_params_slice.sources[ source_models.j_bootstrap_name ] j_bootstrap_profiles = _build_bootstrap_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, static_runtime_params_slice=static_runtime_params_slice, static_source_runtime_params=static_bootstrap_runtime_params, geo=geo, @@ -404,26 +386,16 @@ class SourceModels: .. code-block:: python # Define an electron-density source with a time-dependent Gaussian profile. - my_custom_source = source.SingleProfileSource( - supported_modes=( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.FORMULA_BASED, - ), - affected_core_profiles=source.AffectedCoreProfile.NE, - formula=formulas.Gaussian(), - # Define (possibly) time-dependent parameters to feed to the formula. - runtime_params=runtime_params_lib.RuntimeParams( - formula=formula_config.Gaussian( - total={0.0: 1.0, 5.0: 2.0, 10.0: 1.0}, # time-dependent. - c1=2.0, - c2=3.0, - ), - ), + gas_puff_source = register_source.get_registered_source('gas_puff_source') + gas_puff_source_builder = source_lib.make_source_builder( + gas_puff_source.source_class, + runtime_params_type=gas_puff_source.model_functions['default'].runtime_params_class, + model_func=gas_puff_source.model_functions['default'].source_profile_function, ) # Define the collection of sources here, which in this example only includes # one source. all_torax_sources = SourceModels( - sources={'my_custom_source': my_custom_source} + sources={'gas_puff_source': gas_puff_source_builder} ) See runtime_params.py for more details on how to configure all the source/sink @@ -737,6 +709,7 @@ def __init__( ] = source_lib.make_source_builder( generic_current_source.GenericCurrentSource, runtime_params_type=generic_current_source.RuntimeParams, + model_func=generic_current_source.calculate_generic_current_face, )() source_builders[ generic_current_source.GenericCurrentSource.SOURCE_NAME diff --git a/torax/sources/tests/bootstrap_current_source.py b/torax/sources/tests/bootstrap_current_source.py index d4b8cff5..71191626 100644 --- a/torax/sources/tests/bootstrap_current_source.py +++ b/torax/sources/tests/bootstrap_current_source.py @@ -36,6 +36,8 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.FORMULA_BASED, ], + source_name=bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME, + model_func=None, ) def test_extraction_of_relevant_profile_from_output(self): diff --git a/torax/sources/tests/bremsstrahlung_heat_sink.py b/torax/sources/tests/bremsstrahlung_heat_sink.py index fc47b0e9..eacf3eea 100644 --- a/torax/sources/tests/bremsstrahlung_heat_sink.py +++ b/torax/sources/tests/bremsstrahlung_heat_sink.py @@ -41,6 +41,8 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.FORMULA_BASED, ], + source_name=bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME, + model_func=bremsstrahlung_heat_sink.bremsstrahlung_model_func, ) @parameterized.parameters([ diff --git a/torax/sources/tests/electron_cyclotron_source.py b/torax/sources/tests/electron_cyclotron_source.py index ee84da1d..96d82400 100644 --- a/torax/sources/tests/electron_cyclotron_source.py +++ b/torax/sources/tests/electron_cyclotron_source.py @@ -41,6 +41,8 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.FORMULA_BASED, ], + source_name=electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME, + model_func=electron_cyclotron_source.calc_heating_and_current, ) def test_source_value(self): @@ -50,10 +52,10 @@ def test_source_value(self): raise TypeError(f"{type(self)} has a bad _source_class_builder") runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() @@ -81,11 +83,7 @@ def test_source_value(self): ) value = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources["foo"], geo=geo, core_profiles=core_profiles, ) @@ -101,10 +99,16 @@ def test_invalid_source_types_raise_errors(self): geo = geometry.build_circular_geometry() source_builder = self._source_class_builder() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + { + electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: ( + source_builder + ) + }, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[ + electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME + ] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice_provider = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -145,13 +149,7 @@ def test_invalid_source_types_raise_errors(self): with self.assertRaises(ValueError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - "foo" - ], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/sources/tests/electron_density_sources.py b/torax/sources/tests/electron_density_sources.py index e36081ce..d7cf0802 100644 --- a/torax/sources/tests/electron_density_sources.py +++ b/torax/sources/tests/electron_density_sources.py @@ -16,7 +16,6 @@ from absl.testing import absltest from torax.sources import electron_density_sources as eds -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -28,9 +27,9 @@ def setUpClass(cls): super().setUpClass( source_class=eds.GasPuffSource, runtime_params_class=eds.GasPuffRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], + source_name=eds.GasPuffSource.SOURCE_NAME, + model_func=eds.calc_puff_source, ) @@ -42,9 +41,9 @@ def setUpClass(cls): super().setUpClass( source_class=eds.PelletSource, runtime_params_class=eds.PelletRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], + source_name=eds.PelletSource.SOURCE_NAME, + model_func=eds.calc_pellet_source, ) @@ -56,23 +55,9 @@ def setUpClass(cls): super().setUpClass( source_class=eds.GenericParticleSource, runtime_params_class=eds.GenericParticleSourceRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], - ) - - -class RecombinationDensitySinkTest(test_lib.SingleProfileSourceTestCase): - """Tests for RecombinationDensitySink.""" - - @classmethod - def setUpClass(cls): - super().setUpClass( - source_class=eds.RecombinationDensitySink, - runtime_params_class=runtime_params_lib.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], + source_name=eds.GenericParticleSource.SOURCE_NAME, + model_func=eds.calc_generic_particle_source, ) diff --git a/torax/sources/tests/formulas.py b/torax/sources/tests/formulas.py deleted file mode 100644 index 1c197580..00000000 --- a/torax/sources/tests/formulas.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for sources/formulas.py.""" - -from absl.testing import absltest -import chex -from torax import output -from torax import sim as sim_lib -from torax import simulation_app -from torax.config import build_sim -from torax.config import numerics as numerics_lib -from torax.config import profile_conditions as profile_conditions_lib -from torax.config import runtime_params as general_runtime_params -from torax.pedestal_model import set_tped_nped -from torax.sources import bremsstrahlung_heat_sink -from torax.sources import electron_density_sources -from torax.sources import formula_config -from torax.sources import formulas -from torax.sources import fusion_heat_source -from torax.sources import ohmic_heat_source -from torax.sources import runtime_params as runtime_params_lib -from torax.sources.tests import test_lib -from torax.stepper import linear_theta_method -from torax.tests.test_lib import default_sources -from torax.tests.test_lib import sim_test_case -from torax.transport_model import constant as constant_transport_model - - -_ALL_PROFILES = ('temp_ion', 'temp_el', 'psi', 'q_face', 's_face', 'ne') - - -class FormulasIntegrationTest(sim_test_case.SimTestCase): - """Integration tests for using non-default formulas.""" - - def test_custom_exponential_source_can_replace_puff_source(self): - """Replaces one the default ne source with a custom one.""" - # The default puff source gives an exponential profile. In this test, we - # zero out the default puff source and introduce a new custom source that - # should give the same profiles throughout the entire simulation run as the - # original puff source. - - # For this test, use test_particle_sources_constant with the linear stepper. - custom_source_name = 'custom_exponential_source' - - # Copy the test_particle_sources_constant config in here for clarity. - test_particle_sources_constant_runtime_params = general_runtime_params.GeneralRuntimeParams( - profile_conditions=profile_conditions_lib.ProfileConditions( - set_pedestal=True, - nbar=0.85, - nu=0, - ne_bound_right=0.5, - ), - numerics=numerics_lib.Numerics( - ion_heat_eq=True, - el_heat_eq=True, - dens_eq=True, # This is important to be True to test ne sources. - current_eq=True, - resistivity_mult=100, - t_final=2, - ), - ) - basic_pedestal_model_builder = ( - set_tped_nped.SetTemperatureDensityPedestalModelBuilder() - ) - # Set the sources to match test_particle_sources_constant as well. - source_models_builder = default_sources.get_default_sources_builder() - source_models_builder.runtime_params[ - electron_density_sources.PelletSource.SOURCE_NAME - ].S_pellet_tot = 2.0e22 - S_puff_tot = 1.0e22 # pylint: disable=invalid-name - puff_decay_length = 0.05 - source_models_builder.runtime_params[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ].S_puff_tot = S_puff_tot - source_models_builder.runtime_params[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ].puff_decay_length = puff_decay_length - source_models_builder.runtime_params[ - electron_density_sources.GenericParticleSource.SOURCE_NAME - ].S_tot = 0.0 - # We need to turn off some other sources for test_particle_sources_constant - # that are unrelated to our test for the ne custom source. - source_models_builder.runtime_params[ - fusion_heat_source.FusionHeatSource.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - source_models_builder.runtime_params[ - ohmic_heat_source.OhmicHeatSource.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - source_models_builder.runtime_params[ - bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - - # Add the custom source to the source_models, but keep it turned off for the - # first run. - source_models_builder.source_builders[custom_source_name] = ( - test_lib.TestSourceBuilder( - formula=formulas.Exponential(), - runtime_params=runtime_params_lib.RuntimeParams( - mode=runtime_params_lib.Mode.ZERO, - # will override these later, but defining here because, due to - # how JAX works, this function is still evaluated even when the - # mode is set to ZERO. So the runtime config needs to be set - # with the correct params. - formula=formula_config.Exponential(), - ), - ) - ) - - # Load reference profiles - ref_profiles, ref_time = self._get_refs( - 'test_particle_sources_constant.nc', _ALL_PROFILES - ) - - # We set up the sim only once and update the config on each run below in a - # way that does not trigger recompiles. This way we only trace the code - # once. - geo_provider = build_sim.build_geometry_provider_from_config( - {'geometry_type': 'circular'} - ) - transport_model_builder = ( - constant_transport_model.ConstantTransportModelBuilder( - runtime_params=constant_transport_model.RuntimeParams( - De_const=0.5, - Ve_const=-0.2, - ) - ) - ) - sim = sim_lib.build_sim_object( - runtime_params=test_particle_sources_constant_runtime_params, - geometry_provider=geo_provider, - stepper_builder=linear_theta_method.LinearThetaMethodBuilder( - runtime_params=linear_theta_method.LinearRuntimeParams( - predictor_corrector=False, - ) - ), - transport_model_builder=transport_model_builder, - source_models_builder=source_models_builder, - pedestal_model_builder=basic_pedestal_model_builder, - ) - - # Make sure the config copied here works with these references. - with self.subTest('with_puff_and_without_custom_source'): - # Need to run the sim once to build the step_fn. - sim_outputs = sim.run() - history = output.StateHistory(sim_outputs, sim.source_models) - self._check_profiles_vs_expected( - core_profiles=history.core_profiles, - t=history.times, - ref_time=ref_time, - ref_profiles=ref_profiles, - rtol=self.rtol, - atol=self.atol, - ) - - with self.subTest('without_puff_and_with_custom_source'): - # Now turn on the custom source. - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.FORMULA_BASED - ) - source_models_builder.runtime_params[custom_source_name].formula = ( - formula_config.Exponential( - total=( - S_puff_tot - / test_particle_sources_constant_runtime_params.numerics.nref - ), - c1=1.0, - c2=puff_decay_length, - ) - ) - # And turn off the gas puff source it is replacing. - source_models_builder.runtime_params[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - sim = simulation_app.update_sim( - sim, - test_particle_sources_constant_runtime_params, - sim.geometry_provider, - transport_model_builder.runtime_params, - source_models_builder.runtime_params, - linear_theta_method.LinearRuntimeParams(predictor_corrector=False), - pedestal_runtime_params=basic_pedestal_model_builder.runtime_params, - ) - self._run_sim_and_check(sim, ref_profiles, ref_time) - - with self.subTest('without_puff_and_without_custom_source'): - # Confirm that the custom source actual has an effect. - # Turn it off as well, and the check shouldn't pass. - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.ZERO - ) - sim = simulation_app.update_sim( - sim, - test_particle_sources_constant_runtime_params, - sim.geometry_provider, - transport_model_builder.runtime_params, - source_models_builder.runtime_params, - linear_theta_method.LinearRuntimeParams(predictor_corrector=False), - pedestal_runtime_params=basic_pedestal_model_builder.runtime_params, - ) - with self.assertRaises(AssertionError): - self._run_sim_and_check(sim, ref_profiles, ref_time) - - def _run_sim_and_check( - self, - sim: sim_lib.Sim, - ref_profiles: dict[str, chex.ArrayTree], - ref_time: chex.Array, - ): - """Runs sim with new runtime params and checks the profiles vs. expected.""" - sim_outputs = sim_lib.run_simulation( - static_runtime_params_slice=sim.static_runtime_params_slice, - dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider, - geometry_provider=sim.geometry_provider, - initial_state=sim.initial_state, - time_step_calculator=sim.time_step_calculator, - step_fn=sim.step_fn, - ) - history = output.StateHistory(sim_outputs, sim.source_models) - self._check_profiles_vs_expected( - core_profiles=history.core_profiles, - t=history.times, - ref_time=ref_time, - ref_profiles=ref_profiles, - rtol=self.rtol, - atol=self.atol, - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index 6dab4b4b..d178bc44 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -40,6 +40,8 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.FORMULA_BASED, ], + source_name=fusion_heat_source.FusionHeatSource.SOURCE_NAME, + model_func=fusion_heat_source.fusion_heat_model_func, ) @parameterized.parameters([ diff --git a/torax/sources/tests/generic_current_source.py b/torax/sources/tests/generic_current_source.py index a21cd50a..041aac24 100644 --- a/torax/sources/tests/generic_current_source.py +++ b/torax/sources/tests/generic_current_source.py @@ -38,6 +38,8 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.MODEL_BASED, ], + source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME, + model_func=generic_current_source.calculate_generic_current_face, ) def test_generic_current_hires(self): @@ -70,13 +72,7 @@ def test_generic_current_hires(self): self.assertIsNotNone( source.generic_current_source_hires( dynamic_runtime_params_slice=dynamic_slice, - dynamic_source_runtime_params=dynamic_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo=geo, ) ) @@ -113,13 +109,7 @@ def test_profile_is_on_face_grid(self): self.assertEqual( source.get_value( static_slice, - static_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], dynamic_runtime_params_slice, - dynamic_runtime_params_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo, core_profiles=None, ).shape, diff --git a/torax/sources/tests/generic_ion_el_heat_source.py b/torax/sources/tests/generic_ion_el_heat_source.py index 79555718..fb644f47 100644 --- a/torax/sources/tests/generic_ion_el_heat_source.py +++ b/torax/sources/tests/generic_ion_el_heat_source.py @@ -16,7 +16,6 @@ from absl.testing import absltest from torax.sources import generic_ion_el_heat_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -28,9 +27,9 @@ def setUpClass(cls): super().setUpClass( source_class=generic_ion_el_heat_source.GenericIonElectronHeatSource, runtime_params_class=generic_ion_el_heat_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], + source_name=generic_ion_el_heat_source.GenericIonElectronHeatSource.SOURCE_NAME, + model_func=generic_ion_el_heat_source.default_formula, ) diff --git a/torax/sources/tests/impurity_radiation_heat_sink.py b/torax/sources/tests/impurity_radiation_heat_sink.py index 485d2c26..b22237c4 100644 --- a/torax/sources/tests/impurity_radiation_heat_sink.py +++ b/torax/sources/tests/impurity_radiation_heat_sink.py @@ -42,11 +42,10 @@ def setUpClass(cls): super().setUpClass( source_class=impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink, runtime_params_class=impurity_radiation_heat_sink_lib.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ], + unsupported_modes=[], + source_name=impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME, links_back=True, + model_func=impurity_radiation_heat_sink_lib.radially_constant_fraction_of_Pin ) def test_source_value(self): @@ -64,8 +63,11 @@ def test_source_value(self): heat_source_builder_builder = source_lib.make_source_builder( source_type=generic_ion_el_heat_source.GenericIonElectronHeatSource, runtime_params_type=generic_ion_el_heat_source.RuntimeParams, + model_func=generic_ion_el_heat_source.default_formula, + ) + heat_source_builder = heat_source_builder_builder( + model_func=generic_ion_el_heat_source.default_formula ) - heat_source_builder = heat_source_builder_builder() # Runtime params runtime_params = general_runtime_params.GeneralRuntimeParams() @@ -73,9 +75,7 @@ def test_source_value(self): # Source models source_models_builder = source_models_lib.SourceModelsBuilder( { - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME: ( - impurity_radiation_sink_builder - ), + self._source_name: impurity_radiation_sink_builder, generic_ion_el_heat_source.GenericIonElectronHeatSource.SOURCE_NAME: ( heat_source_builder ), @@ -84,9 +84,7 @@ def test_source_value(self): source_models = source_models_builder() # Extract the source we're testing and check that it's been built correctly - impurity_radiation_sink = source_models.sources[ - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME - ] + impurity_radiation_sink = source_models.sources[self._source_name] self.assertIsInstance(impurity_radiation_sink, source_lib.Source) # Geometry, profiles, and dynamic runtime params @@ -110,12 +108,9 @@ def test_source_value(self): geo=geo, source_models=source_models, ) - impurity_radiation_sink_dynamic_runtime_params_slice = dynamic_runtime_params_slice.sources[ - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME - ] - impurity_radiation_sink_static_runtime_params_slice = static_slice.sources[ - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME - ] + impurity_radiation_sink_dynamic_runtime_params_slice = ( + dynamic_runtime_params_slice.sources[self._source_name] + ) heat_source_dynamic_runtime_params_slice = ( dynamic_runtime_params_slice.sources[ @@ -131,13 +126,13 @@ def test_source_value(self): heat_source_dynamic_runtime_params_slice, generic_ion_el_heat_source.DynamicRuntimeParams, ) - impurity_radiation_heat_sink_power_density = impurity_radiation_sink.get_value( - static_runtime_params_slice=static_slice, - static_source_runtime_params=impurity_radiation_sink_static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=impurity_radiation_sink_dynamic_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, + impurity_radiation_heat_sink_power_density = ( + impurity_radiation_sink.get_value( + static_runtime_params_slice=static_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) ) # ImpurityRadiationHeatSink provides TEMP_EL only @@ -162,10 +157,12 @@ def test_invalid_source_types_raise_errors(self): geo = geometry.build_circular_geometry() source_builder = self._source_class_builder() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + { + self._source_name: source_builder + }, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice_provider = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -207,13 +204,7 @@ def test_invalid_source_types_raise_errors(self): with self.assertRaises(RuntimeError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - "foo" - ], geo=geo, core_profiles=core_profiles, ) @@ -223,10 +214,10 @@ def test_extraction_of_relevant_profile_from_output(self): geo = geometry.build_circular_geometry() source_builder = self._source_class_builder() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) cell = source_lib.ProfileType.CELL.get_profile_shape(geo) fake_profile = jnp.ones(cell) diff --git a/torax/sources/tests/ion_cyclotron_source.py b/torax/sources/tests/ion_cyclotron_source.py index 495deb00..bf309dfb 100644 --- a/torax/sources/tests/ion_cyclotron_source.py +++ b/torax/sources/tests/ion_cyclotron_source.py @@ -99,6 +99,9 @@ def setUpClass(cls): source_class=ion_cyclotron_source.IonCyclotronSource, runtime_params_class=ion_cyclotron_source.RuntimeParams, unsupported_modes=[runtime_params_lib.Mode.FORMULA_BASED], + source_class_builder=ion_cyclotron_source.IonCyclotronSourceBuilder, + source_name=ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME, + model_func=None, ) @parameterized.product( @@ -163,13 +166,7 @@ def test_icrh_output_matches_total_power( ) icrh_output = icrh_source.get_value( static_slice, - static_slice.sources[ - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME - ], dynamic_runtime_params_slice, - dynamic_runtime_params_slice.sources[ - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME - ], geo, core_profiles, ) @@ -218,10 +215,12 @@ def test_source_value(self, mock_path): runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + {ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: source_builder}, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[ + ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME + ] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -244,11 +243,7 @@ def test_source_value(self, mock_path): ) ion_and_el = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources["foo"], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/sources/tests/ohmic_heat_source.py b/torax/sources/tests/ohmic_heat_source.py index a6cdc5f2..597d7dc3 100644 --- a/torax/sources/tests/ohmic_heat_source.py +++ b/torax/sources/tests/ohmic_heat_source.py @@ -29,7 +29,9 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.FORMULA_BASED, ], + source_name=ohmic_heat_source.OhmicHeatSource.SOURCE_NAME, links_back=True, + model_func=ohmic_heat_source.ohmic_model_func, ) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index 3df12b31..36c2cb80 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -37,6 +37,8 @@ def setUpClass(cls): unsupported_modes=[ runtime_params_lib.Mode.FORMULA_BASED, ], + source_name=qei_source.QeiSource.SOURCE_NAME, + model_func=None, ) def test_source_value(self): @@ -70,7 +72,6 @@ def test_source_value(self): qei = source.get_qei( static_slice, dynamic_slice, - dynamic_slice.sources['qei_source'], geo, core_profiles, ) @@ -119,7 +120,6 @@ def test_invalid_source_types_raise_errors(self): source.get_qei( static_slice, dynamic_slice, - dynamic_slice.sources['qei_source'], geo, core_profiles, ) diff --git a/torax/sources/tests/register_source.py b/torax/sources/tests/register_source.py index 47a6465f..dd0100db 100644 --- a/torax/sources/tests/register_source.py +++ b/torax/sources/tests/register_source.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for the source registry.""" - from absl.testing import absltest from absl.testing import parameterized from torax.sources import bootstrap_current_source @@ -27,6 +26,7 @@ from torax.sources import ohmic_heat_source from torax.sources import qei_source from torax.sources import register_source +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -48,10 +48,21 @@ class SourceTest(parameterized.TestCase): ) def test_sources_in_registry_build_successfully(self, source_name: str): """Test that all sources in the registry build successfully.""" - registered_source = register_source.get_registered_source(source_name) + registered_source = register_source.get_supported_source(source_name) source_class = registered_source.source_class - source_runtime_params_class = registered_source.default_runtime_params_class - source_builder_class = registered_source.source_builder_class + model_function = registered_source.model_functions[ + source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + source_builder_class = model_function.source_builder_class + source_runtime_params_class = model_function.runtime_params_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + registered_source.source_class, + runtime_params_type=source_runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + source_runtime_params_class = model_function.runtime_params_class source_builder = source_builder_class() self.assertIsInstance( source_builder.runtime_params, source_runtime_params_class diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index da68237a..79619187 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -45,12 +45,21 @@ class PsiTestSource(source_lib.Source): def affected_core_profiles(self): return (source_lib.AffectedCoreProfile.PSI,) + @property + def source_name(self) -> str: + return 'foo' + PsiTestSourceBuilder = source_lib.make_source_builder(PsiTestSource) @dataclasses.dataclass(frozen=True, eq=True) class IonElTestSource(source_lib.Source): + """Test source that affects ion and electron profiles.""" + + @property + def source_name(self) -> str: + return 'foo' @property def affected_core_profiles(self): @@ -147,6 +156,7 @@ class NotEq: @dataclasses.dataclass(frozen=True, eq=True) class MySource: my_field: int + model_func: source_lib.SourceProfileFunction | None = None # pylint doesn't realize this is a class MySourceBuilder = source_lib.make_source_builder( # pylint: disable=invalid-name @@ -194,11 +204,7 @@ def test_zero_profile_works_by_default(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -213,6 +219,10 @@ def test_unsupported_modes_raise_errors(self): class TestSource(source_lib.Source): """A test source.""" + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles( self, @@ -256,11 +266,7 @@ def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: with self.assertRaises(ValueError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -281,10 +287,10 @@ def test_correct_mode_called(self, mode, expected_profile): source = source_models.sources['foo'] source = dataclasses.replace( source, - model_func=lambda _0, _1, _2, _3, _4, _5, _6: jnp.ones( + model_func=lambda _0, _1, _2, _3, _4, _5: jnp.ones( source_lib.ProfileType.CELL.get_profile_shape(geo) ), - formula=lambda _0, _1, _2, _3, _4, _5, _6: jnp.ones( + formula=lambda _0, _1, _2, _3, _4, _5: jnp.ones( source_lib.ProfileType.CELL.get_profile_shape(geo) ) * 2, @@ -318,11 +324,7 @@ def test_correct_mode_called(self, mode, expected_profile): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -370,20 +372,13 @@ def test_defaults_output_zeros(self): ) }, ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose( - profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), - ) + with self.assertRaises(ValueError): + source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + core_profiles=core_profiles, + ) with self.subTest('formula'): static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params, @@ -394,20 +389,13 @@ def test_defaults_output_zeros(self): ) }, ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose( - profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), - ) + with self.assertRaises(ValueError): + source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + core_profiles=core_profiles, + ) with self.subTest('prescribed'): static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params, @@ -420,11 +408,7 @@ def test_defaults_output_zeros(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -439,7 +423,7 @@ def test_overriding_default_formula(self): output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) expected_output = jnp.ones(output_shape) source_builder = IonElTestSourceBuilder( - formula=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, + formula=lambda _0, _1, _2, _3, _4, _5: expected_output, ) source_builder.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED source_models_builder = source_models_lib.SourceModelsBuilder( @@ -469,11 +453,7 @@ def test_overriding_default_formula(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -484,9 +464,10 @@ def test_overriding_model(self): geo = geometry.build_circular_geometry() output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) expected_output = jnp.ones(output_shape) - source_builder = IonElTestSourceBuilder( - model_func=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, - ) + source_builder = source_lib.make_source_builder( + IonElTestSource, + model_func=lambda _0, _1, _2, _3, _4, _5: expected_output, + )() source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED source_models_builder = source_models_lib.SourceModelsBuilder( {'foo': source_builder}, @@ -515,11 +496,7 @@ def test_overriding_model(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -565,11 +542,7 @@ def test_overriding_prescribed_values(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -583,6 +556,10 @@ def test_retrieving_profile_for_affected_state(self): class TestSource(source_lib.Source): output_shape_getter = lambda _0: output_shape + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles(self): return ( @@ -592,7 +569,7 @@ def affected_core_profiles(self): profile = jnp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]]) # from get_value() source = TestSource( - model_func=lambda _0, _1, _2, _3, _4, _5, _6: profile, + model_func=lambda _0, _1, _2, _3, _4, _5: profile, ) geo = geometry.build_circular_geometry(n_rho=4) psi_profile = source.get_source_profile_for_affected_core_profile( @@ -622,7 +599,7 @@ def test_custom_formula(self): geo = geometry.build_circular_geometry(n_rho=5) expected_output = jnp.ones((5)) # 5 matches the geo. source_builder = PsiTestSourceBuilder( - formula=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, + formula=lambda _0, _1, _2, _3, _4, _5: expected_output, ) source_builder.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED source_models_builder = source_models_lib.SourceModelsBuilder( @@ -651,11 +628,7 @@ def test_custom_formula(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -666,7 +639,7 @@ def test_retrieving_profile_for_affected_state(self): profile = jnp.asarray([1, 2, 3, 4]) # from get_value() source = test_lib.TestSource( - model_func=lambda _0, _1, _2, _3, _4, _5, _6: profile, + model_func=lambda _0, _1, _2, _3, _4, _5: profile, ) geo = geometry.build_circular_geometry(n_rho=4) psi_profile = source.get_source_profile_for_affected_core_profile( diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index fef76763..1bdb6915 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -37,6 +37,10 @@ class FooSource(source_lib.Source): """A test source.""" + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles( self, @@ -175,10 +179,9 @@ def test_custom_source_profiles_dont_change_when_jitted(self): def foo_formula( unused_dcs, - unused_sc, unused_static_runtime_params_slice, - unused_static_source_runtime_params, geo: geometry.Geometry, + unused_source_name: str, unused_state, unused_source_models, ): diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 28ae9792..048c05dc 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -37,6 +37,10 @@ class TestSource(source_lib.Source): """A test source.""" + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles( self, @@ -66,6 +70,8 @@ class SourceTestCase(parameterized.TestCase): _source_class_builder: source_lib.SourceBuilderProtocol _config_attr_name: str _unsupported_modes: Sequence[runtime_params_lib.Mode] + _source_name: str + _runtime_params_class: Type[runtime_params_lib.RuntimeParams] @classmethod def setUpClass( @@ -73,18 +79,26 @@ def setUpClass( source_class: Type[source_lib.Source], runtime_params_class: Type[runtime_params_lib.RuntimeParams], unsupported_modes: Sequence[runtime_params_lib.Mode], + source_name: str, + model_func: source_lib.SourceProfileFunction | None, links_back: bool = False, + source_class_builder: source_lib.SourceBuilderProtocol | None = None, ): super().setUpClass() cls._source_class = source_class - cls._source_class_builder = source_lib.make_source_builder( - source_type=source_class, - runtime_params_type=runtime_params_class, - links_back=links_back, - ) + if source_class_builder is None: + cls._source_class_builder = source_lib.make_source_builder( + source_type=source_class, + runtime_params_type=runtime_params_class, + links_back=links_back, + model_func=model_func, + ) + else: + cls._source_class_builder = source_class_builder cls._runtime_params_class = runtime_params_class cls._unsupported_modes = unsupported_modes cls._links_back = links_back + cls._source_name = source_name def test_runtime_params_builds_dynamic_params(self): runtime_params = self._runtime_params_class() @@ -127,17 +141,17 @@ def test_source_value(self): # SingleProfileSource subclasses should have default names and be # instantiable without any __init__ arguments. # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter + source_builder = self._source_class_builder() if not source_lib.is_source_builder(source_builder): raise TypeError(f'{type(self)} has a bad _source_class_builder') # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources['foo'] - source_builder.runtime_params.mode = source.supported_modes[0] + source = source_models.sources[self._source_name] + source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() dynamic_runtime_params_slice = ( @@ -161,11 +175,7 @@ def test_source_value(self): ) value = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -179,10 +189,10 @@ def test_invalid_source_types_raise_errors(self): source_builder = self._source_class_builder() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources['foo'] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -213,11 +223,7 @@ def test_invalid_source_types_raise_errors(self): with self.assertRaises(ValueError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -229,15 +235,15 @@ class IonElSourceTestCase(SourceTestCase): def test_source_value(self): """Tests that the source can provide a value by default.""" # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter + source_builder = self._source_class_builder() # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources['foo'] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -260,11 +266,7 @@ def test_source_value(self): ) ion_and_el = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -278,10 +280,10 @@ def test_invalid_source_types_raise_errors(self): source_builder = self._source_class_builder() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources['foo'] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -312,11 +314,7 @@ def test_invalid_source_types_raise_errors(self): with self.assertRaises(ValueError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/tests/physics.py b/torax/tests/physics.py index 085f3d98..aba8fb06 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -16,6 +16,7 @@ import dataclasses from typing import Callable + from absl.testing import absltest from absl.testing import parameterized import jax @@ -27,10 +28,12 @@ from torax.config import runtime_params_slice from torax.fvm import cell_variable from torax.geometry import geometry +from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs + _trapz = jax.scipy.integrate.trapezoid @@ -108,9 +111,9 @@ def test_update_psi_from_j( runtime_params = references.runtime_params source_models_builder = source_models_lib.SourceModelsBuilder() # Turn on the external current source. - source_models_builder.runtime_params['generic_current_source'].mode = ( - source_runtime_params.Mode.FORMULA_BASED - ) + source_models_builder.runtime_params[ + generic_current_source.GenericCurrentSource.SOURCE_NAME + ].mode = source_runtime_params.Mode.MODEL_BASED source_models = source_models_builder() dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index a34c0ed4..360c6859 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -35,6 +35,7 @@ from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib from torax.sources.tests import test_lib from torax.stepper import linear_theta_method from torax.tests.test_lib import default_sources @@ -94,64 +95,36 @@ def test_custom_ne_source_can_replace_defaults(self): # For this example, use test_particle_sources_constant with the linear # stepper. - custom_source_name = 'custom_ne_source' + custom_source_name = 'foo' def custom_source_formula( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.RuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + unused_source_name: str, unused_state: state_lib.CoreProfiles | None, unused_source_models: ..., ): # Combine the outputs. - assert isinstance( - dynamic_source_runtime_params, _CustomSourceDynamicRuntimeParams - ) - ignored_default_kwargs = dict( - formula=dynamic_source_runtime_params.formula, - prescribed_values=dynamic_source_runtime_params.prescribed_values, - ) - puff_params = electron_density_sources.DynamicGasPuffRuntimeParams( - puff_decay_length=dynamic_source_runtime_params.puff_decay_length, - S_puff_tot=dynamic_source_runtime_params.S_puff_tot, - **ignored_default_kwargs, - ) - params = electron_density_sources.DynamicParticleRuntimeParams( - deposition_location=dynamic_source_runtime_params.deposition_location, - particle_width=dynamic_source_runtime_params.particle_width, - S_tot=dynamic_source_runtime_params.S_tot, - **ignored_default_kwargs, - ) - pellet_params = electron_density_sources.DynamicPelletRuntimeParams( - pellet_deposition_location=dynamic_source_runtime_params.pellet_deposition_location, - pellet_width=dynamic_source_runtime_params.pellet_width, - S_pellet_tot=dynamic_source_runtime_params.S_pellet_tot, - **ignored_default_kwargs, - ) # pylint: disable=protected-access return ( - electron_density_sources._calc_puff_source( + electron_density_sources.calc_puff_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=puff_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, + source_name=electron_density_sources.GasPuffSource.SOURCE_NAME, ) - + electron_density_sources._calc_generic_particle_source( + + electron_density_sources.calc_generic_particle_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, + source_name=electron_density_sources.GenericParticleSource.SOURCE_NAME, ) - + electron_density_sources._calc_pellet_source( + + electron_density_sources.calc_pellet_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=pellet_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, + source_name=electron_density_sources.PelletSource.SOURCE_NAME, ) ) # pylint: enable=protected-access @@ -189,20 +162,25 @@ def custom_source_formula( # Add the custom source with the correct params, but keep it turned off to # start. + source_builder = source_lib.make_source_builder( + test_lib.TestSource, + runtime_params_type=_CustomSourceRuntimeParams, + model_func=custom_source_formula, + ) + runtime_params = _CustomSourceRuntimeParams( + mode=runtime_params_lib.Mode.ZERO, + puff_decay_length=gas_puff_params.puff_decay_length, + S_puff_tot=gas_puff_params.S_puff_tot, + particle_width=params.particle_width, + deposition_location=params.deposition_location, + S_tot=params.S_tot, + pellet_width=pellet_params.pellet_width, + pellet_deposition_location=pellet_params.pellet_deposition_location, + S_pellet_tot=pellet_params.S_pellet_tot, + ) source_models_builder.source_builders[custom_source_name] = ( - test_lib.TestSourceBuilder( - formula=custom_source_formula, - runtime_params=_CustomSourceRuntimeParams( - mode=runtime_params_lib.Mode.ZERO, - puff_decay_length=gas_puff_params.puff_decay_length, - S_puff_tot=gas_puff_params.S_puff_tot, - particle_width=params.particle_width, - deposition_location=params.deposition_location, - S_tot=params.S_tot, - pellet_width=pellet_params.pellet_width, - pellet_deposition_location=pellet_params.pellet_deposition_location, - S_pellet_tot=pellet_params.S_pellet_tot, - ), + source_builder( + runtime_params=runtime_params, ) ) @@ -241,16 +219,12 @@ def custom_source_formula( params.mode = runtime_params_lib.Mode.ZERO pellet_params.mode = runtime_params_lib.Mode.ZERO gas_puff_params.mode = runtime_params_lib.Mode.ZERO - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.FORMULA_BASED - ) + runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self._run_sim_and_check(sim, ref_profiles, ref_time) with self.subTest('without_defaults_and_without_custom_source'): # Confirm that the custom source actual has an effect. - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.ZERO - ) + runtime_params.mode = runtime_params_lib.Mode.ZERO with self.assertRaises(AssertionError): self._run_sim_and_check(sim, ref_profiles, ref_time) @@ -311,7 +285,6 @@ def make_provider( raise ValueError('torax_mesh is required for CustomSourceRuntimeParams.') return _CustomSourceRuntimeParamsProvider( runtime_params_config=self, - formula=self.formula.make_provider(torax_mesh), prescribed_values=config_args.get_interpolated_var_2d( self.prescribed_values, torax_mesh.cell_centers ), diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 68764537..e68b18f7 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -38,6 +38,7 @@ from torax.geometry import geometry_provider as geometry_provider_lib from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib from torax.sources.tests import test_lib @@ -52,6 +53,30 @@ _ALL_PROFILES = ('temp_ion', 'temp_el', 'psi', 'q_face', 's_face', 'ne') +class TestImplicitNeSource(test_lib.TestSource): + """A test source.""" + + @property + def source_name(self) -> str: + return 'implicit_ne_source' + + +class TestExplicitNeSource(test_lib.TestSource): + """A test source.""" + + @property + def source_name(self) -> str: + return 'explicit_ne_source' + + +TestImplicitNeSourceBuilder = source_lib.make_source_builder( + TestImplicitNeSource +) +TestExplicitNeSourceBuilder = source_lib.make_source_builder( + TestExplicitNeSource +) + + class SimOutputSourceProfilesTest(sim_test_case.SimTestCase): """Tests checking the output core_sources profiles from run_simulation().""" @@ -101,28 +126,31 @@ def test_first_and_last_source_profiles(self): # This is not physically realistic, just for testing purposes. def custom_source_formula( unused_static_runtime_params_slice, - unused_static_source_runtime_params, - unused_dynamic_runtime_params, - source_conf, + dynamic_runtime_params, geo, + source_name, unused_state, unused_source_models, ): - return jnp.ones_like(geo.rho) * source_conf.foo + dynamic_source_params = dynamic_runtime_params.sources[source_name] + return jnp.ones_like(geo.rho) * dynamic_source_params.foo # Include 2 versions of this source, one implicit and one explicit. + source_builder = source_lib.make_source_builder( + TestImplicitNeSource, + runtime_params_type=_FakeSourceRuntimeParams, + model_func=custom_source_formula, + ) source_models_builder = source_models_lib.SourceModelsBuilder({ - 'implicit_ne_source': test_lib.TestSourceBuilder( - formula=custom_source_formula, + 'implicit_ne_source': source_builder( runtime_params=_FakeSourceRuntimeParams( - mode=runtime_params_lib.Mode.FORMULA_BASED, + mode=runtime_params_lib.Mode.MODEL_BASED, foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, ), ), - 'explicit_ne_source': test_lib.TestSourceBuilder( - formula=custom_source_formula, + 'explicit_ne_source': source_builder( runtime_params=_FakeSourceRuntimeParams( - mode=runtime_params_lib.Mode.FORMULA_BASED, + mode=runtime_params_lib.Mode.MODEL_BASED, foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, ), ), @@ -249,7 +277,6 @@ def make_provider( raise ValueError('torax_mesh is required for FakeSourceRuntimeParams.') return _FakeSourceRuntimeParamsProvider( runtime_params_config=self, - formula=self.formula.make_provider(torax_mesh), prescribed_values=config_args.get_interpolated_var_2d( self.prescribed_values, torax_mesh.cell_centers ), diff --git a/torax/tests/state.py b/torax/tests/state.py index c55d3090..a6679903 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -397,13 +397,7 @@ def test_initial_psi_from_j( source_models = source_models_lib.SourceModels() bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dcs3, - dynamic_source_runtime_params=dcs3.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles3_helper, ) diff --git a/torax/tests/test_lib/default_sources.py b/torax/tests/test_lib/default_sources.py index fd45273a..959b6924 100644 --- a/torax/tests/test_lib/default_sources.py +++ b/torax/tests/test_lib/default_sources.py @@ -13,7 +13,9 @@ # limitations under the License. """Utilities to help with testing sources.""" + from torax.sources import register_source +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -60,9 +62,18 @@ def get_default_sources_builder() -> source_models_lib.SourceModelsBuilder: ] source_builders = {} for name in names: - registered_source = register_source.get_registered_source(name) - runtime_params = registered_source.default_runtime_params_class() - source_builders[name] = registered_source.source_builder_class( - runtime_params=runtime_params - ) + registered_source = register_source.get_supported_source(name) + model_function = registered_source.model_functions[ + registered_source.source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + runtime_params = model_function.runtime_params_class() + source_builder_class = model_function.source_builder_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + registered_source.source_class, + runtime_params_type=model_function.runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + source_builders[name] = source_builder_class(runtime_params=runtime_params) return source_models_lib.SourceModelsBuilder(source_builders)