diff --git a/docs/configuration.rst b/docs/configuration.rst index 80f8c552..5227eb08 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -803,10 +803,6 @@ The configurable runtime parameters of each source are as follows: Source values come from a model in code. Specific model selection is not yet available in TORAX since there are no source components with more than one physics model. However, this will be straightforward to develop when that occurs. -* ``'FORMULA'`` - Source values come from a prescribed (possibly time-dependent) formula that is not dependent on the state of the system. The formula type (Gaussian, exponential) - is set by ``formula_type``. - * ``'PRESCRIBED'`` Source values are arbitrarily prescribed by the user. The value is set by ``prescribed_values``, and can contain the same data structures as :ref:`Time-varying arrays`. @@ -840,38 +836,13 @@ and can be set to anything convenient. beginning of a time step, or do not have any dependance on state. Implicit sources depend on updated states as the iterative solvers evolve the state through the course of a time step. If a source model is complex but evolves over slow timescales compared to the state, it may be beneficial to set it as explicit. -``formula_type`` (str='default') - Sets the formula type if ``mode=='formula'``. The current options are: - -* ``'exponential'`` takes the following arguments: - * c1 (float): Offset location - * c2 (float): Exponential decay parameter - * total (float): integral - - The profile is parameterized as follows :math:`Q = C e^{-(r - c1) / c2}` , where ``C`` is calculated to be consistent with ``total``. If ``use_normalized_r==True``, - then c1 and c2 are interpreted as being in normalized toroidal flux units. - -* ``'gaussian'`` takes the following arguments: - * c1 (float): Gaussian peak Location - * c2 (float): Gaussian width - * total (float): integral - - The profile is parameterized as follows :math:`Q = C e^{-((r - c1)^2) / (2 c2^2)}` , where ``C`` is calculated to be consistent with ``total``. If ``use_normalized_r==True``, - then c1 and c2 are interpreted as being in normalized toroidal flux units. - -* ``'default'`` - Some sources have default implementations which use the above formulas under the hood with intuitive parameter names for c1 and c2. - Consult the list below for further details. generic_ion_el_heat_source ^^^^^^^^^^^^^^^^^^^^^^^^^^ A utility source module that allows for a time dependent Gaussian ion and electron heat source. -``mode`` (str = 'formula') - -``formula_type`` (str = 'default') - Uses the Gaussian formula. +``mode`` (str = 'model') ``rsource`` (float = 0.0), **time-varying-scalar** Gaussian center of source profile in units of :math:`\hat{\rho}`. @@ -912,12 +883,9 @@ Fusion power assuming a 50-50 D-T ion distribution. gas_puff_source ^^^^^^^^^^^^^^^ -Formula based exponential gas puff source. No first-principle-based model is yet implemented in TORAX. +Exponential based gas puff source. No first-principle-based model is yet implemented in TORAX. -``mode`` (str = 'formula') - -``formula_type`` (str = 'default') - Uses the exponential formula with ``c1=1``. +``mode`` (str = 'model') ``puff_decay_length`` (float = 0.05), **time-varying-scalar** Gas puff decay length from edge in units of :math:`\hat{\rho}`. @@ -930,10 +898,7 @@ pellet_source Time dependent Gaussian pellet source. No first-principle-based model is yet implemented in TORAX. -``mode`` (str = 'formula') - -``formula_type`` (str = 'default') - Uses the Gaussian formula. +``mode`` (str = 'model') ``pellet_deposition_location`` (float = 0.85), **time-varying-scalar** Gaussian center of source profile in units of :math:`\hat{\rho}`. @@ -949,10 +914,7 @@ generic_particle_source Time dependent Gaussian particle source. No first-principle-based model is yet implemented in TORAX. -``mode`` (str = 'formula') - -``formula_type`` (str = 'default') - Uses the Gaussian formula with. +``mode`` (str = 'model') ``deposition_location`` (float = 0.0), **time-varying-scalar** Gaussian center of source profile in units of :math:`\hat{\rho}`. @@ -978,10 +940,7 @@ generic_current_source Generic external current profile, parameterized as a Gaussian. -``mode`` (str = 'formula') - -``formula_type`` (str = 'default') - Uses the Gaussian formula. +``mode`` (str = 'model') ``rext`` (float = 0.4), **time-varying-scalar** Gaussian center of current profile in units of :math:`\hat{\rho}`. diff --git a/docs/physics_models.rst b/docs/physics_models.rst index f871a48d..c8db037a 100644 --- a/docs/physics_models.rst +++ b/docs/physics_models.rst @@ -284,11 +284,8 @@ not need to be JAX-compatible, since explicit sources are an input into the PDE and do not require JIT compilation. Conversely, implicit treatment can be important for accurately resolving the impact of fast-evolving source terms. -All sources can optionally be set to zero, prescribed with non-physics-based formulas -(currently Gaussian or exponential) with user-configurable time-dependent parameters like -amplitude, width, and location, or calculated with a dedicated physics-based model. Not -all sources currently have a model implementation. However, the code modular structure -facilitates easy coupling of additional source models in future work. Specifics of source models +All sources can optionally be set to zero, prescribed with explicit values or calculated with a dedicated physics-based model. +However, the code modular structure facilitates easy coupling of additional source models in future work. Specifics of source models currently implemented in TORAX follow: Ion-electron heat exchange diff --git a/torax/__init__.py b/torax/__init__.py index 9ca6d8f7..9f1936f4 100644 --- a/torax/__init__.py +++ b/torax/__init__.py @@ -73,6 +73,7 @@ def set_jax_precision(): 'geo_t', 'geo_t_plus_dt', 'geometry_provider', + 'source_name', 'x_old', 'state', 'unused_state', @@ -88,6 +89,7 @@ def set_jax_precision(): 'source_profiles', 'source_profile', 'explicit_source_profiles', + 'model_func', 'source_models', 'pedestal_model', 'time_step_calculator', diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index afaf7ddc..a749ed3e 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -25,10 +25,7 @@ from torax.geometry import geometry_provider from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped -from torax.sources import formula_config -from torax.sources import formulas from torax.sources import register_source -from torax.sources import runtime_params as source_runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method @@ -353,48 +350,12 @@ def build_sources_builder_from_config( }, } - If the `mode` is set to `formula_based`, then the you can provide a - `formula_type` key which may have the following values: - - - `default`: Uses the default impl (if the source has one) (default) - - - The other config args are based on the source's RuntimeParams object - outlined above. - - - `exponential`: Exponential profile. - - - The other config args are from `sources.formula_config.Exponential`. - - - `gaussian`: Gaussian profile. - - - The other config args are from `sources.formula_config.Gaussian`. - - E.g. for an example heat source: - - .. code-block:: python - - { - mode: 'formula', - formula_type: 'gaussian', - total: 120e6, # total heating - c1: 0.0, # Source Gaussian central location (in normalized r) - c2: 0.25, # Gaussian width in normalized radial coordinates - } - - If you have custom source implementations, you may update this funtion to - handle those new sources and keys, or you may use the "advanced" configuration - method and build your `SourceModel` object directly. - Args: source_configs: Input config dict defining all sources, with a structure as described above. Returns: A `SourceModelsBuilder`. - - Raises: - ValueError if an input key doesn't match one of the source names defined - above. """ source_builders = { @@ -410,42 +371,33 @@ def _build_single_source_builder_from_config( source_config: dict[str, Any], ) -> source_lib.SourceBuilderProtocol: """Builds a source builder from the input config.""" - registered_source = register_source.get_registered_source(source_name) - runtime_params = registered_source.default_runtime_params_class() - # Update the defaults with the config provided. - source_config = copy.copy(source_config) - if 'mode' in source_config: - mode = source_runtime_params_lib.Mode[source_config.pop('mode').upper()] - runtime_params.mode = mode - formula = None - if 'formula_type' in source_config: - func = source_config.pop('formula_type').lower() - if func == 'default': - pass # Nothing to do here. - elif func == 'exponential': - runtime_params.formula = config_args.recursive_replace( - formula_config.Exponential(), - ignore_extra_kwargs=True, - **source_config, - ) - formula = formulas.Exponential() - elif func == 'gaussian': - runtime_params.formula = config_args.recursive_replace( - formula_config.Gaussian(), - ignore_extra_kwargs=True, - **source_config, - ) - formula = formulas.Gaussian() - else: - raise ValueError(f'Unknown formula_type for source {source_name}: {func}') + supported_source = register_source.get_supported_source(source_name) + if 'model_func' in source_config: + # If the user has specified a model function, try to retrive that from the + # registered source model functions. + model_func = source_config.pop('model_func') + model_function = supported_source.model_functions[model_func] + else: + # Otherwise, use the default model function. + model_function = supported_source.model_functions[ + supported_source.source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + runtime_params = model_function.runtime_params_class() runtime_params = config_args.recursive_replace( runtime_params, ignore_extra_kwargs=True, **source_config ) kwargs = {'runtime_params': runtime_params} - if formula is not None: - kwargs['formula'] = formula - return registered_source.source_builder_class(**kwargs) + source_builder_class = model_function.source_builder_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + supported_source.source_class, + runtime_params_type=model_function.runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + + return source_builder_class(**kwargs) def build_transport_model_builder_from_config( diff --git a/torax/config/tests/build_sim.py b/torax/config/tests/build_sim.py index 39619751..445018d6 100644 --- a/torax/config/tests/build_sim.py +++ b/torax/config/tests/build_sim.py @@ -23,8 +23,6 @@ from torax.geometry import geometry from torax.geometry import geometry_provider from torax.pedestal_model import set_tped_nped -from torax.sources import formula_config -from torax.sources import formulas from torax.sources import runtime_params as source_runtime_params_lib from torax.stepper import linear_theta_method from torax.stepper import nonlinear_theta_method @@ -366,39 +364,13 @@ def test_adding_standard_source_via_config(self): # pytype: enable=attribute-error self.assertEqual( source_models_builder.runtime_params['gas_puff_source'].mode, - source_runtime_params_lib.Mode.FORMULA_BASED, # On by default. + source_runtime_params_lib.Mode.MODEL_BASED, # On by default. ) self.assertEqual( source_models_builder.runtime_params['ohmic_heat_source'].mode, source_runtime_params_lib.Mode.ZERO, ) - def test_updating_formula_via_source_config(self): - """Tests that we can set the formula type and params via the config.""" - source_models_builder = build_sim.build_sources_builder_from_config({ - 'gas_puff_source': { - 'formula_type': 'gaussian', - 'total': 1, - 'c1': 2, - 'c2': 3, - } - }) - source_models = source_models_builder() - gas_source = source_models.sources['gas_puff_source'] - self.assertIsInstance(gas_source.formula, formulas.Gaussian) - gas_source_runtime_params = source_models_builder.runtime_params[ - 'gas_puff_source' - ] - self.assertIsInstance( - gas_source_runtime_params.formula, - formula_config.Gaussian, - ) - # pytype: disable=attribute-error - self.assertEqual(gas_source_runtime_params.formula.total, 1) - self.assertEqual(gas_source_runtime_params.formula.c1, 2) - self.assertEqual(gas_source_runtime_params.formula.c2, 3) - # pytype: enable=attribute-error - def test_missing_transport_model_raises_error(self): with self.assertRaises(ValueError): build_sim.build_transport_model_builder_from_config({}) diff --git a/torax/config/tests/runtime_params_slice.py b/torax/config/tests/runtime_params_slice.py index 52f6a5c2..ee110cf9 100644 --- a/torax/config/tests/runtime_params_slice.py +++ b/torax/config/tests/runtime_params_slice.py @@ -26,9 +26,7 @@ from torax.geometry import geometry from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources -from torax.sources import formula_config from torax.sources import generic_current_source -from torax.sources import runtime_params as sources_params_lib from torax.stepper import runtime_params as stepper_params_lib from torax.tests.test_lib import default_sources from torax.transport_model import runtime_params as transport_params_lib @@ -248,63 +246,6 @@ def test_source_formula_config_has_time_dependent_params(self): ) np.testing.assert_allclose(generic_particle_source.S_tot, 4.0) - with self.subTest('exponential_formula'): - runtime_params = general_runtime_params.GeneralRuntimeParams() - dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources={ - electron_density_sources.GasPuffSource.SOURCE_NAME: ( - sources_params_lib.RuntimeParams( - formula=formula_config.Exponential( - total={0.0: 0.0, 1.0: 1.0}, - c1={0.0: 0.0, 1.0: 2.0}, - c2={0.0: 0.0, 1.0: 3.0}, - ) - ) - ), - }, - torax_mesh=self._geo.torax_mesh, - )( - t=0.25, - ) - gas_puff_source = dcs.sources[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ] - assert isinstance( - gas_puff_source.formula, - formula_config.DynamicExponential, - ) - np.testing.assert_allclose(gas_puff_source.formula.total, 0.25) - np.testing.assert_allclose(gas_puff_source.formula.c1, 0.5) - np.testing.assert_allclose(gas_puff_source.formula.c2, 0.75) - - with self.subTest('gaussian_formula'): - runtime_params = general_runtime_params.GeneralRuntimeParams() - dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources={ - electron_density_sources.GasPuffSource.SOURCE_NAME: ( - sources_params_lib.RuntimeParams( - formula=formula_config.Gaussian( - total={0.0: 0.0, 1.0: 1.0}, - c1={0.0: 0.0, 1.0: 2.0}, - c2={0.0: 0.0, 1.0: 3.0}, - ) - ) - ), - }, - torax_mesh=self._geo.torax_mesh, - )( - t=0.25, - ) - gas_puff_source = dcs.sources[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ] - assert isinstance(gas_puff_source.formula, formula_config.DynamicGaussian) - np.testing.assert_allclose(gas_puff_source.formula.total, 0.25) - np.testing.assert_allclose(gas_puff_source.formula.c1, 0.5) - np.testing.assert_allclose(gas_puff_source.formula.c2, 0.75) - def test_wext_in_dynamic_runtime_params_cannot_be_negative(self): """Tests that wext cannot be negative.""" runtime_params = general_runtime_params.GeneralRuntimeParams() diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index e7ad62d1..028deda0 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -280,11 +280,7 @@ def _prescribe_currents_no_bootstrap( # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo=geo, core_profiles=core_profiles, ) @@ -311,7 +307,6 @@ def _prescribe_currents_no_bootstrap( jtot_hires = _get_jtot_hires( static_runtime_params_slice, dynamic_runtime_params_slice, - dynamic_generic_current_params, geo, bootstrap_profile, Iohm, @@ -362,13 +357,7 @@ def _prescribe_currents_with_bootstrap( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles, ) @@ -388,11 +377,7 @@ def _prescribe_currents_with_bootstrap( # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.generic_current_source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -423,7 +408,6 @@ def _prescribe_currents_with_bootstrap( jtot_hires = _get_jtot_hires( static_runtime_params_slice, dynamic_runtime_params_slice, - dynamic_generic_current_params, geo, bootstrap_profile, Iohm, @@ -477,31 +461,16 @@ def _calculate_currents_from_psi( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles, ) - # Calculate splitting of currents depending on input runtime params. - dynamic_generic_current_params = get_generic_current_params( - dynamic_runtime_params_slice, source_models - ) - # calculate "External" current profile (e.g. ECCD) # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.generic_current_source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -991,7 +960,6 @@ def compute_boundary_conditions( def _get_jtot_hires( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_generic_current_params: generic_current_source.DynamicRuntimeParams, geo: geometry.Geometry, bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile, Iohm: jax.Array | float, @@ -1005,11 +973,7 @@ def _get_jtot_hires( # calculate hi-res "External" current profile (e.g. ECCD) on cell grid. generic_current_hires = generic_current.generic_current_source_hires( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo=geo, ) diff --git a/torax/fvm/calc_coeffs.py b/torax/fvm/calc_coeffs.py index a30cf7b8..13ee3e82 100644 --- a/torax/fvm/calc_coeffs.py +++ b/torax/fvm/calc_coeffs.py @@ -728,9 +728,6 @@ def _calc_coeffs_full( qei = source_models.qei_source.get_qei( static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.qei_source_name - ], geo=geo, # For Qei, always use the current set of core profiles. # In the linear solver, core_profiles is the set of profiles at time t (at diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index bb106b0b..1b95de93 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -578,7 +578,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): ) ) geo = geometry.build_circular_geometry(n_rho=num_cells) - source_models = source_models_lib.SourceModels() + source_models = source_models_builder() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice, dynamic_runtime_params_slice, @@ -729,7 +729,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): ), ) - source_models = source_models_lib.SourceModels() + source_models = source_models_builder() pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice_theta0, diff --git a/torax/sim.py b/torax/sim.py index 81608854..3318fe8e 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -1349,33 +1349,16 @@ def update_current_distribution( bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles, ) - # Calculate splitting of currents depending on input runtime params. - dynamic_generic_current_params = ( - core_profile_setters.get_generic_current_params( - dynamic_runtime_params_slice, source_models - ) - ) - # calculate "External" current profile (e.g. ECCD) # form of external current on face grid generic_current_face = source_models.generic_current_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_generic_current_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_models.generic_current_source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -1525,9 +1508,6 @@ def get_initial_source_profiles( qei = source_models.qei_source.get_qei( static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_models.qei_source_name - ], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index 1f75a8c8..4c41994c 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -91,6 +91,11 @@ class BootstrapCurrentSource(source.Source): """ SOURCE_NAME: ClassVar[str] = 'j_bootstrap' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_neoclassical' + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: @@ -111,14 +116,16 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: def get_value( self, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.BootstrapCurrentProfile: - # Make sure the input mode requested is supported. - self.check_mode(static_source_runtime_params.mode) + static_source_runtime_params = static_runtime_params_slice.sources[ + self.source_name + ] + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] # Make sure the input params are the correct type. if not isinstance(dynamic_source_runtime_params, DynamicRuntimeParams): raise TypeError( @@ -127,7 +134,6 @@ def get_value( ) bootstrap_current = calc_neoclassical( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, geo=geo, temp_ion=core_profiles.temp_ion, temp_el=core_profiles.temp_el, @@ -175,7 +181,6 @@ def get_source_profile_for_affected_core_profile( @jax_utils.jit def calc_neoclassical( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: DynamicRuntimeParams, geo: geometry.Geometry, temp_ion: cell_variable.CellVariable, temp_el: cell_variable.CellVariable, @@ -187,7 +192,6 @@ def calc_neoclassical( Args: dynamic_runtime_params_slice: General configuration parameters. - dynamic_source_runtime_params: Source-specific runtime parameters. geo: Torus geometry. temp_ion: Ion temperature. We don't pass in a full `core_profiles` here because this function is used to create the `Currents` in the initial @@ -200,6 +204,10 @@ def calc_neoclassical( Returns: A BootstrapCurrentProfile. See that class's docstring for more info. """ + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + BootstrapCurrentSource.SOURCE_NAME + ] + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index 312668f1..9033c5b7 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -122,19 +122,20 @@ def calc_relativistic_correction() -> jax.Array: def bremsstrahlung_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_model_func: source_models.SourceModels | None, ) -> jax.Array: """Model function for the Bremsstrahlung heat sink.""" del ( - static_source_runtime_params, static_runtime_params_slice, unused_model_func, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) _, P_brem_profile = calc_bremsstrahlung( core_profiles, @@ -152,16 +153,12 @@ class BremsstrahlungHeatSink(source.Source): """Brehmsstrahlung heat sink for electron heat equation.""" SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'bremsstrahlung_model_func' model_func: source.SourceProfileFunction = bremsstrahlung_model_func @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - """Returns the modes supported by this source.""" - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: diff --git a/torax/sources/electron_cyclotron_source.py b/torax/sources/electron_cyclotron_source.py index e5f02023..e404c15d 100644 --- a/torax/sources/electron_cyclotron_source.py +++ b/torax/sources/electron_cyclotron_source.py @@ -103,14 +103,13 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): gaussian_ec_total_power: array_typing.ScalarFloat -def _calc_heating_and_current( +def calc_heating_and_current( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, - unused_model_func: source_models.SourceModels, + unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Model function for the electron-cyclotron source. @@ -119,11 +118,9 @@ def _calc_heating_and_current( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters. dynamic_runtime_params_slice: Global runtime parameters - dynamic_source_runtime_params: Specific runtime parameters for the - electron-cyclotron source. geo: Magnetic geometry. + source_name: Name of the source. core_profiles: CoreProfiles component of the state. unused_model_func: (unused) source models used in the simulation. @@ -131,10 +128,12 @@ def _calc_heating_and_current( 2D array of electron cyclotron heating power density and current density. """ del ( - static_source_runtime_params, - unused_model_func, + unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] # Helps linter understand the type of dynamic_source_runtime_params. assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Construct the profile @@ -188,7 +187,12 @@ class ElectronCyclotronSource(source.Source): """Electron cyclotron source for the Te and Psi equations.""" SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source" - model_func: source.SourceProfileFunction = _calc_heating_and_current + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = "calc_heating_and_current" + model_func: source.SourceProfileFunction = calc_heating_and_current + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 9038ed23..e03652e2 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -41,7 +41,7 @@ class GasPuffRuntimeParams(runtime_params_lib.RuntimeParams): puff_decay_length: runtime_params_lib.TimeInterpolatedInput = 0.05 # total gas puff particles/s S_puff_tot: runtime_params_lib.TimeInterpolatedInput = 1e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -76,21 +76,22 @@ class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams): # Default formula: exponential with nref normalization. -def _calc_puff_source( +def calc_puff_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from puffs.""" del ( - static_source_runtime_params, unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicGasPuffRuntimeParams) return formulas.exponential_profile( c1=1.0, @@ -108,7 +109,12 @@ class GasPuffSource(source.Source): """Gas puff source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'gas_puff_source' - formula: source.SourceProfileFunction = _calc_puff_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source' + model_func: source.SourceProfileFunction = calc_puff_source + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -125,7 +131,7 @@ class GenericParticleSourceRuntimeParams(runtime_params_lib.RuntimeParams): deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.0 # total particle source S_tot: runtime_params_lib.TimeInterpolatedInput = 1e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -155,7 +161,6 @@ def build_dynamic_params( particle_width=float(self.particle_width.get_value(t)), deposition_location=float(self.deposition_location.get_value(t)), S_tot=float(self.S_tot.get_value(t)), - formula=self.formula.build_dynamic_params(t), prescribed_values=self.prescribed_values.get_value(t), ) @@ -167,21 +172,22 @@ class DynamicParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams): S_tot: array_typing.ScalarFloat -def _calc_generic_particle_source( +def calc_generic_particle_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from SBI.""" del ( - static_source_runtime_params, unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicParticleRuntimeParams) return formulas.gaussian_profile( c1=dynamic_source_runtime_params.deposition_location, @@ -199,7 +205,12 @@ class GenericParticleSource(source.Source): """Neutral-beam injection source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'generic_particle_source' - formula: source.SourceProfileFunction = _calc_generic_particle_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source' + model_func: source.SourceProfileFunction = calc_generic_particle_source + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -218,7 +229,7 @@ class PelletRuntimeParams(runtime_params_lib.RuntimeParams): pellet_deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.85 # total pellet particles/s (continuous pellet model) S_pellet_tot: runtime_params_lib.TimeInterpolatedInput = 2e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -250,21 +261,22 @@ class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams): S_pellet_tot: array_typing.ScalarFloat -def _calc_pellet_source( +def calc_pellet_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from pellets.""" del ( - static_source_runtime_params, unused_source_models, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicPelletRuntimeParams) return formulas.gaussian_profile( c1=dynamic_source_runtime_params.pellet_deposition_location, @@ -282,23 +294,13 @@ class PelletSource(source.Source): """Pellet source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'pellet_source' - formula: source.SourceProfileFunction = _calc_pellet_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_pellet_source' + model_func: source.SourceProfileFunction = calc_pellet_source + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: return (source.AffectedCoreProfile.NE,) - - -# pylint: enable=invalid-name -# The sources below don't have any source-specific implementations, so their -# bodies are empty. You can refer to their base class to see the implementation. -# We define new classes here to: -# a) support any future source-specific implementation. -# b) better readability and human-friendly error messages when debugging. -@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class RecombinationDensitySink(source.Source): - """Recombination sink for the electron density equation.""" - - affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( - source.AffectedCoreProfile.NE, - ) diff --git a/torax/sources/formula_config.py b/torax/sources/formula_config.py deleted file mode 100644 index 589ef2e6..00000000 --- a/torax/sources/formula_config.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Defines the runtime (dynamic) configuration of formulas used in sources.""" - -from __future__ import annotations - -import dataclasses -from typing import TypeAlias - -import chex -from torax import array_typing -from torax import interpolated_param -from torax.config import base -from torax.config import config_args -from torax.geometry import geometry - - -TimeInterpolatedInput: TypeAlias = interpolated_param.TimeInterpolatedInput - - -@chex.dataclass(frozen=True) -class DynamicFormula: - """Base class for dynamic configs.""" - - -@dataclasses.dataclass -class Exponential(base.RuntimeParametersConfig['ExponentialProvider']): - """Configures an exponential formula. - - See formulas.Exponential for more information on how this config is used. - """ - - # floats to parameterize the different formulas. - total: TimeInterpolatedInput = 1.0 - c1: TimeInterpolatedInput = 1.0 - c2: TimeInterpolatedInput = 1.0 - - def make_provider( - self, torax_mesh: geometry.Grid1D | None = None - ) -> ExponentialProvider: - del torax_mesh # Unused. - return ExponentialProvider( - runtime_params_config=self, - total=config_args.get_interpolated_var_single_axis( - self.total, - ), - c1=config_args.get_interpolated_var_single_axis( - self.c1, - ), - c2=config_args.get_interpolated_var_single_axis( - self.c2, - ), - ) - - -@chex.dataclass -class ExponentialProvider(base.RuntimeParametersProvider['DynamicExponential']): - """Runtime parameter provider for a single source/sink term.""" - - runtime_params_config: Exponential - total: interpolated_param.InterpolatedVarSingleAxis - c1: interpolated_param.InterpolatedVarSingleAxis - c2: interpolated_param.InterpolatedVarSingleAxis - - def build_dynamic_params( - self, - t: chex.Numeric, - ) -> DynamicExponential: - return DynamicExponential(**self.get_dynamic_params_kwargs(t)) - - -@chex.dataclass(frozen=True) -class DynamicExponential(DynamicFormula): - - total: array_typing.ScalarFloat - c1: array_typing.ScalarFloat - c2: array_typing.ScalarFloat - - -@dataclasses.dataclass -class Gaussian(base.RuntimeParametersConfig['GaussianProvider']): - """Configures a Gaussian formula. - - See formulas.Gaussian for more information on how this config is used. - """ - - # floats to parameterize the different formulas. - total: TimeInterpolatedInput = 1.0 - c1: TimeInterpolatedInput = 1.0 - c2: TimeInterpolatedInput = 1.0 - - def make_provider( - self, torax_mesh: geometry.Grid1D | None = None - ) -> GaussianProvider: - del torax_mesh # Unused. - return GaussianProvider( - runtime_params_config=self, - total=config_args.get_interpolated_var_single_axis( - self.total, - ), - c1=config_args.get_interpolated_var_single_axis( - self.c1, - ), - c2=config_args.get_interpolated_var_single_axis( - self.c2, - ), - ) - - -@chex.dataclass -class GaussianProvider(base.RuntimeParametersProvider['DynamicGaussian']): - """Runtime parameter provider for a single source/sink term.""" - - runtime_params_config: Gaussian - total: interpolated_param.InterpolatedVarSingleAxis - c1: interpolated_param.InterpolatedVarSingleAxis - c2: interpolated_param.InterpolatedVarSingleAxis - - def build_dynamic_params( - self, - t: chex.Numeric, - ) -> DynamicGaussian: - return DynamicGaussian(**self.get_dynamic_params_kwargs(t)) - - -@chex.dataclass(frozen=True) -class DynamicGaussian(DynamicFormula): - - total: array_typing.ScalarFloat - c1: array_typing.ScalarFloat - c2: array_typing.ScalarFloat diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index 9ff681f6..d3c08c0e 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -13,18 +13,9 @@ # limitations under the License. """Prescribed formulas for computing source profiles.""" - -import dataclasses -from typing import Optional import jax from jax import numpy as jnp -from torax import state -from torax.config import runtime_params_slice from torax.geometry import geometry -from torax.sources import formula_config -from torax.sources import runtime_params - - # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name @@ -103,71 +94,3 @@ def gaussian_profile( geo.vpr_face * S_face, geo.rho_face_norm ) return C * S - - -# pylint: enable=invalid-name - - -# Callable classes used as arguments for Source formulas. - - -@dataclasses.dataclass(frozen=True) -class Exponential: - """Callable class providing an exponential profile.""" - - def __call__( # pytype: disable=name-error - self, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params.StaticRuntimeParams, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, - geo: geometry.Geometry, - unused_state: state.CoreProfiles | None, - unused_source_models: Optional['source_models.SourceModels'] = None, - ) -> jax.Array: - del ( - dynamic_runtime_params_slice, - static_runtime_params_slice, - static_source_runtime_params, - unused_state, - unused_source_models, - ) # Unused. - exp_config = dynamic_source_runtime_params.formula - assert isinstance(exp_config, formula_config.DynamicExponential) - return exponential_profile( - c1=exp_config.c1, - c2=exp_config.c2, - total=exp_config.total, - geo=geo, - ) - - -@dataclasses.dataclass(frozen=True) -class Gaussian: - """Callable class providing a gaussian profile.""" - - def __call__( # pytype: disable=name-error - self, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params.StaticRuntimeParams, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, - geo: geometry.Geometry, - unused_state: state.CoreProfiles | None, - unused_source_models: Optional['source_models.SourceModels'] = None, - ) -> jax.Array: - del ( - dynamic_runtime_params_slice, - static_runtime_params_slice, - static_source_runtime_params, - unused_state, - unused_source_models, - ) # Unused. - gaussian_config = dynamic_source_runtime_params.formula - assert isinstance(gaussian_config, formula_config.DynamicGaussian) - return gaussian_profile( - c1=gaussian_config.c1, - c2=gaussian_config.c2, - total=gaussian_config.total, - geo=geo, - ) diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index 9e395874..ce6f4f34 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -121,20 +121,15 @@ def calc_fusion( # pytype: disable=name-error def fusion_heat_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Model function for fusion heating.""" # pytype: enable=name-error - del ( - dynamic_source_runtime_params, - static_source_runtime_params, - static_runtime_params_slice, - ) # Unused. + del static_runtime_params_slice, source_name # Unused. # pylint: disable=invalid-name _, Pfus_i, Pfus_e = calc_fusion( geo, core_profiles, dynamic_runtime_params_slice.numerics.nref @@ -148,8 +143,13 @@ class FusionHeatSource(source.Source): """Fusion heat source for both ion and electron heat.""" SOURCE_NAME: ClassVar[str] = 'fusion_heat_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'fusion_heat_model_func' model_func: source.SourceProfileFunction = fusion_heat_model_func + @property + def source_name(self) -> str: + return self.SOURCE_NAME + @property def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: return ( diff --git a/torax/sources/generic_current_source.py b/torax/sources/generic_current_source.py index bef21640..a578c734 100644 --- a/torax/sources/generic_current_source.py +++ b/torax/sources/generic_current_source.py @@ -52,7 +52,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): # Toggles if external current is provided absolutely or as a fraction of Ip. use_absolute_current: bool = False - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED @property def grid_type(self) -> base.GridType: @@ -105,12 +105,11 @@ def __post_init__(self): # pytype bug: does not treat 'source_models.SourceModels' as a forward reference # pytype: disable=name-error -def _calculate_generic_current_face( +def calculate_generic_current_face( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: @@ -118,23 +117,25 @@ def _calculate_generic_current_face( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters. dynamic_runtime_params_slice: Parameter configuration at present timestep. - dynamic_source_runtime_params: Source-specific parameters at the present - timestep. geo: Tokamak geometry. + source_name: Name of the source. unused_state: State argument not used in this function but is present to adhere to the source API. + unused_source_models: Source models argument not used in this function but + is present to adhere to the source API. Returns: External current density profile along the face grid. """ del ( - static_source_runtime_params, static_runtime_params_slice, unused_state, unused_source_models, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] # pytype: enable=name-error assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) Iext = _calculate_Iext( @@ -163,10 +164,9 @@ def _calculate_generic_current_face( # pytype: disable=name-error def _calculate_generic_current_hires( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, unused_state: state.CoreProfiles | None = None, unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: @@ -174,23 +174,25 @@ def _calculate_generic_current_hires( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters. dynamic_runtime_params_slice: Parameter configuration at present timestep. - dynamic_source_runtime_params: Source-specific parameters at the present - timestep. geo: Tokamak geometry. + source_name: Name of the source. unused_state: State argument not used in this function but is present to adhere to the source API. + unused_source_models: Source models argument not used in this function but + is present to adhere to the source API. Returns: Generic current density profile along the hires cell grid. """ del ( - static_source_runtime_params, static_runtime_params_slice, unused_state, unused_source_models, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] # pytype: enable=name-error assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) Iext = _calculate_Iext( @@ -233,8 +235,13 @@ class GenericCurrentSource(source.Source): """A generic current density source profile.""" SOURCE_NAME: ClassVar[str] = 'generic_current_source' - formula: source.SourceProfileFunction = _calculate_generic_current_face + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face' hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires + model_func: source.SourceProfileFunction = calculate_generic_current_face + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -261,10 +268,8 @@ def get_source_profile_for_affected_core_profile( # pytype: disable=name-error def generic_current_source_hires( self, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, unused_source_models: Optional['source_models.SourceModels'] = None, @@ -272,8 +277,13 @@ def generic_current_source_hires( # pytype: enable=name-error """Return the current density profile along the hires cell grid.""" del unused_state, unused_source_models # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] + static_source_runtime_params = static_runtime_params_slice.sources[ + self.source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) - self.check_mode(static_source_runtime_params.mode) # Interpolate prescribed values onto the hires grid hires_prescribed_values = jnp.where( @@ -289,19 +299,13 @@ def generic_current_source_hires( return source.get_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, core_profiles=None, # There is no model for this source. - model_func=( - lambda _0, _1, _2, _3, _4, _5, _6: jnp.zeros_like( - geo.rho_hires_norm - ) - ), - formula=self.hires_formula, + model_func=self.hires_formula, output_shape=geo.rho_hires_norm.shape, prescribed_values=hires_prescribed_values, source_models=getattr(self, 'source_models', None), + source_name=self.source_name, ) diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index 3dd761fd..f7eca7ba 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -48,7 +48,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): Ptot: runtime_params_lib.TimeInterpolatedInput = 120e6 # electron heating fraction el_heat_fraction: runtime_params_lib.TimeInterpolatedInput = 0.66666 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -113,24 +113,24 @@ def calc_generic_heat_source( # pytype: disable=name-error -def _default_formula( +def default_formula( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Returns the default formula-based ion/electron heat source profile.""" # pytype: enable=name-error del ( - dynamic_runtime_params_slice, core_profiles, - static_source_runtime_params, static_runtime_params_slice, unused_source_models, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) ion, el = calc_generic_heat_source( geo, @@ -150,7 +150,12 @@ class GenericIonElectronHeatSource(source.Source): """Generic heat source for both ion and electron heat.""" SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source' - formula: source.SourceProfileFunction = _default_formula + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula' + model_func: source.SourceProfileFunction = default_formula + + @property + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: diff --git a/torax/sources/impurity_radiation_heat_sink.py b/torax/sources/impurity_radiation_heat_sink.py index 4e9bff90..f2587c74 100644 --- a/torax/sources/impurity_radiation_heat_sink.py +++ b/torax/sources/impurity_radiation_heat_sink.py @@ -16,6 +16,7 @@ """Basic impurity radiation heat sink for electron heat equation..""" import dataclasses +from typing import ClassVar import chex import jax @@ -30,12 +31,11 @@ from torax.sources import source_models as source_models_lib -def _radially_constant_fraction_of_Pin( # pylint: disable=invalid-name +def radially_constant_fraction_of_Pin( # pylint: disable=invalid-name static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, source_models: source_models_lib.SourceModels, ) -> jax.Array: @@ -46,37 +46,30 @@ def _radially_constant_fraction_of_Pin( # pylint: disable=invalid-name Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static source runtime parameters. dynamic_runtime_params_slice: Dynamic runtime parameters. - dynamic_source_runtime_params: Dynamic source runtime parameters. geo: Geometry object. + source_name: Name of the source. core_profiles: Core profiles object. source_models: Source models object. Returns: The heat sink profile. """ - del (static_source_runtime_params,) # Unused + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Based on source_models.sum_sources_temp_el and source_models.calc_and_sum # sources_psi, but only summing over heating *input* sources # (Pohm + Paux + Palpha + ...) and summing over *both* ion + electron heating - def get_heat_source_profile( - source_name: str, source: source_lib.Source - ) -> jax.Array: + def get_heat_source_profile(source: source_lib.Source) -> jax.Array: # TODO(b/381543891): Currently this recomputes the profile for each source, # which is inefficient. Refactor to avoid this. profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - source_name - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - source_name - ], geo=geo, core_profiles=core_profiles, ) @@ -95,7 +88,6 @@ def get_heat_source_profile( } source_profiles = jax.tree.map( get_heat_source_profile, - list(heat_sources.keys()), list(heat_sources.values()), ) Qtot_in = jnp.sum(jnp.stack(source_profiles), axis=0) @@ -148,19 +140,17 @@ class ImpurityRadiationHeatSink(source_lib.Source): """Impurity radiation heat sink for electron heat equation.""" SOURCE_NAME = "impurity_radiation_heat_sink" - source_models: source_models_lib.SourceModels + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = ( + "radially_constant_fraction_of_Pin" + ) model_func: source_lib.SourceProfileFunction = ( - _radially_constant_fraction_of_Pin + radially_constant_fraction_of_Pin ) + source_models: source_models_lib.SourceModels @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - """Returns the modes supported by this source.""" - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles( diff --git a/torax/sources/ion_cyclotron_source.py b/torax/sources/ion_cyclotron_source.py index 13e92f4f..1439f198 100644 --- a/torax/sources/ion_cyclotron_source.py +++ b/torax/sources/ion_cyclotron_source.py @@ -364,12 +364,11 @@ def _helium3_tail_temperature( return core_profiles.temp_el.value * (1 + epsilon) -def _icrh_model_func( +def icrh_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, unused_source_models: source_models.SourceModels | None, toric_nn: ToricNNWrapper, @@ -377,10 +376,12 @@ def _icrh_model_func( """Compute ion/electron heat source terms.""" del ( unused_source_models, - dynamic_runtime_params_slice, - static_source_runtime_params, static_runtime_params_slice, ) # Unused. + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + source_name + ] + assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) # Construct inputs for ToricNN. volume = integrate.trapezoid(geo.vpr_face, geo.rho_face_norm) @@ -484,8 +485,6 @@ def _icrh_model_func( source_ion += power_deposition_2T * dynamic_source_runtime_params.Ptot return jnp.stack([source_ion, source_el]) - - # pylint: enable=invalid-name @@ -494,26 +493,11 @@ class IonCyclotronSource(source.Source): """Ion cyclotron source with surrogate model.""" SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source' - # The model function is fixed to _icrh_model_func because that is the only - # supported implementation of this source. - # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set the default to the - # desired value. - model_func: source.SourceProfileFunction = dataclasses.field( - init=False, - default_factory=lambda: functools.partial( - _icrh_model_func, - toric_nn=ToricNNWrapper(), - ), - ) + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'icrh_model_func' @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -525,3 +509,29 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @property def output_shape_getter(self) -> source.SourceOutputShapeFunction: return source.get_ion_el_output_shape + + +@dataclasses.dataclass(kw_only=True, frozen=False) +class IonCyclotronSourceBuilder: + """Builder for the IonCyclotronSource.""" + + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams + ) + links_back: bool = False + model_func: source.SourceProfileFunction | None = None + + def __post_init__(self): + if self.model_func is None: + self.model_func = functools.partial( + icrh_model_func, + toric_nn=ToricNNWrapper(), + ) + + def __call__( + self, + ) -> IonCyclotronSource: + + return IonCyclotronSource( + model_func=self.model_func, + ) diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py index 7a31037f..a364c0f3 100644 --- a/torax/sources/ohmic_heat_source.py +++ b/torax/sources/ohmic_heat_source.py @@ -145,16 +145,14 @@ def calc_psidot( def ohmic_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels | None = None, + source_models: source_models_lib.SourceModels, ) -> jax.Array: """Returns the Ohmic source for electron heat equation.""" - del dynamic_source_runtime_params, static_source_runtime_params - + del source_name # Unused. if source_models is None: raise TypeError('source_models is a required argument for ohmic_model_func') @@ -192,25 +190,14 @@ class OhmicHeatSource(source_lib.Source): """ SOURCE_NAME: ClassVar[str] = 'ohmic_heat_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'ohmic_model_func' + model_func: source_lib.SourceProfileFunction = ohmic_model_func # Users must pass in a pointer to the complete set of sources to this object. source_models: source_models_lib.SourceModels - # The model function is fixed to ohmic_model_func because that is the only - # supported implementation of this source. - # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set the default to the - # desired value. - model_func: source_lib.SourceProfileFunction | None = dataclasses.field( - init=False, - default_factory=lambda: ohmic_model_func, - ) @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles( diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index 26ada6dd..3b002bd2 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -72,14 +72,11 @@ class QeiSource(source.Source): """ SOURCE_NAME: ClassVar[str] = 'qei_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'model_based_qei' @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) + def source_name(self) -> str: + return self.SOURCE_NAME @property def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]: @@ -92,14 +89,15 @@ def get_qei( self, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.QeiInfo: """Computes the value of the source.""" - self.check_mode(static_runtime_params_slice.sources[self.SOURCE_NAME].mode) + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] return jax.lax.cond( - static_runtime_params_slice.sources[self.SOURCE_NAME].mode + static_runtime_params_slice.sources[self.source_name].mode == runtime_params_lib.Mode.MODEL_BASED.value, lambda: _model_based_qei( static_runtime_params_slice, diff --git a/torax/sources/register_source.py b/torax/sources/register_source.py index c705a204..f8cf784f 100644 --- a/torax/sources/register_source.py +++ b/torax/sources/register_source.py @@ -35,6 +35,7 @@ class to build, the runtime associated with that source and (optionally) the expected to grow over time as TORAX becomes more feature rich but ultimately be finite. """ + import dataclasses from typing import Type @@ -54,116 +55,149 @@ class to build, the runtime associated with that source and (optionally) the @dataclasses.dataclass(frozen=True) -class RegisteredSource: - source_class: Type[source.Source] - source_builder_class: source.SourceBuilderProtocol - default_runtime_params_class: Type[runtime_params.RuntimeParams] - - -def _register_new_source( - source_class: Type[source.Source], - default_runtime_params_class: Type[runtime_params.RuntimeParams], - source_builder_class: source.SourceBuilderProtocol | None = None, - links_back: bool = False, -) -> RegisteredSource: - """Register source class, default runtime params and (optional) builder for this source. - - Args: - source_class: The source class. - default_runtime_params_class: The default runtime params class. - source_builder_class: The source builder class. If None, a default builder - is created which uses the source class and default runtime params class to - construct a builder for that source. - links_back: Whether the source requires a reference to all the source - models. - - Returns: - A `RegisteredSource` dataclass containing the source class, source - builder class, and default runtime params class. - """ - if source_builder_class is None: - builder_class = source.make_source_builder( - source_class, - runtime_params_type=default_runtime_params_class, - links_back=links_back, - ) - else: - builder_class = source_builder_class - - return RegisteredSource( - source_class=source_class, - source_builder_class=builder_class, - default_runtime_params_class=default_runtime_params_class, - ) +class ModelFunction: + source_profile_function: source.SourceProfileFunction | None + runtime_params_class: Type[runtime_params.RuntimeParams] + source_builder_class: source.SourceBuilderProtocol | None = None + links_back: bool = False -_REGISTERED_SOURCES = { - bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME: ( - _register_new_source( - source_class=bootstrap_current_source.BootstrapCurrentSource, - default_runtime_params_class=bootstrap_current_source.RuntimeParams, - ) +@dataclasses.dataclass(frozen=True) +class SupportedSource: + """Source that can be used in TORAX and any associated model functions.""" + source_class: Type[source.Source] + model_functions: dict[str, ModelFunction] + + +_SUPPORTED_SOURCES = { + bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME: SupportedSource( + source_class=bootstrap_current_source.BootstrapCurrentSource, + model_functions={ + bootstrap_current_source.BootstrapCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=bootstrap_current_source.RuntimeParams, + ) + }, ), - generic_current_source.GenericCurrentSource.SOURCE_NAME: ( - _register_new_source( - source_class=generic_current_source.GenericCurrentSource, - default_runtime_params_class=generic_current_source.RuntimeParams, - ) + generic_current_source.GenericCurrentSource.SOURCE_NAME: SupportedSource( + source_class=generic_current_source.GenericCurrentSource, + model_functions={ + generic_current_source.GenericCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=generic_current_source.calculate_generic_current_face, + runtime_params_class=generic_current_source.RuntimeParams, + ) + }, ), - electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: _register_new_source( + electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: SupportedSource( source_class=electron_cyclotron_source.ElectronCyclotronSource, - default_runtime_params_class=electron_cyclotron_source.RuntimeParams, + model_functions={ + electron_cyclotron_source.ElectronCyclotronSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_cyclotron_source.calc_heating_and_current, + runtime_params_class=electron_cyclotron_source.RuntimeParams, + ) + }, ), - electron_density_sources.GenericParticleSource.SOURCE_NAME: _register_new_source( + electron_density_sources.GenericParticleSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.GenericParticleSource, - default_runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams, + model_functions={ + electron_density_sources.GenericParticleSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_generic_particle_source, + runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams, + ) + }, ), - electron_density_sources.GasPuffSource.SOURCE_NAME: _register_new_source( + electron_density_sources.GasPuffSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.GasPuffSource, - default_runtime_params_class=electron_density_sources.GasPuffRuntimeParams, + model_functions={ + electron_density_sources.GasPuffSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_puff_source, + runtime_params_class=electron_density_sources.GasPuffRuntimeParams, + ) + }, ), - electron_density_sources.PelletSource.SOURCE_NAME: _register_new_source( + electron_density_sources.PelletSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.PelletSource, - default_runtime_params_class=electron_density_sources.PelletRuntimeParams, + model_functions={ + electron_density_sources.PelletSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_pellet_source, + runtime_params_class=electron_density_sources.PelletRuntimeParams, + ) + }, ), - ion_el_heat.GenericIonElectronHeatSource.SOURCE_NAME: _register_new_source( + ion_el_heat.GenericIonElectronHeatSource.SOURCE_NAME: SupportedSource( source_class=ion_el_heat.GenericIonElectronHeatSource, - default_runtime_params_class=ion_el_heat.RuntimeParams, + model_functions={ + ion_el_heat.GenericIonElectronHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=ion_el_heat.default_formula, + runtime_params_class=ion_el_heat.RuntimeParams, + ) + }, ), - fusion_heat_source.FusionHeatSource.SOURCE_NAME: _register_new_source( + fusion_heat_source.FusionHeatSource.SOURCE_NAME: SupportedSource( source_class=fusion_heat_source.FusionHeatSource, - default_runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, + model_functions={ + fusion_heat_source.FusionHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=fusion_heat_source.fusion_heat_model_func, + runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, + ) + }, ), - qei_source.QeiSource.SOURCE_NAME: _register_new_source( + qei_source.QeiSource.SOURCE_NAME: SupportedSource( source_class=qei_source.QeiSource, - default_runtime_params_class=qei_source.RuntimeParams, + model_functions={ + qei_source.QeiSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=qei_source.RuntimeParams, + ) + }, ), - ohmic_heat_source.OhmicHeatSource.SOURCE_NAME: _register_new_source( + ohmic_heat_source.OhmicHeatSource.SOURCE_NAME: SupportedSource( source_class=ohmic_heat_source.OhmicHeatSource, - default_runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, - links_back=True, + model_functions={ + ohmic_heat_source.OhmicHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ( + ModelFunction( + source_profile_function=ohmic_heat_source.ohmic_model_func, + runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, + links_back=True, + ) + ) + }, ), - bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME: ( - _register_new_source( - source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, - default_runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, - ) + bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME: SupportedSource( + source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, + model_functions={ + bremsstrahlung_heat_sink.BremsstrahlungHeatSink.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=bremsstrahlung_heat_sink.bremsstrahlung_model_func, + runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, + ) + }, ), - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: _register_new_source( + ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: SupportedSource( source_class=ion_cyclotron_source.IonCyclotronSource, - default_runtime_params_class=ion_cyclotron_source.RuntimeParams, + model_functions={ + ion_cyclotron_source.IonCyclotronSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=ion_cyclotron_source.RuntimeParams, + source_builder_class=ion_cyclotron_source.IonCyclotronSourceBuilder, + ) + }, ), - impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: _register_new_source( + impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: SupportedSource( source_class=impurity_radiation_heat_sink.ImpurityRadiationHeatSink, - default_runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, - links_back=True, + model_functions={ + impurity_radiation_heat_sink.ImpurityRadiationHeatSink.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=impurity_radiation_heat_sink.radially_constant_fraction_of_Pin, + runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, + links_back=True, + ) + }, ), } -def get_registered_source(source_name: str) -> RegisteredSource: - """Used when building a simulation to get the registered source.""" - if source_name in _REGISTERED_SOURCES: - return _REGISTERED_SOURCES[source_name] +def get_supported_source(source_name: str) -> SupportedSource: + """Used when building a simulation to get the supported source.""" + if source_name in _SUPPORTED_SOURCES: + return _SUPPORTED_SOURCES[source_name] else: raise RuntimeError(f'Source:{source_name} has not been registered.') diff --git a/torax/sources/runtime_params.py b/torax/sources/runtime_params.py index 4200cffc..b42bac7a 100644 --- a/torax/sources/runtime_params.py +++ b/torax/sources/runtime_params.py @@ -25,7 +25,6 @@ from torax import interpolated_param from torax.config import base from torax.geometry import geometry -from torax.sources import formula_config TimeInterpolatedInput = interpolated_param.TimeInterpolatedInput @@ -42,15 +41,10 @@ class Mode(enum.Enum): # explicit depending on the model implementation. MODEL_BASED = 1 - # Source values come from a prescribed (possibly time-dependent) formula that - # is not dependent on the state of the system. These formulas may be dependent - # on the config and geometry of the system. - FORMULA_BASED = 2 - # Source values come from a pre-determined set of values, that may evolve in # time. Values can be drawn from a file or an array. These sources are always # explicit. - PRESCRIBED = 3 + PRESCRIBED = 2 @dataclasses.dataclass @@ -81,12 +75,6 @@ class RuntimeParams(base.RuntimeParametersConfig): # running the simulation. is_explicit: bool = False - # Parameters used only when the source is using a prescribed formula to - # compute its profile. - formula: base.RuntimeParametersConfig = dataclasses.field( - default_factory=formula_config.Exponential - ) - # Prescribed values for the source. Used only when the source is fully # prescribed (i.e. source.mode == Mode.PRESCRIBED). # The default here is a vector of all zeros along for all rho and time, and @@ -118,7 +106,6 @@ class RuntimeParamsProvider( """Runtime parameter provider for a single source/sink term.""" runtime_params_config: RuntimeParams - formula: base.RuntimeParametersProvider prescribed_values: interpolated_param.InterpolatedVarTimeRho def get_dynamic_params_kwargs( @@ -148,8 +135,6 @@ class DynamicRuntimeParams: stateless, so these params are their inputs to determine their output profiles. """ - - formula: formula_config.DynamicFormula prescribed_values: array_typing.ArrayFloat diff --git a/torax/sources/source.py b/torax/sources/source.py index 34f03278..7539e88c 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -28,7 +28,7 @@ import enum import types import typing -from typing import Any, Callable, Optional, Protocol, TypeAlias +from typing import Any, Callable, ClassVar, Optional, Protocol, TypeAlias # We use Optional here because | doesn't work with string name types. # We use string name 'source_models.SourceModels' in this file to avoid @@ -43,21 +43,23 @@ from torax.sources import runtime_params as runtime_params_lib -# Sources implement these functions to be able to provide source profiles. # pytype bug: 'source_models.SourceModels' not treated as forward reference -SourceProfileFunction: TypeAlias = Callable[ # pytype: disable=name-error - [ # Arguments - runtime_params_slice.StaticRuntimeParamsSlice, # Static runtime params. - runtime_params_lib.StaticRuntimeParams, # Source-specific params. - runtime_params_slice.DynamicRuntimeParamsSlice, # General config params - runtime_params_lib.DynamicRuntimeParams, # Source-specific params. - geometry.Geometry, - state.CoreProfiles, - Optional['source_models.SourceModels'], - ], - # Returns a JAX array, tuple of arrays, or mapping of arrays. - chex.ArrayTree, -] +# pytype: disable=name-error +@typing.runtime_checkable +class SourceProfileFunction(Protocol): + """Sources implement these functions to be able to provide source profiles.""" + + def __call__( + self, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + source_name: str, + core_profiles: state.CoreProfiles, + source_models: Optional['source_models.SourceModels'], + ) -> chex.ArrayTree: + ... +# pytype: enable=name-error # Any callable which takes the dynamic runtime_params, geometry, and optional @@ -105,6 +107,8 @@ class Source(abc.ABC): are in turn used to compute coeffs in sim.py. Attributes: + DEFAULT_MODEL_FUNCTION_NAME: The name of the model function used with this + source if another isn't specified. runtime_params: Input dataclass containing all the source-specific runtime parameters. At runtime, the parameters here are interpolated to a specific time t and then passed to the model_func or formula, depending on the mode @@ -114,11 +118,6 @@ class Source(abc.ABC): By default, the number of affected core profiles should equal the rank of the output shape returned by output_shape_getter. Subclasses may override this requirement. - supported_modes: Defines how the source computes its profile. Can be set to - zero, model-based, etc. At runtime, the input config (the RuntimeParams or - the DynamicRuntimeParams) will specify which supported type the Source is - running with. If the runtime config specifies an unsupported type, an - error will raise. output_shape_getter: Callable which returns the shape of the profiles given by this source. model_func: The function used when the the runtime type is set to @@ -128,9 +127,13 @@ class Source(abc.ABC): affected_core_profiles_ints: Derived property from the affected_core_profiles. Integer values of those enums. """ - + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default' model_func: SourceProfileFunction | None = None - formula: SourceProfileFunction | None = None + + @property + @abc.abstractmethod + def source_name(self) -> str: + """Returns the name of the source.""" @property @abc.abstractmethod @@ -142,36 +145,14 @@ def output_shape_getter(self) -> SourceOutputShapeFunction: """Returns a function which returns the shape of the source's output.""" return get_cell_profile_shape - @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - """Returns the modes supported by this source.""" - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.FORMULA_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) - @property def affected_core_profiles_ints(self) -> tuple[int, ...]: return tuple([int(cp) for cp in self.affected_core_profiles]) - def check_mode( - self, - mode: int, - ): - """Raises an error if the source type is not supported.""" - if runtime_params_lib.Mode(mode) not in self.supported_modes: - raise ValueError( - f'This source supports the following modes: {self.supported_modes}.' - f' Unsupported mode provided: {mode}.' - ) - def get_value( self, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> chex.ArrayTree: @@ -179,11 +160,8 @@ def get_value( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters for this source. dynamic_runtime_params_slice: Slice of the general TORAX config that can be used as input for this time step. - dynamic_source_runtime_params: Slice of this source's runtime parameters - at a specific time t. geo: Geometry of the torus. core_profiles: Core plasma profiles. May be the profiles at the start of the time step or a "live" set of core profiles being actively updated @@ -195,31 +173,21 @@ def get_value( Returns: Array, arrays, or nested dataclass/dict of arrays for the source profile. """ - self.check_mode(static_source_runtime_params.mode) + dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ + self.source_name + ] output_shape = self.output_shape_getter(geo) - model_func = ( - (lambda _0, _1, _2, _3, _4, _5, _6: jnp.zeros(output_shape)) - if self.model_func is None - else self.model_func - ) - formula = ( - (lambda _0, _1, _2, _3, _4, _5, _6: jnp.zeros(output_shape)) - if self.formula is None - else self.formula - ) return get_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, core_profiles=core_profiles, - model_func=model_func, - formula=formula, + model_func=self.model_func, prescribed_values=dynamic_source_runtime_params.prescribed_values, output_shape=output_shape, source_models=getattr(self, 'source_models', None), + source_name=self.source_name, ) def get_source_profile_for_affected_core_profile( @@ -301,13 +269,11 @@ def get_profile_shape(self, geo: geometry.Geometry) -> tuple[int, ...]: # pytype: disable=name-error def get_source_profiles( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + source_name: str, core_profiles: state.CoreProfiles, - model_func: SourceProfileFunction, - formula: SourceProfileFunction, + model_func: SourceProfileFunction | None, prescribed_values: chex.Array, output_shape: tuple[int, ...], source_models: Optional['source_models.SourceModels'], @@ -321,16 +287,13 @@ def get_source_profiles( Args: static_runtime_params_slice: Static runtime parameters. - static_source_runtime_params: Static runtime parameters for this source. dynamic_runtime_params_slice: Slice of the general TORAX config that can be used as input for this time step. - dynamic_source_runtime_params: Slice of this source's runtime parameters at - a specific time t. geo: Geometry information. Used as input to the source profile functions. + source_name: The name of the source. core_profiles: Core plasma profiles. Used as input to the source profile functions. model_func: Model function. - formula: Formula implementation. prescribed_values: Array of values for this timeslice, interpolated onto the grid (ie with shape output_shape) output_shape: Expected shape of the output array. @@ -340,25 +303,18 @@ def get_source_profiles( Output array of a profile or concatenated/stacked profiles. """ # pytype: enable=name-error - mode = static_source_runtime_params.mode + mode = static_runtime_params_slice.sources[source_name].mode match mode: case runtime_params_lib.Mode.MODEL_BASED.value: + if model_func is None: + raise ValueError( + 'Source is in MODEL_BASED mode but has no model function.' + ) return model_func( static_runtime_params_slice, - static_source_runtime_params, - dynamic_runtime_params_slice, - dynamic_source_runtime_params, - geo, - core_profiles, - source_models, - ) - case runtime_params_lib.Mode.FORMULA_BASED.value: - return formula( - static_runtime_params_slice, - static_source_runtime_params, dynamic_runtime_params_slice, - dynamic_source_runtime_params, geo, + source_name, core_profiles, source_models, ) @@ -428,6 +384,7 @@ def is_source_builder(obj, raise_if_false: bool = False) -> bool: def _convert_source_builder_to_init_kwargs( source_builder: ..., + model_func: SourceProfileFunction | None, ) -> dict[str, Any]: """Returns a dict of init kwargs for the source builder.""" source_init_kwargs = {} @@ -439,12 +396,14 @@ def _convert_source_builder_to_init_kwargs( # including turning custom dataclasses with __call__ methods into # plain Python dictionaries. source_init_kwargs[field.name] = getattr(source_builder, field.name) + source_init_kwargs['model_func'] = model_func return source_init_kwargs def make_source_builder( source_type: ..., runtime_params_type: ... = runtime_params_lib.RuntimeParams, + model_func: SourceProfileFunction | None = None, links_back=False, ) -> SourceBuilderProtocol: """Given a Source type, returns a Builder for that type. @@ -455,6 +414,7 @@ def make_source_builder( source_type: The Source class to make a builder for. runtime_params_type: The type of `runtime_params` field which will be added to the builder dataclass. + model_func: The model function to pass to the source. links_back: If True, the Source class has a `source_models` field linking back to the SourceModels object. This must be passed to the builder's __call__ method. @@ -511,7 +471,10 @@ def check_kwargs(source_init_kwargs, context_msg): assert all([isinstance(var, runtime_params_lib.Mode) for var in v]) elif f.type == 'SourceProfileFunction | None': assert v is None or callable(v) - elif f.type == 'source.SourceProfileFunction': + elif f.type in [ + 'source.SourceProfileFunction', + 'source_lib.SourceProfileFunction', + ]: if not callable(v): raise TypeError( f'While {context_msg} {source_type} got field ' @@ -536,8 +499,8 @@ def check_kwargs(source_init_kwargs, context_msg): raise TypeError(f'Unrecognized type string: {f.type}') # Check if the field is a parameterized generic. - # Python cannot check isinstance for parameterized generics, so we need - # to handle those cases differently. + # Python cannot check isinstance for parameterized generics, so we ignore + # these cases for now. # For instance, if a field type is `tuple[float, ...]` and the value is # valid, like `(1, 2, 3)`, then `isinstance(v, f.type)` would raise a # TypeError. @@ -545,14 +508,16 @@ def check_kwargs(source_init_kwargs, context_msg): type(f.type) == types.GenericAlias # pylint: disable=unidiomatic-typecheck or typing.get_origin(f.type) is not None ): - # Do a superficial check in these instances. Only check that the origin - # type matches the value. Don't look into the rest of the object. - if not isinstance(v, typing.get_origin(f.type)): - raise TypeError( - f'While {context_msg} {source_type} got field {f.name} with ' - f'input type {type(v)} but an expected type {f.type}.' - ) - + # For `Union`s check if the value is a member of the union. + # `typing.Union` is for types defined with `Union[A, B, C]` syntax. + # `types.UnionType` is for types defined with `A | B | C` syntax. + if typing.get_origin(f.type) in [typing.Union, types.UnionType]: + if not isinstance(v, typing.get_args(f.type)): + raise TypeError( + f'While {context_msg} {source_type} got argument ' + f'{f.name} of type {type(v)} but expected ' + f'{f.type}).' + ) else: try: type_works = isinstance(v, f.type) @@ -572,7 +537,9 @@ def check_kwargs(source_init_kwargs, context_msg): # pylint doesn't like this function name because it doesn't realize # this function is to be installed in a class def __post_init__(self): # pylint:disable=invalid-name - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, model_func + ) check_kwargs(source_init_kwargs, 'making builder') # check_kwargs checks only the kwargs to Source, not SourceBuilder, # so it doesn't check "runtime_params" @@ -598,7 +565,10 @@ def check_source(source): if links_back: def build_source(self, source_models): - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, + model_func, + ) source_init_kwargs['source_models'] = source_models check_kwargs(source_init_kwargs, 'building') source = source_type(**source_init_kwargs) @@ -608,7 +578,10 @@ def build_source(self, source_models): else: def build_source(self): - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, + model_func, + ) check_kwargs(source_init_kwargs, 'building') source = source_type(**source_init_kwargs) check_source(source) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 2390d1ed..851a7a38 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -69,15 +69,11 @@ def build_source_profiles( """ # Bootstrap current is a special-case source with multiple outputs, so handle # it here. - dynamic_bootstrap_runtime_params = dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ] static_bootstrap_runtime_params = static_runtime_params_slice.sources[ source_models.j_bootstrap_name ] bootstrap_profiles = _build_bootstrap_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, static_runtime_params_slice=static_runtime_params_slice, static_source_runtime_params=static_bootstrap_runtime_params, geo=geo, @@ -106,7 +102,6 @@ def _build_bootstrap_profiles( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, static_source_runtime_params: runtime_params_lib.StaticRuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, j_bootstrap_source: bootstrap_current_source.BootstrapCurrentSource, @@ -122,8 +117,6 @@ def _build_bootstrap_profiles( bootstrap current source that do not change from time step to time step. dynamic_runtime_params_slice: Input config for this time step. Can change from time step to time step. - dynamic_source_runtime_params: Input runtime parameters for this time step, - specific to the bootstrap current source. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if @@ -141,9 +134,7 @@ def _build_bootstrap_profiles( """ bootstrap_profile = j_bootstrap_source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, core_profiles=core_profiles, ) @@ -240,9 +231,6 @@ def _build_standard_source_profiles( affected_core_profiles_set = set(affected_core_profiles) for source_name, source in source_models.standard_sources.items(): if affected_core_profiles_set.intersection(source.affected_core_profiles): - dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ - source_name - ] static_source_runtime_params = static_runtime_params_slice.sources[ source_name ] @@ -253,9 +241,7 @@ def _build_standard_source_profiles( ), source.get_value( static_runtime_params_slice, - static_source_runtime_params, dynamic_runtime_params_slice, - dynamic_source_runtime_params, geo, core_profiles, ), @@ -359,15 +345,11 @@ def calc_and_sum_sources_psi( affected_core_profile=source_lib.AffectedCoreProfile.PSI.value, geo=geo, ) - dynamic_bootstrap_runtime_params = dynamic_runtime_params_slice.sources[ - source_models.j_bootstrap_name - ] static_bootstrap_runtime_params = static_runtime_params_slice.sources[ source_models.j_bootstrap_name ] j_bootstrap_profiles = _build_bootstrap_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_bootstrap_runtime_params, static_runtime_params_slice=static_runtime_params_slice, static_source_runtime_params=static_bootstrap_runtime_params, geo=geo, @@ -404,26 +386,16 @@ class SourceModels: .. code-block:: python # Define an electron-density source with a time-dependent Gaussian profile. - my_custom_source = source.SingleProfileSource( - supported_modes=( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.FORMULA_BASED, - ), - affected_core_profiles=source.AffectedCoreProfile.NE, - formula=formulas.Gaussian(), - # Define (possibly) time-dependent parameters to feed to the formula. - runtime_params=runtime_params_lib.RuntimeParams( - formula=formula_config.Gaussian( - total={0.0: 1.0, 5.0: 2.0, 10.0: 1.0}, # time-dependent. - c1=2.0, - c2=3.0, - ), - ), + gas_puff_source = register_source.get_registered_source('gas_puff_source') + gas_puff_source_builder = source_lib.make_source_builder( + gas_puff_source.source_class, + runtime_params_type=gas_puff_source.model_functions['default'].runtime_params_class, + model_func=gas_puff_source.model_functions['default'].source_profile_function, ) # Define the collection of sources here, which in this example only includes # one source. all_torax_sources = SourceModels( - sources={'my_custom_source': my_custom_source} + sources={'gas_puff_source': gas_puff_source_builder} ) See runtime_params.py for more details on how to configure all the source/sink @@ -737,6 +709,7 @@ def __init__( ] = source_lib.make_source_builder( generic_current_source.GenericCurrentSource, runtime_params_type=generic_current_source.RuntimeParams, + model_func=generic_current_source.calculate_generic_current_face, )() source_builders[ generic_current_source.GenericCurrentSource.SOURCE_NAME diff --git a/torax/sources/tests/bootstrap_current_source.py b/torax/sources/tests/bootstrap_current_source.py index d4b8cff5..421d6b6b 100644 --- a/torax/sources/tests/bootstrap_current_source.py +++ b/torax/sources/tests/bootstrap_current_source.py @@ -19,7 +19,6 @@ import numpy as np from torax.geometry import geometry from torax.sources import bootstrap_current_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_profiles from torax.sources.tests import test_lib @@ -33,9 +32,8 @@ def setUpClass(cls): super().setUpClass( source_class=bootstrap_current_source.BootstrapCurrentSource, runtime_params_class=bootstrap_current_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.FORMULA_BASED, - ], + source_name=bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME, + model_func=None, ) def test_extraction_of_relevant_profile_from_output(self): diff --git a/torax/sources/tests/bremsstrahlung_heat_sink.py b/torax/sources/tests/bremsstrahlung_heat_sink.py index fc47b0e9..8d8422b7 100644 --- a/torax/sources/tests/bremsstrahlung_heat_sink.py +++ b/torax/sources/tests/bremsstrahlung_heat_sink.py @@ -23,7 +23,6 @@ from torax import core_profile_setters from torax.config import runtime_params_slice from torax.sources import bremsstrahlung_heat_sink -from torax.sources import runtime_params as runtime_params_lib from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib from torax.stepper import runtime_params as stepper_runtime_params @@ -38,9 +37,8 @@ def setUpClass(cls): super().setUpClass( source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.FORMULA_BASED, - ], + source_name=bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME, + model_func=bremsstrahlung_heat_sink.bremsstrahlung_model_func, ) @parameterized.parameters([ diff --git a/torax/sources/tests/electron_cyclotron_source.py b/torax/sources/tests/electron_cyclotron_source.py index ee84da1d..82a8106a 100644 --- a/torax/sources/tests/electron_cyclotron_source.py +++ b/torax/sources/tests/electron_cyclotron_source.py @@ -38,9 +38,8 @@ def setUpClass(cls): super().setUpClass( source_class=electron_cyclotron_source.ElectronCyclotronSource, runtime_params_class=electron_cyclotron_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.FORMULA_BASED, - ], + source_name=electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME, + model_func=electron_cyclotron_source.calc_heating_and_current, ) def test_source_value(self): @@ -50,10 +49,10 @@ def test_source_value(self): raise TypeError(f"{type(self)} has a bad _source_class_builder") runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() @@ -81,11 +80,7 @@ def test_source_value(self): ) value = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources["foo"], geo=geo, core_profiles=core_profiles, ) @@ -95,67 +90,6 @@ def test_source_value(self): if jnp.any(jnp.isnan(value)): raise AssertionError(f"Source value contains NaNs: {value}") - def test_invalid_source_types_raise_errors(self): - """Tests that using unsupported types raises an error.""" - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() - source_builder = self._source_class_builder() - source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources["foo"] - self.assertIsInstance(source, source_lib.Source) - dynamic_runtime_params_slice_provider = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - ) - ) - # This slice is needed to create the core_profiles - dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( - t=runtime_params.numerics.t_initial, - ) - static_runtime_params_slice = ( - runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - stepper=stepper_runtime_params.RuntimeParams(), - source_runtime_params=source_models_builder.runtime_params, - ) - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - source_models=source_models, - ) - - for unsupported_mode in self._unsupported_modes: - source_builder.runtime_params.mode = unsupported_mode - # Construct a new slice with the given mode - static_runtime_params_slice = ( - runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - stepper=stepper_runtime_params.RuntimeParams(), - source_runtime_params=source_models_builder.runtime_params, - ) - ) - with self.subTest(unsupported_mode.name): - with self.assertRaises(ValueError): - source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], - static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - "foo" - ], - geo=geo, - core_profiles=core_profiles, - ) - def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" geo = geometry.build_circular_geometry() diff --git a/torax/sources/tests/electron_density_sources.py b/torax/sources/tests/electron_density_sources.py index e36081ce..d31ab9af 100644 --- a/torax/sources/tests/electron_density_sources.py +++ b/torax/sources/tests/electron_density_sources.py @@ -16,7 +16,6 @@ from absl.testing import absltest from torax.sources import electron_density_sources as eds -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -28,9 +27,8 @@ def setUpClass(cls): super().setUpClass( source_class=eds.GasPuffSource, runtime_params_class=eds.GasPuffRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + source_name=eds.GasPuffSource.SOURCE_NAME, + model_func=eds.calc_puff_source, ) @@ -42,9 +40,8 @@ def setUpClass(cls): super().setUpClass( source_class=eds.PelletSource, runtime_params_class=eds.PelletRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + source_name=eds.PelletSource.SOURCE_NAME, + model_func=eds.calc_pellet_source, ) @@ -56,23 +53,8 @@ def setUpClass(cls): super().setUpClass( source_class=eds.GenericParticleSource, runtime_params_class=eds.GenericParticleSourceRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], - ) - - -class RecombinationDensitySinkTest(test_lib.SingleProfileSourceTestCase): - """Tests for RecombinationDensitySink.""" - - @classmethod - def setUpClass(cls): - super().setUpClass( - source_class=eds.RecombinationDensitySink, - runtime_params_class=runtime_params_lib.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + source_name=eds.GenericParticleSource.SOURCE_NAME, + model_func=eds.calc_generic_particle_source, ) diff --git a/torax/sources/tests/formulas.py b/torax/sources/tests/formulas.py deleted file mode 100644 index 1c197580..00000000 --- a/torax/sources/tests/formulas.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for sources/formulas.py.""" - -from absl.testing import absltest -import chex -from torax import output -from torax import sim as sim_lib -from torax import simulation_app -from torax.config import build_sim -from torax.config import numerics as numerics_lib -from torax.config import profile_conditions as profile_conditions_lib -from torax.config import runtime_params as general_runtime_params -from torax.pedestal_model import set_tped_nped -from torax.sources import bremsstrahlung_heat_sink -from torax.sources import electron_density_sources -from torax.sources import formula_config -from torax.sources import formulas -from torax.sources import fusion_heat_source -from torax.sources import ohmic_heat_source -from torax.sources import runtime_params as runtime_params_lib -from torax.sources.tests import test_lib -from torax.stepper import linear_theta_method -from torax.tests.test_lib import default_sources -from torax.tests.test_lib import sim_test_case -from torax.transport_model import constant as constant_transport_model - - -_ALL_PROFILES = ('temp_ion', 'temp_el', 'psi', 'q_face', 's_face', 'ne') - - -class FormulasIntegrationTest(sim_test_case.SimTestCase): - """Integration tests for using non-default formulas.""" - - def test_custom_exponential_source_can_replace_puff_source(self): - """Replaces one the default ne source with a custom one.""" - # The default puff source gives an exponential profile. In this test, we - # zero out the default puff source and introduce a new custom source that - # should give the same profiles throughout the entire simulation run as the - # original puff source. - - # For this test, use test_particle_sources_constant with the linear stepper. - custom_source_name = 'custom_exponential_source' - - # Copy the test_particle_sources_constant config in here for clarity. - test_particle_sources_constant_runtime_params = general_runtime_params.GeneralRuntimeParams( - profile_conditions=profile_conditions_lib.ProfileConditions( - set_pedestal=True, - nbar=0.85, - nu=0, - ne_bound_right=0.5, - ), - numerics=numerics_lib.Numerics( - ion_heat_eq=True, - el_heat_eq=True, - dens_eq=True, # This is important to be True to test ne sources. - current_eq=True, - resistivity_mult=100, - t_final=2, - ), - ) - basic_pedestal_model_builder = ( - set_tped_nped.SetTemperatureDensityPedestalModelBuilder() - ) - # Set the sources to match test_particle_sources_constant as well. - source_models_builder = default_sources.get_default_sources_builder() - source_models_builder.runtime_params[ - electron_density_sources.PelletSource.SOURCE_NAME - ].S_pellet_tot = 2.0e22 - S_puff_tot = 1.0e22 # pylint: disable=invalid-name - puff_decay_length = 0.05 - source_models_builder.runtime_params[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ].S_puff_tot = S_puff_tot - source_models_builder.runtime_params[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ].puff_decay_length = puff_decay_length - source_models_builder.runtime_params[ - electron_density_sources.GenericParticleSource.SOURCE_NAME - ].S_tot = 0.0 - # We need to turn off some other sources for test_particle_sources_constant - # that are unrelated to our test for the ne custom source. - source_models_builder.runtime_params[ - fusion_heat_source.FusionHeatSource.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - source_models_builder.runtime_params[ - ohmic_heat_source.OhmicHeatSource.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - source_models_builder.runtime_params[ - bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - - # Add the custom source to the source_models, but keep it turned off for the - # first run. - source_models_builder.source_builders[custom_source_name] = ( - test_lib.TestSourceBuilder( - formula=formulas.Exponential(), - runtime_params=runtime_params_lib.RuntimeParams( - mode=runtime_params_lib.Mode.ZERO, - # will override these later, but defining here because, due to - # how JAX works, this function is still evaluated even when the - # mode is set to ZERO. So the runtime config needs to be set - # with the correct params. - formula=formula_config.Exponential(), - ), - ) - ) - - # Load reference profiles - ref_profiles, ref_time = self._get_refs( - 'test_particle_sources_constant.nc', _ALL_PROFILES - ) - - # We set up the sim only once and update the config on each run below in a - # way that does not trigger recompiles. This way we only trace the code - # once. - geo_provider = build_sim.build_geometry_provider_from_config( - {'geometry_type': 'circular'} - ) - transport_model_builder = ( - constant_transport_model.ConstantTransportModelBuilder( - runtime_params=constant_transport_model.RuntimeParams( - De_const=0.5, - Ve_const=-0.2, - ) - ) - ) - sim = sim_lib.build_sim_object( - runtime_params=test_particle_sources_constant_runtime_params, - geometry_provider=geo_provider, - stepper_builder=linear_theta_method.LinearThetaMethodBuilder( - runtime_params=linear_theta_method.LinearRuntimeParams( - predictor_corrector=False, - ) - ), - transport_model_builder=transport_model_builder, - source_models_builder=source_models_builder, - pedestal_model_builder=basic_pedestal_model_builder, - ) - - # Make sure the config copied here works with these references. - with self.subTest('with_puff_and_without_custom_source'): - # Need to run the sim once to build the step_fn. - sim_outputs = sim.run() - history = output.StateHistory(sim_outputs, sim.source_models) - self._check_profiles_vs_expected( - core_profiles=history.core_profiles, - t=history.times, - ref_time=ref_time, - ref_profiles=ref_profiles, - rtol=self.rtol, - atol=self.atol, - ) - - with self.subTest('without_puff_and_with_custom_source'): - # Now turn on the custom source. - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.FORMULA_BASED - ) - source_models_builder.runtime_params[custom_source_name].formula = ( - formula_config.Exponential( - total=( - S_puff_tot - / test_particle_sources_constant_runtime_params.numerics.nref - ), - c1=1.0, - c2=puff_decay_length, - ) - ) - # And turn off the gas puff source it is replacing. - source_models_builder.runtime_params[ - electron_density_sources.GasPuffSource.SOURCE_NAME - ].mode = runtime_params_lib.Mode.ZERO - sim = simulation_app.update_sim( - sim, - test_particle_sources_constant_runtime_params, - sim.geometry_provider, - transport_model_builder.runtime_params, - source_models_builder.runtime_params, - linear_theta_method.LinearRuntimeParams(predictor_corrector=False), - pedestal_runtime_params=basic_pedestal_model_builder.runtime_params, - ) - self._run_sim_and_check(sim, ref_profiles, ref_time) - - with self.subTest('without_puff_and_without_custom_source'): - # Confirm that the custom source actual has an effect. - # Turn it off as well, and the check shouldn't pass. - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.ZERO - ) - sim = simulation_app.update_sim( - sim, - test_particle_sources_constant_runtime_params, - sim.geometry_provider, - transport_model_builder.runtime_params, - source_models_builder.runtime_params, - linear_theta_method.LinearRuntimeParams(predictor_corrector=False), - pedestal_runtime_params=basic_pedestal_model_builder.runtime_params, - ) - with self.assertRaises(AssertionError): - self._run_sim_and_check(sim, ref_profiles, ref_time) - - def _run_sim_and_check( - self, - sim: sim_lib.Sim, - ref_profiles: dict[str, chex.ArrayTree], - ref_time: chex.Array, - ): - """Runs sim with new runtime params and checks the profiles vs. expected.""" - sim_outputs = sim_lib.run_simulation( - static_runtime_params_slice=sim.static_runtime_params_slice, - dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider, - geometry_provider=sim.geometry_provider, - initial_state=sim.initial_state, - time_step_calculator=sim.time_step_calculator, - step_fn=sim.step_fn, - ) - history = output.StateHistory(sim_outputs, sim.source_models) - self._check_profiles_vs_expected( - core_profiles=history.core_profiles, - t=history.times, - ref_time=ref_time, - ref_profiles=ref_profiles, - rtol=self.rtol, - atol=self.atol, - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index 6dab4b4b..af7fe558 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -23,7 +23,6 @@ from torax import core_profile_setters from torax.config import runtime_params_slice from torax.sources import fusion_heat_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib from torax.tests.test_lib import torax_refs @@ -37,9 +36,8 @@ def setUpClass(cls): super().setUpClass( source_class=fusion_heat_source.FusionHeatSource, runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.FORMULA_BASED, - ], + source_name=fusion_heat_source.FusionHeatSource.SOURCE_NAME, + model_func=fusion_heat_source.fusion_heat_model_func, ) @parameterized.parameters([ diff --git a/torax/sources/tests/generic_current_source.py b/torax/sources/tests/generic_current_source.py index a21cd50a..3b0e1ca0 100644 --- a/torax/sources/tests/generic_current_source.py +++ b/torax/sources/tests/generic_current_source.py @@ -22,7 +22,6 @@ from torax.config import runtime_params_slice from torax.geometry import geometry from torax.sources import generic_current_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources.tests import test_lib @@ -35,9 +34,8 @@ def setUpClass(cls): super().setUpClass( source_class=generic_current_source.GenericCurrentSource, runtime_params_class=generic_current_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME, + model_func=generic_current_source.calculate_generic_current_face, ) def test_generic_current_hires(self): @@ -70,13 +68,7 @@ def test_generic_current_hires(self): self.assertIsNotNone( source.generic_current_source_hires( dynamic_runtime_params_slice=dynamic_slice, - dynamic_source_runtime_params=dynamic_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo=geo, ) ) @@ -113,13 +105,7 @@ def test_profile_is_on_face_grid(self): self.assertEqual( source.get_value( static_slice, - static_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], dynamic_runtime_params_slice, - dynamic_runtime_params_slice.sources[ - generic_current_source.GenericCurrentSource.SOURCE_NAME - ], geo, core_profiles=None, ).shape, diff --git a/torax/sources/tests/generic_ion_el_heat_source.py b/torax/sources/tests/generic_ion_el_heat_source.py index 79555718..f0e94254 100644 --- a/torax/sources/tests/generic_ion_el_heat_source.py +++ b/torax/sources/tests/generic_ion_el_heat_source.py @@ -16,7 +16,6 @@ from absl.testing import absltest from torax.sources import generic_ion_el_heat_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -28,9 +27,8 @@ def setUpClass(cls): super().setUpClass( source_class=generic_ion_el_heat_source.GenericIonElectronHeatSource, runtime_params_class=generic_ion_el_heat_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + source_name=generic_ion_el_heat_source.GenericIonElectronHeatSource.SOURCE_NAME, + model_func=generic_ion_el_heat_source.default_formula, ) diff --git a/torax/sources/tests/impurity_radiation_heat_sink.py b/torax/sources/tests/impurity_radiation_heat_sink.py index 485d2c26..f88797b8 100644 --- a/torax/sources/tests/impurity_radiation_heat_sink.py +++ b/torax/sources/tests/impurity_radiation_heat_sink.py @@ -42,11 +42,9 @@ def setUpClass(cls): super().setUpClass( source_class=impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink, runtime_params_class=impurity_radiation_heat_sink_lib.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ], + source_name=impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME, links_back=True, + model_func=impurity_radiation_heat_sink_lib.radially_constant_fraction_of_Pin ) def test_source_value(self): @@ -64,8 +62,11 @@ def test_source_value(self): heat_source_builder_builder = source_lib.make_source_builder( source_type=generic_ion_el_heat_source.GenericIonElectronHeatSource, runtime_params_type=generic_ion_el_heat_source.RuntimeParams, + model_func=generic_ion_el_heat_source.default_formula, + ) + heat_source_builder = heat_source_builder_builder( + model_func=generic_ion_el_heat_source.default_formula ) - heat_source_builder = heat_source_builder_builder() # Runtime params runtime_params = general_runtime_params.GeneralRuntimeParams() @@ -73,9 +74,7 @@ def test_source_value(self): # Source models source_models_builder = source_models_lib.SourceModelsBuilder( { - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME: ( - impurity_radiation_sink_builder - ), + self._source_name: impurity_radiation_sink_builder, generic_ion_el_heat_source.GenericIonElectronHeatSource.SOURCE_NAME: ( heat_source_builder ), @@ -84,9 +83,7 @@ def test_source_value(self): source_models = source_models_builder() # Extract the source we're testing and check that it's been built correctly - impurity_radiation_sink = source_models.sources[ - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME - ] + impurity_radiation_sink = source_models.sources[self._source_name] self.assertIsInstance(impurity_radiation_sink, source_lib.Source) # Geometry, profiles, and dynamic runtime params @@ -110,12 +107,9 @@ def test_source_value(self): geo=geo, source_models=source_models, ) - impurity_radiation_sink_dynamic_runtime_params_slice = dynamic_runtime_params_slice.sources[ - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME - ] - impurity_radiation_sink_static_runtime_params_slice = static_slice.sources[ - impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME - ] + impurity_radiation_sink_dynamic_runtime_params_slice = ( + dynamic_runtime_params_slice.sources[self._source_name] + ) heat_source_dynamic_runtime_params_slice = ( dynamic_runtime_params_slice.sources[ @@ -131,13 +125,13 @@ def test_source_value(self): heat_source_dynamic_runtime_params_slice, generic_ion_el_heat_source.DynamicRuntimeParams, ) - impurity_radiation_heat_sink_power_density = impurity_radiation_sink.get_value( - static_runtime_params_slice=static_slice, - static_source_runtime_params=impurity_radiation_sink_static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=impurity_radiation_sink_dynamic_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, + impurity_radiation_heat_sink_power_density = ( + impurity_radiation_sink.get_value( + static_runtime_params_slice=static_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) ) # ImpurityRadiationHeatSink provides TEMP_EL only @@ -156,77 +150,15 @@ def test_source_value(self): rtol=1e-2, # TODO(b/382682284): this rtol seems v. high ) - def test_invalid_source_types_raise_errors(self): - """Tests that using unsupported types raises an error.""" - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() - source_builder = self._source_class_builder() - source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources["foo"] - self.assertIsInstance(source, source_lib.Source) - dynamic_runtime_params_slice_provider = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - ) - ) - # This slice is needed to create the core_profiles - dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( - t=runtime_params.numerics.t_initial, - ) - static_runtime_params_slice = ( - runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - source_models=source_models, - ) - - for unsupported_mode in self._unsupported_modes: - source_builder.runtime_params.mode = unsupported_mode - # Construct a new slice with the given mode - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - with self.subTest(unsupported_mode.name): - with self.assertRaises(RuntimeError): - source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], - static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_runtime_params_slice.sources[ - "foo" - ], - geo=geo, - core_profiles=core_profiles, - ) - def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" geo = geometry.build_circular_geometry() source_builder = self._source_class_builder() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) cell = source_lib.ProfileType.CELL.get_profile_shape(geo) fake_profile = jnp.ones(cell) diff --git a/torax/sources/tests/ion_cyclotron_source.py b/torax/sources/tests/ion_cyclotron_source.py index 495deb00..58d32662 100644 --- a/torax/sources/tests/ion_cyclotron_source.py +++ b/torax/sources/tests/ion_cyclotron_source.py @@ -29,7 +29,6 @@ from torax.config import runtime_params_slice from torax.geometry import geometry from torax.sources import ion_cyclotron_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib @@ -98,7 +97,9 @@ def setUpClass(cls): super().setUpClass( source_class=ion_cyclotron_source.IonCyclotronSource, runtime_params_class=ion_cyclotron_source.RuntimeParams, - unsupported_modes=[runtime_params_lib.Mode.FORMULA_BASED], + source_class_builder=ion_cyclotron_source.IonCyclotronSourceBuilder, + source_name=ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME, + model_func=None, ) @parameterized.product( @@ -163,13 +164,7 @@ def test_icrh_output_matches_total_power( ) icrh_output = icrh_source.get_value( static_slice, - static_slice.sources[ - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME - ], dynamic_runtime_params_slice, - dynamic_runtime_params_slice.sources[ - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME - ], geo, core_profiles, ) @@ -218,10 +213,12 @@ def test_source_value(self, mock_path): runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( - {"foo": source_builder}, + {ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: source_builder}, ) source_models = source_models_builder() - source = source_models.sources["foo"] + source = source_models.sources[ + ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME + ] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -244,11 +241,7 @@ def test_source_value(self, mock_path): ) ion_and_el = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - "foo" - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources["foo"], geo=geo, core_profiles=core_profiles, ) diff --git a/torax/sources/tests/ohmic_heat_source.py b/torax/sources/tests/ohmic_heat_source.py index a6cdc5f2..f5ed1f55 100644 --- a/torax/sources/tests/ohmic_heat_source.py +++ b/torax/sources/tests/ohmic_heat_source.py @@ -14,7 +14,6 @@ """Tests for ohmic_heat_source.""" from absl.testing import absltest from torax.sources import ohmic_heat_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -26,10 +25,9 @@ def setUpClass(cls): super().setUpClass( source_class=ohmic_heat_source.OhmicHeatSource, runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.FORMULA_BASED, - ], + source_name=ohmic_heat_source.OhmicHeatSource.SOURCE_NAME, links_back=True, + model_func=ohmic_heat_source.ohmic_model_func, ) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index 3df12b31..71754e6f 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -13,15 +13,12 @@ # limitations under the License. """Tests for qei_source.""" - -import dataclasses from absl.testing import absltest from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice from torax.geometry import geometry from torax.sources import qei_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib @@ -34,9 +31,8 @@ def setUpClass(cls): super().setUpClass( source_class=qei_source.QeiSource, runtime_params_class=qei_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.FORMULA_BASED, - ], + source_name=qei_source.QeiSource.SOURCE_NAME, + model_func=None, ) def test_source_value(self): @@ -70,60 +66,11 @@ def test_source_value(self): qei = source.get_qei( static_slice, dynamic_slice, - dynamic_slice.sources['qei_source'], geo, core_profiles, ) self.assertIsNotNone(qei) - def test_invalid_source_types_raise_errors(self): - source_builder = self._source_class_builder() - source_models_builder = source_models_lib.SourceModelsBuilder( - {'qei_source': source_builder} - ) - source_models = source_models_builder() - source = source_models.sources['qei_source'] - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - for unsupported_mode in self._unsupported_modes: - with self.subTest(unsupported_mode.name): - with self.assertRaises(ValueError): - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params={ - 'qei_source': dataclasses.replace( - source_builder.runtime_params, - mode=unsupported_mode, - ) - }, - ) - # Force pytype to recognize `source` has `get_qei` - assert isinstance(source, qei_source.QeiSource) - source.get_qei( - static_slice, - dynamic_slice, - dynamic_slice.sources['qei_source'], - geo, - core_profiles, - ) - if __name__ == '__main__': absltest.main() diff --git a/torax/sources/tests/register_source.py b/torax/sources/tests/register_source.py index 47a6465f..dd0100db 100644 --- a/torax/sources/tests/register_source.py +++ b/torax/sources/tests/register_source.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for the source registry.""" - from absl.testing import absltest from absl.testing import parameterized from torax.sources import bootstrap_current_source @@ -27,6 +26,7 @@ from torax.sources import ohmic_heat_source from torax.sources import qei_source from torax.sources import register_source +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -48,10 +48,21 @@ class SourceTest(parameterized.TestCase): ) def test_sources_in_registry_build_successfully(self, source_name: str): """Test that all sources in the registry build successfully.""" - registered_source = register_source.get_registered_source(source_name) + registered_source = register_source.get_supported_source(source_name) source_class = registered_source.source_class - source_runtime_params_class = registered_source.default_runtime_params_class - source_builder_class = registered_source.source_builder_class + model_function = registered_source.model_functions[ + source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + source_builder_class = model_function.source_builder_class + source_runtime_params_class = model_function.runtime_params_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + registered_source.source_class, + runtime_params_type=source_runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + source_runtime_params_class = model_function.runtime_params_class source_builder = source_builder_class() self.assertIsInstance( source_builder.runtime_params, source_runtime_params_class diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index da68237a..04a08caf 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -45,12 +45,21 @@ class PsiTestSource(source_lib.Source): def affected_core_profiles(self): return (source_lib.AffectedCoreProfile.PSI,) + @property + def source_name(self) -> str: + return 'foo' + PsiTestSourceBuilder = source_lib.make_source_builder(PsiTestSource) @dataclasses.dataclass(frozen=True, eq=True) class IonElTestSource(source_lib.Source): + """Test source that affects ion and electron profiles.""" + + @property + def source_name(self) -> str: + return 'foo' @property def affected_core_profiles(self): @@ -59,14 +68,6 @@ def affected_core_profiles(self): source_lib.AffectedCoreProfile.TEMP_EL, ) - @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - return ( - runtime_params_lib.Mode.FORMULA_BASED, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) - IonElTestSourceBuilder = source_lib.make_source_builder(IonElTestSource) @@ -147,6 +148,7 @@ class NotEq: @dataclasses.dataclass(frozen=True, eq=True) class MySource: my_field: int + model_func: source_lib.SourceProfileFunction | None = None # pylint doesn't realize this is a class MySourceBuilder = source_lib.make_source_builder( # pylint: disable=invalid-name @@ -194,11 +196,7 @@ def test_zero_profile_works_by_default(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -207,88 +205,24 @@ def test_zero_profile_works_by_default(self): get_zero_profile(source_lib.ProfileType.CELL, geo), ) - def test_unsupported_modes_raise_errors(self): - """Calling with an unsupported type should raise an error.""" - - class TestSource(source_lib.Source): - """A test source.""" - - @property - def affected_core_profiles( - self, - ) -> tuple[source_lib.AffectedCoreProfile, ...]: - return (source_lib.AffectedCoreProfile.NE,) - - @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - return (runtime_params_lib.Mode.FORMULA_BASED,) - - source_builder = source_lib.make_source_builder(TestSource)() - # But set the runtime params of the source to use ZERO as the mode. - source_builder.runtime_params.mode = runtime_params_lib.Mode.ZERO - source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources['foo'] - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - # But calling requesting ZERO shouldn't work. - with self.assertRaises(ValueError): - source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - @parameterized.parameters( (runtime_params_lib.Mode.ZERO, np.array([0, 0, 0, 0])), - (runtime_params_lib.Mode.MODEL_BASED, np.array([1, 1, 1, 1])), - (runtime_params_lib.Mode.FORMULA_BASED, np.array([2, 2, 2, 2])), + (runtime_params_lib.Mode.MODEL_BASED, np.array([2, 2, 2, 2])), (runtime_params_lib.Mode.PRESCRIBED, np.array([3, 3, 3, 3])), ) def test_correct_mode_called(self, mode, expected_profile): """The correct mode should be called.""" - source_builder = test_lib.TestSourceBuilder() + source_builder = source_lib.make_source_builder( + test_lib.TestSource, + model_func=lambda _0, _1, _2, _3, _4, _5: jnp.ones( + source_lib.ProfileType.CELL.get_profile_shape(geo) + ) * 2, + )() source_models_builder = source_models_lib.SourceModelsBuilder( {'foo': source_builder}, ) source_models = source_models_builder() source = source_models.sources['foo'] - source = dataclasses.replace( - source, - model_func=lambda _0, _1, _2, _3, _4, _5, _6: jnp.ones( - source_lib.ProfileType.CELL.get_profile_shape(geo) - ), - formula=lambda _0, _1, _2, _3, _4, _5, _6: jnp.ones( - source_lib.ProfileType.CELL.get_profile_shape(geo) - ) - * 2, - ) source_runtime_params = source_models_builder.runtime_params runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry(n_rho=4) @@ -318,11 +252,7 @@ def test_correct_mode_called(self, mode, expected_profile): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -370,44 +300,13 @@ def test_defaults_output_zeros(self): ) }, ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose( - profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), - ) - with self.subTest('formula'): - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params={ - 'foo': dataclasses.replace( - source_builder.runtime_params, - mode=runtime_params_lib.Mode.FORMULA_BASED, - ) - }, - ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose( - profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), - ) + with self.assertRaises(ValueError): + source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + core_profiles=core_profiles, + ) with self.subTest('prescribed'): static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params, @@ -420,11 +319,7 @@ def test_defaults_output_zeros(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -433,60 +328,15 @@ def test_defaults_output_zeros(self): get_zero_profile(source_lib.ProfileType.CELL, geo), ) - def test_overriding_default_formula(self): - """The user-specified formula should override the default formula.""" - geo = geometry.build_circular_geometry() - output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) - expected_output = jnp.ones(output_shape) - source_builder = IonElTestSourceBuilder( - formula=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, - ) - source_builder.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED - source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources['foo'] - runtime_params = general_runtime_params.GeneralRuntimeParams() - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose(profile, expected_output) - def test_overriding_model(self): """The user-specified model should override the default model.""" geo = geometry.build_circular_geometry() output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) expected_output = jnp.ones(output_shape) - source_builder = IonElTestSourceBuilder( - model_func=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, - ) + source_builder = source_lib.make_source_builder( + IonElTestSource, + model_func=lambda _0, _1, _2, _3, _4, _5: expected_output, + )() source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED source_models_builder = source_models_lib.SourceModelsBuilder( {'foo': source_builder}, @@ -515,11 +365,7 @@ def test_overriding_model(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -565,11 +411,7 @@ def test_overriding_prescribed_values(self): ) profile = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) @@ -583,6 +425,10 @@ def test_retrieving_profile_for_affected_state(self): class TestSource(source_lib.Source): output_shape_getter = lambda _0: output_shape + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles(self): return ( @@ -592,7 +438,7 @@ def affected_core_profiles(self): profile = jnp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]]) # from get_value() source = TestSource( - model_func=lambda _0, _1, _2, _3, _4, _5, _6: profile, + model_func=lambda _0, _1, _2, _3, _4, _5: profile, ) geo = geometry.build_circular_geometry(n_rho=4) psi_profile = source.get_source_profile_for_affected_core_profile( @@ -616,57 +462,12 @@ def affected_core_profiles(self): class SingleProfileSourceTest(parameterized.TestCase): """Tests for SingleProfileSource.""" - def test_custom_formula(self): - """The user-specified formula should override the default formula.""" - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry(n_rho=5) - expected_output = jnp.ones((5)) # 5 matches the geo. - source_builder = PsiTestSourceBuilder( - formula=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, - ) - source_builder.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED - source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources['foo'] - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose(profile, expected_output) - def test_retrieving_profile_for_affected_state(self): """Grabbing the correct profile works for all mesh state attributes.""" profile = jnp.asarray([1, 2, 3, 4]) # from get_value() source = test_lib.TestSource( - model_func=lambda _0, _1, _2, _3, _4, _5, _6: profile, + model_func=lambda _0, _1, _2, _3, _4, _5: profile, ) geo = geometry.build_circular_geometry(n_rho=4) psi_profile = source.get_source_profile_for_affected_core_profile( diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index fef76763..2aa969ef 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -37,6 +37,10 @@ class FooSource(source_lib.Source): """A test source.""" + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles( self, @@ -50,10 +54,6 @@ def affected_core_profiles( def output_shape_getter(self) -> source_lib.SourceOutputShapeFunction: return source_lib.get_ion_el_output_shape - @property - def supported_modes(self) -> tuple[source_runtime_params_lib.Mode, ...]: - return (source_runtime_params_lib.Mode.FORMULA_BASED,) - _FooSourceBuilder = source_lib.make_source_builder( FooSource, @@ -175,10 +175,9 @@ def test_custom_source_profiles_dont_change_when_jitted(self): def foo_formula( unused_dcs, - unused_sc, unused_static_runtime_params_slice, - unused_static_source_runtime_params, geo: geometry.Geometry, + unused_source_name: str, unused_state, unused_source_models, ): @@ -187,16 +186,12 @@ def foo_formula( jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)), ]) - foo_source_builder = _FooSourceBuilder( - formula=foo_formula, - ) - foo_source_builder.affected_core_profiles = ( - source_lib.AffectedCoreProfile.TEMP_EL, - source_lib.AffectedCoreProfile.NE, - ) - # Set the source mode to FORMULA. + foo_source_builder = source_lib.make_source_builder( + FooSource, model_func=foo_formula + )() + # Set the source mode to MODEL_BASED. foo_source_builder.runtime_params.mode = ( - source_runtime_params_lib.Mode.FORMULA_BASED + source_runtime_params_lib.Mode.MODEL_BASED ) source_models_builder = source_models_lib.SourceModelsBuilder( {source_name: foo_source_builder}, diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 28ae9792..8ca38384 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -14,7 +14,7 @@ """Utilities to help with testing sources.""" -from typing import Sequence, Type +from typing import Type from absl.testing import parameterized import chex @@ -37,21 +37,16 @@ class TestSource(source_lib.Source): """A test source.""" + @property + def source_name(self) -> str: + return 'foo' + @property def affected_core_profiles( self, ) -> tuple[source_lib.AffectedCoreProfile, ...]: return (source_lib.AffectedCoreProfile.NE,) - @property - def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: - return ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.FORMULA_BASED, - runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.PRESCRIBED, - ) - TestSourceBuilder = source_lib.make_source_builder(TestSource) @@ -65,26 +60,33 @@ class SourceTestCase(parameterized.TestCase): _source_class: Type[source_lib.Source] _source_class_builder: source_lib.SourceBuilderProtocol _config_attr_name: str - _unsupported_modes: Sequence[runtime_params_lib.Mode] + _source_name: str + _runtime_params_class: Type[runtime_params_lib.RuntimeParams] @classmethod def setUpClass( cls, source_class: Type[source_lib.Source], runtime_params_class: Type[runtime_params_lib.RuntimeParams], - unsupported_modes: Sequence[runtime_params_lib.Mode], + source_name: str, + model_func: source_lib.SourceProfileFunction | None, links_back: bool = False, + source_class_builder: source_lib.SourceBuilderProtocol | None = None, ): super().setUpClass() cls._source_class = source_class - cls._source_class_builder = source_lib.make_source_builder( - source_type=source_class, - runtime_params_type=runtime_params_class, - links_back=links_back, - ) + if source_class_builder is None: + cls._source_class_builder = source_lib.make_source_builder( + source_type=source_class, + runtime_params_type=runtime_params_class, + links_back=links_back, + model_func=model_func, + ) + else: + cls._source_class_builder = source_class_builder cls._runtime_params_class = runtime_params_class - cls._unsupported_modes = unsupported_modes cls._links_back = links_back + cls._source_name = source_name def test_runtime_params_builds_dynamic_params(self): runtime_params = self._runtime_params_class() @@ -100,7 +102,6 @@ def test_runtime_params_builds_dynamic_params(self): mode=( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.MODEL_BASED, - runtime_params_lib.Mode.FORMULA_BASED, runtime_params_lib.Mode.PRESCRIBED, ), is_explicit=(True, False), @@ -127,17 +128,17 @@ def test_source_value(self): # SingleProfileSource subclasses should have default names and be # instantiable without any __init__ arguments. # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter + source_builder = self._source_class_builder() if not source_lib.is_source_builder(source_builder): raise TypeError(f'{type(self)} has a bad _source_class_builder') # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources['foo'] - source_builder.runtime_params.mode = source.supported_modes[0] + source = source_models.sources[self._source_name] + source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() dynamic_runtime_params_slice = ( @@ -161,67 +162,12 @@ def test_source_value(self): ) value = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) chex.assert_rank(value, 1) - def test_invalid_source_types_raise_errors(self): - """Tests that using unsupported types raises an error.""" - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() - # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter - # pylint: enable=missing-kwoa - source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources['foo'] - self.assertIsInstance(source, source_lib.Source) - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - for unsupported_mode in self._unsupported_modes: - source_builder.runtime_params.mode = unsupported_mode - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - with self.subTest(unsupported_mode.name): - with self.assertRaises(ValueError): - source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - class IonElSourceTestCase(SourceTestCase): """Base test class for IonElSource subclasses.""" @@ -229,15 +175,15 @@ class IonElSourceTestCase(SourceTestCase): def test_source_value(self): """Tests that the source can provide a value by default.""" # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter + source_builder = self._source_class_builder() # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources['foo'] + source = source_models.sources[self._source_name] self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -260,67 +206,12 @@ def test_source_value(self): ) ion_and_el = source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], geo=geo, core_profiles=core_profiles, ) chex.assert_rank(ion_and_el, 2) - def test_invalid_source_types_raise_errors(self): - """Tests that using unsupported types raises an error.""" - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry() - # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter - # pylint: enable=missing-kwoa - source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources['foo'] - self.assertIsInstance(source, source_lib.Source) - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - source_models=source_models, - ) - for unsupported_mode in self._unsupported_modes: - source_builder.runtime_params.mode = unsupported_mode - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params, - source_runtime_params=source_models_builder.runtime_params, - ) - with self.subTest(unsupported_mode.name): - with self.assertRaises(ValueError): - source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources['foo'], - geo=geo, - core_profiles=core_profiles, - ) - def test_extraction_of_relevant_profile_from_output(self): """Tests that the relevant profile is extracted from the output.""" geo = geometry.build_circular_geometry() diff --git a/torax/tests/physics.py b/torax/tests/physics.py index 085f3d98..aba8fb06 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -16,6 +16,7 @@ import dataclasses from typing import Callable + from absl.testing import absltest from absl.testing import parameterized import jax @@ -27,10 +28,12 @@ from torax.config import runtime_params_slice from torax.fvm import cell_variable from torax.geometry import geometry +from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs + _trapz = jax.scipy.integrate.trapezoid @@ -108,9 +111,9 @@ def test_update_psi_from_j( runtime_params = references.runtime_params source_models_builder = source_models_lib.SourceModelsBuilder() # Turn on the external current source. - source_models_builder.runtime_params['generic_current_source'].mode = ( - source_runtime_params.Mode.FORMULA_BASED - ) + source_models_builder.runtime_params[ + generic_current_source.GenericCurrentSource.SOURCE_NAME + ].mode = source_runtime_params.Mode.MODEL_BASED source_models = source_models_builder() dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index a34c0ed4..360c6859 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -35,6 +35,7 @@ from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib from torax.sources.tests import test_lib from torax.stepper import linear_theta_method from torax.tests.test_lib import default_sources @@ -94,64 +95,36 @@ def test_custom_ne_source_can_replace_defaults(self): # For this example, use test_particle_sources_constant with the linear # stepper. - custom_source_name = 'custom_ne_source' + custom_source_name = 'foo' def custom_source_formula( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - static_source_runtime_params: runtime_params_lib.RuntimeParams, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, + unused_source_name: str, unused_state: state_lib.CoreProfiles | None, unused_source_models: ..., ): # Combine the outputs. - assert isinstance( - dynamic_source_runtime_params, _CustomSourceDynamicRuntimeParams - ) - ignored_default_kwargs = dict( - formula=dynamic_source_runtime_params.formula, - prescribed_values=dynamic_source_runtime_params.prescribed_values, - ) - puff_params = electron_density_sources.DynamicGasPuffRuntimeParams( - puff_decay_length=dynamic_source_runtime_params.puff_decay_length, - S_puff_tot=dynamic_source_runtime_params.S_puff_tot, - **ignored_default_kwargs, - ) - params = electron_density_sources.DynamicParticleRuntimeParams( - deposition_location=dynamic_source_runtime_params.deposition_location, - particle_width=dynamic_source_runtime_params.particle_width, - S_tot=dynamic_source_runtime_params.S_tot, - **ignored_default_kwargs, - ) - pellet_params = electron_density_sources.DynamicPelletRuntimeParams( - pellet_deposition_location=dynamic_source_runtime_params.pellet_deposition_location, - pellet_width=dynamic_source_runtime_params.pellet_width, - S_pellet_tot=dynamic_source_runtime_params.S_pellet_tot, - **ignored_default_kwargs, - ) # pylint: disable=protected-access return ( - electron_density_sources._calc_puff_source( + electron_density_sources.calc_puff_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=puff_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, + source_name=electron_density_sources.GasPuffSource.SOURCE_NAME, ) - + electron_density_sources._calc_generic_particle_source( + + electron_density_sources.calc_generic_particle_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, + source_name=electron_density_sources.GenericParticleSource.SOURCE_NAME, ) - + electron_density_sources._calc_pellet_source( + + electron_density_sources.calc_pellet_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=pellet_params, static_runtime_params_slice=static_runtime_params_slice, - static_source_runtime_params=static_source_runtime_params, geo=geo, + source_name=electron_density_sources.PelletSource.SOURCE_NAME, ) ) # pylint: enable=protected-access @@ -189,20 +162,25 @@ def custom_source_formula( # Add the custom source with the correct params, but keep it turned off to # start. + source_builder = source_lib.make_source_builder( + test_lib.TestSource, + runtime_params_type=_CustomSourceRuntimeParams, + model_func=custom_source_formula, + ) + runtime_params = _CustomSourceRuntimeParams( + mode=runtime_params_lib.Mode.ZERO, + puff_decay_length=gas_puff_params.puff_decay_length, + S_puff_tot=gas_puff_params.S_puff_tot, + particle_width=params.particle_width, + deposition_location=params.deposition_location, + S_tot=params.S_tot, + pellet_width=pellet_params.pellet_width, + pellet_deposition_location=pellet_params.pellet_deposition_location, + S_pellet_tot=pellet_params.S_pellet_tot, + ) source_models_builder.source_builders[custom_source_name] = ( - test_lib.TestSourceBuilder( - formula=custom_source_formula, - runtime_params=_CustomSourceRuntimeParams( - mode=runtime_params_lib.Mode.ZERO, - puff_decay_length=gas_puff_params.puff_decay_length, - S_puff_tot=gas_puff_params.S_puff_tot, - particle_width=params.particle_width, - deposition_location=params.deposition_location, - S_tot=params.S_tot, - pellet_width=pellet_params.pellet_width, - pellet_deposition_location=pellet_params.pellet_deposition_location, - S_pellet_tot=pellet_params.S_pellet_tot, - ), + source_builder( + runtime_params=runtime_params, ) ) @@ -241,16 +219,12 @@ def custom_source_formula( params.mode = runtime_params_lib.Mode.ZERO pellet_params.mode = runtime_params_lib.Mode.ZERO gas_puff_params.mode = runtime_params_lib.Mode.ZERO - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.FORMULA_BASED - ) + runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self._run_sim_and_check(sim, ref_profiles, ref_time) with self.subTest('without_defaults_and_without_custom_source'): # Confirm that the custom source actual has an effect. - source_models_builder.runtime_params[custom_source_name].mode = ( - runtime_params_lib.Mode.ZERO - ) + runtime_params.mode = runtime_params_lib.Mode.ZERO with self.assertRaises(AssertionError): self._run_sim_and_check(sim, ref_profiles, ref_time) @@ -311,7 +285,6 @@ def make_provider( raise ValueError('torax_mesh is required for CustomSourceRuntimeParams.') return _CustomSourceRuntimeParamsProvider( runtime_params_config=self, - formula=self.formula.make_provider(torax_mesh), prescribed_values=config_args.get_interpolated_var_2d( self.prescribed_values, torax_mesh.cell_centers ), diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 68764537..e68b18f7 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -38,6 +38,7 @@ from torax.geometry import geometry_provider as geometry_provider_lib from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib from torax.sources.tests import test_lib @@ -52,6 +53,30 @@ _ALL_PROFILES = ('temp_ion', 'temp_el', 'psi', 'q_face', 's_face', 'ne') +class TestImplicitNeSource(test_lib.TestSource): + """A test source.""" + + @property + def source_name(self) -> str: + return 'implicit_ne_source' + + +class TestExplicitNeSource(test_lib.TestSource): + """A test source.""" + + @property + def source_name(self) -> str: + return 'explicit_ne_source' + + +TestImplicitNeSourceBuilder = source_lib.make_source_builder( + TestImplicitNeSource +) +TestExplicitNeSourceBuilder = source_lib.make_source_builder( + TestExplicitNeSource +) + + class SimOutputSourceProfilesTest(sim_test_case.SimTestCase): """Tests checking the output core_sources profiles from run_simulation().""" @@ -101,28 +126,31 @@ def test_first_and_last_source_profiles(self): # This is not physically realistic, just for testing purposes. def custom_source_formula( unused_static_runtime_params_slice, - unused_static_source_runtime_params, - unused_dynamic_runtime_params, - source_conf, + dynamic_runtime_params, geo, + source_name, unused_state, unused_source_models, ): - return jnp.ones_like(geo.rho) * source_conf.foo + dynamic_source_params = dynamic_runtime_params.sources[source_name] + return jnp.ones_like(geo.rho) * dynamic_source_params.foo # Include 2 versions of this source, one implicit and one explicit. + source_builder = source_lib.make_source_builder( + TestImplicitNeSource, + runtime_params_type=_FakeSourceRuntimeParams, + model_func=custom_source_formula, + ) source_models_builder = source_models_lib.SourceModelsBuilder({ - 'implicit_ne_source': test_lib.TestSourceBuilder( - formula=custom_source_formula, + 'implicit_ne_source': source_builder( runtime_params=_FakeSourceRuntimeParams( - mode=runtime_params_lib.Mode.FORMULA_BASED, + mode=runtime_params_lib.Mode.MODEL_BASED, foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, ), ), - 'explicit_ne_source': test_lib.TestSourceBuilder( - formula=custom_source_formula, + 'explicit_ne_source': source_builder( runtime_params=_FakeSourceRuntimeParams( - mode=runtime_params_lib.Mode.FORMULA_BASED, + mode=runtime_params_lib.Mode.MODEL_BASED, foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, ), ), @@ -249,7 +277,6 @@ def make_provider( raise ValueError('torax_mesh is required for FakeSourceRuntimeParams.') return _FakeSourceRuntimeParamsProvider( runtime_params_config=self, - formula=self.formula.make_provider(torax_mesh), prescribed_values=config_args.get_interpolated_var_2d( self.prescribed_values, torax_mesh.cell_centers ), diff --git a/torax/tests/state.py b/torax/tests/state.py index c55d3090..a6679903 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -397,13 +397,7 @@ def test_initial_psi_from_j( source_models = source_models_lib.SourceModels() bootstrap_profile = source_models.j_bootstrap.get_value( dynamic_runtime_params_slice=dcs3, - dynamic_source_runtime_params=dcs3.sources[ - source_models.j_bootstrap_name - ], static_runtime_params_slice=static_slice, - static_source_runtime_params=static_slice.sources[ - source_models.j_bootstrap_name - ], geo=geo, core_profiles=core_profiles3_helper, ) diff --git a/torax/tests/test_lib/default_sources.py b/torax/tests/test_lib/default_sources.py index fd45273a..959b6924 100644 --- a/torax/tests/test_lib/default_sources.py +++ b/torax/tests/test_lib/default_sources.py @@ -13,7 +13,9 @@ # limitations under the License. """Utilities to help with testing sources.""" + from torax.sources import register_source +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -60,9 +62,18 @@ def get_default_sources_builder() -> source_models_lib.SourceModelsBuilder: ] source_builders = {} for name in names: - registered_source = register_source.get_registered_source(name) - runtime_params = registered_source.default_runtime_params_class() - source_builders[name] = registered_source.source_builder_class( - runtime_params=runtime_params - ) + registered_source = register_source.get_supported_source(name) + model_function = registered_source.model_functions[ + registered_source.source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + runtime_params = model_function.runtime_params_class() + source_builder_class = model_function.source_builder_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + registered_source.source_class, + runtime_params_type=model_function.runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + source_builders[name] = source_builder_class(runtime_params=runtime_params) return source_models_lib.SourceModelsBuilder(source_builders)