diff --git a/torax/config/tests/build_sim.py b/torax/config/tests/build_sim.py index 39619751..1243ccf2 100644 --- a/torax/config/tests/build_sim.py +++ b/torax/config/tests/build_sim.py @@ -366,7 +366,7 @@ def test_adding_standard_source_via_config(self): # pytype: enable=attribute-error self.assertEqual( source_models_builder.runtime_params['gas_puff_source'].mode, - source_runtime_params_lib.Mode.FORMULA_BASED, # On by default. + source_runtime_params_lib.Mode.MODEL_BASED, # On by default. ) self.assertEqual( source_models_builder.runtime_params['ohmic_heat_source'].mode, diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index cb9f65c7..e78befef 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -580,7 +580,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): ) ) geo = geometry.build_circular_geometry(n_rho=num_cells) - source_models = source_models_lib.SourceModels() + source_models = source_models_builder() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice, dynamic_runtime_params_slice, @@ -732,7 +732,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): ), ) - source_models = source_models_lib.SourceModels() + source_models = source_models_builder() pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice_theta0, diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 559c6110..4998df7c 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -41,7 +41,7 @@ class GasPuffRuntimeParams(runtime_params_lib.RuntimeParams): puff_decay_length: runtime_params_lib.TimeInterpolatedInput = 0.05 # total gas puff particles/s S_puff_tot: runtime_params_lib.TimeInterpolatedInput = 1e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -110,7 +110,6 @@ class GasPuffSource(source.Source): SOURCE_NAME: ClassVar[str] = 'gas_puff_source' DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source' - formula: source.SourceProfileFunction = calc_puff_source model_func: source.SourceProfileFunction = calc_puff_source @property @@ -132,7 +131,7 @@ class GenericParticleSourceRuntimeParams(runtime_params_lib.RuntimeParams): deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.0 # total particle source S_tot: runtime_params_lib.TimeInterpolatedInput = 1e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -208,7 +207,6 @@ class GenericParticleSource(source.Source): SOURCE_NAME: ClassVar[str] = 'generic_particle_source' DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source' - formula: source.SourceProfileFunction = calc_generic_particle_source model_func: source.SourceProfileFunction = calc_generic_particle_source @property @@ -232,7 +230,7 @@ class PelletRuntimeParams(runtime_params_lib.RuntimeParams): pellet_deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.85 # total pellet particles/s (continuous pellet model) S_pellet_tot: runtime_params_lib.TimeInterpolatedInput = 2e22 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -298,7 +296,6 @@ class PelletSource(source.Source): SOURCE_NAME: ClassVar[str] = 'pellet_source' DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_pellet_source' - formula: source.SourceProfileFunction = calc_pellet_source model_func: source.SourceProfileFunction = calc_pellet_source @property diff --git a/torax/sources/generic_current_source.py b/torax/sources/generic_current_source.py index 9a9b73e4..ba39ad34 100644 --- a/torax/sources/generic_current_source.py +++ b/torax/sources/generic_current_source.py @@ -52,7 +52,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): # Toggles if external current is provided absolutely or as a fraction of Ip. use_absolute_current: bool = False - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED @property def grid_type(self) -> base.GridType: @@ -236,7 +236,6 @@ class GenericCurrentSource(source.Source): SOURCE_NAME: ClassVar[str] = 'generic_current_source' DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face' - formula: source.SourceProfileFunction = calculate_generic_current_face hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires model_func: source.SourceProfileFunction = calculate_generic_current_face @@ -305,12 +304,8 @@ def generic_current_source_hires( geo=geo, core_profiles=None, # There is no model for this source. - model_func=( - lambda _0, _1, _2, _3, _4, _5: jnp.zeros_like( - geo.rho_hires_norm - ) - ), - formula=self.hires_formula, + model_func=self.hires_formula, + formula=None, output_shape=geo.rho_hires_norm.shape, prescribed_values=hires_prescribed_values, source_models=getattr(self, 'source_models', None), diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index be827002..f7eca7ba 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -48,7 +48,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): Ptot: runtime_params_lib.TimeInterpolatedInput = 120e6 # electron heating fraction el_heat_fraction: runtime_params_lib.TimeInterpolatedInput = 0.66666 - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED def make_provider( self, @@ -151,7 +151,6 @@ class GenericIonElectronHeatSource(source.Source): SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source' DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula' - formula: source.SourceProfileFunction = default_formula model_func: source.SourceProfileFunction = default_formula @property diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index c42fb22e..e587a809 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -151,7 +151,7 @@ def test_source_value(self): ) source_models = source_models_builder() source = source_models.sources[self._source_name] - source_builder.runtime_params.mode = source.supported_modes[1] + source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() dynamic_runtime_params_slice = ( diff --git a/torax/tests/physics.py b/torax/tests/physics.py index b1cc06f7..c7291789 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -16,6 +16,7 @@ import dataclasses from typing import Callable + from absl.testing import absltest from absl.testing import parameterized import jax @@ -27,10 +28,12 @@ from torax.config import runtime_params_slice from torax.fvm import cell_variable from torax.geometry import geometry +from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib from torax.tests.test_lib import torax_refs + _trapz = jax.scipy.integrate.trapezoid @@ -108,9 +111,9 @@ def test_update_psi_from_j( runtime_params = references.runtime_params source_models_builder = source_models_lib.SourceModelsBuilder() # Turn on the external current source. - source_models_builder.runtime_params['generic_current_source'].mode = ( - source_runtime_params.Mode.FORMULA_BASED - ) + source_models_builder.runtime_params[ + generic_current_source.GenericCurrentSource.SOURCE_NAME + ].mode = source_runtime_params.Mode.MODEL_BASED source_models = source_models_builder() dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry(