diff --git a/torax/__init__.py b/torax/__init__.py index fb1eb0e4..9f1936f4 100644 --- a/torax/__init__.py +++ b/torax/__init__.py @@ -89,6 +89,7 @@ def set_jax_precision(): 'source_profiles', 'source_profile', 'explicit_source_profiles', + 'model_func', 'source_models', 'pedestal_model', 'time_step_calculator', diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index afaf7ddc..9e3d1823 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -410,8 +410,18 @@ def _build_single_source_builder_from_config( source_config: dict[str, Any], ) -> source_lib.SourceBuilderProtocol: """Builds a source builder from the input config.""" - registered_source = register_source.get_registered_source(source_name) - runtime_params = registered_source.default_runtime_params_class() + supported_source = register_source.get_supported_source(source_name) + if 'model_func' in source_config: + # If the user has specified a model function, try to retrive that from the + # registered source model functions. + model_func = source_config.pop('model_func') + model_function = supported_source.model_functions[model_func] + else: + # Otherwise, use the default model function. + model_function = supported_source.model_functions[ + supported_source.source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + runtime_params = model_function.runtime_params_class() # Update the defaults with the config provided. source_config = copy.copy(source_config) if 'mode' in source_config: @@ -445,7 +455,16 @@ def _build_single_source_builder_from_config( if formula is not None: kwargs['formula'] = formula - return registered_source.source_builder_class(**kwargs) + source_builder_class = model_function.source_builder_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + supported_source.source_class, + runtime_params_type=model_function.runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + + return source_builder_class(**kwargs) def build_transport_model_builder_from_config( diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index cf3caa12..8f41ff60 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -91,6 +91,7 @@ class BootstrapCurrentSource(source.Source): """ SOURCE_NAME: ClassVar[str] = 'j_bootstrap' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_neoclassical' @property def source_name(self) -> str: diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index a2b85f37..7bc38cac 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -153,6 +153,7 @@ class BremsstrahlungHeatSink(source.Source): """Brehmsstrahlung heat sink for electron heat equation.""" SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'bremsstrahlung_model_func' model_func: source.SourceProfileFunction = bremsstrahlung_model_func @property diff --git a/torax/sources/electron_cyclotron_source.py b/torax/sources/electron_cyclotron_source.py index 9a6e3535..e404c15d 100644 --- a/torax/sources/electron_cyclotron_source.py +++ b/torax/sources/electron_cyclotron_source.py @@ -103,13 +103,13 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): gaussian_ec_total_power: array_typing.ScalarFloat -def _calc_heating_and_current( +def calc_heating_and_current( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, - unused_model_func: source_models.SourceModels, + unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Model function for the electron-cyclotron source. @@ -128,7 +128,7 @@ def _calc_heating_and_current( 2D array of electron cyclotron heating power density and current density. """ del ( - unused_model_func, + unused_source_models, static_runtime_params_slice, ) # Unused. dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ @@ -187,7 +187,8 @@ class ElectronCyclotronSource(source.Source): """Electron cyclotron source for the Te and Psi equations.""" SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source" - model_func: source.SourceProfileFunction = _calc_heating_and_current + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = "calc_heating_and_current" + model_func: source.SourceProfileFunction = calc_heating_and_current @property def source_name(self) -> str: diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 70fba4ae..559c6110 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -76,7 +76,7 @@ class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams): # Default formula: exponential with nref normalization. -def _calc_puff_source( +def calc_puff_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -109,7 +109,9 @@ class GasPuffSource(source.Source): """Gas puff source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'gas_puff_source' - formula: source.SourceProfileFunction = _calc_puff_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source' + formula: source.SourceProfileFunction = calc_puff_source + model_func: source.SourceProfileFunction = calc_puff_source @property def source_name(self) -> str: @@ -172,7 +174,7 @@ class DynamicParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams): S_tot: array_typing.ScalarFloat -def _calc_generic_particle_source( +def calc_generic_particle_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -205,7 +207,9 @@ class GenericParticleSource(source.Source): """Neutral-beam injection source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'generic_particle_source' - formula: source.SourceProfileFunction = _calc_generic_particle_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source' + formula: source.SourceProfileFunction = calc_generic_particle_source + model_func: source.SourceProfileFunction = calc_generic_particle_source @property def source_name(self) -> str: @@ -260,7 +264,7 @@ class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams): S_pellet_tot: array_typing.ScalarFloat -def _calc_pellet_source( +def calc_pellet_source( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -293,7 +297,9 @@ class PelletSource(source.Source): """Pellet source for the ne equation.""" SOURCE_NAME: ClassVar[str] = 'pellet_source' - formula: source.SourceProfileFunction = _calc_pellet_source + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_pellet_source' + formula: source.SourceProfileFunction = calc_pellet_source + model_func: source.SourceProfileFunction = calc_pellet_source @property def source_name(self) -> str: diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index afa97401..ce6f4f34 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -143,6 +143,7 @@ class FusionHeatSource(source.Source): """Fusion heat source for both ion and electron heat.""" SOURCE_NAME: ClassVar[str] = 'fusion_heat_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'fusion_heat_model_func' model_func: source.SourceProfileFunction = fusion_heat_model_func @property diff --git a/torax/sources/generic_current_source.py b/torax/sources/generic_current_source.py index 7208c573..9a9b73e4 100644 --- a/torax/sources/generic_current_source.py +++ b/torax/sources/generic_current_source.py @@ -105,7 +105,7 @@ def __post_init__(self): # pytype bug: does not treat 'source_models.SourceModels' as a forward reference # pytype: disable=name-error -def _calculate_generic_current_face( +def calculate_generic_current_face( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -235,8 +235,10 @@ class GenericCurrentSource(source.Source): """A generic current density source profile.""" SOURCE_NAME: ClassVar[str] = 'generic_current_source' - formula: source.SourceProfileFunction = _calculate_generic_current_face + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face' + formula: source.SourceProfileFunction = calculate_generic_current_face hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires + model_func: source.SourceProfileFunction = calculate_generic_current_face @property def source_name(self) -> str: diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index 23ae64ad..be827002 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -113,7 +113,7 @@ def calc_generic_heat_source( # pytype: disable=name-error -def _default_formula( +def default_formula( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -150,7 +150,9 @@ class GenericIonElectronHeatSource(source.Source): """Generic heat source for both ion and electron heat.""" SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source' - formula: source.SourceProfileFunction = _default_formula + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula' + formula: source.SourceProfileFunction = default_formula + model_func: source.SourceProfileFunction = default_formula @property def source_name(self) -> str: diff --git a/torax/sources/impurity_radiation_heat_sink.py b/torax/sources/impurity_radiation_heat_sink.py index ed96edfe..06c3e7da 100644 --- a/torax/sources/impurity_radiation_heat_sink.py +++ b/torax/sources/impurity_radiation_heat_sink.py @@ -16,6 +16,7 @@ """Basic impurity radiation heat sink for electron heat equation..""" import dataclasses +from typing import ClassVar import chex import jax @@ -139,10 +140,13 @@ class ImpurityRadiationHeatSink(source_lib.Source): """Impurity radiation heat sink for electron heat equation.""" SOURCE_NAME = "impurity_radiation_heat_sink" - source_models: source_models_lib.SourceModels + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = ( + "radially_constant_fraction_of_Pin" + ) model_func: source_lib.SourceProfileFunction = ( radially_constant_fraction_of_Pin ) + source_models: source_models_lib.SourceModels @property def source_name(self) -> str: diff --git a/torax/sources/ion_cyclotron_source.py b/torax/sources/ion_cyclotron_source.py index 529ac125..f2d88a2d 100644 --- a/torax/sources/ion_cyclotron_source.py +++ b/torax/sources/ion_cyclotron_source.py @@ -364,7 +364,7 @@ def _helium3_tail_temperature( return core_profiles.temp_el.value * (1 + epsilon) -def _icrh_model_func( +def icrh_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, @@ -485,8 +485,6 @@ def _icrh_model_func( source_ion += power_deposition_2T * dynamic_source_runtime_params.Ptot return jnp.stack([source_ion, source_el]) - - # pylint: enable=invalid-name @@ -495,6 +493,7 @@ class IonCyclotronSource(source.Source): """Ion cyclotron source with surrogate model.""" SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'icrh_model_func' @property def source_name(self) -> str: @@ -528,13 +527,18 @@ class IonCyclotronSourceBuilder: default_factory=RuntimeParams ) links_back: bool = False + model_func: source.SourceProfileFunction | None = None + + def __post_init__(self): + if self.model_func is None: + self.model_func = functools.partial( + icrh_model_func, + toric_nn=ToricNNWrapper(), + ) def __call__( self, formula: source.SourceProfileFunction | None = None, ) -> IonCyclotronSource: - model_func: source.SourceProfileFunction = functools.partial( - _icrh_model_func, - toric_nn=ToricNNWrapper(), - ) - return IonCyclotronSource(formula=formula, model_func=model_func,) + + return IonCyclotronSource(formula=formula, model_func=self.model_func,) diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py index a0dc6d64..0c5c8166 100644 --- a/torax/sources/ohmic_heat_source.py +++ b/torax/sources/ohmic_heat_source.py @@ -149,7 +149,7 @@ def ohmic_model_func( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels | None = None, + source_models: source_models_lib.SourceModels, ) -> jax.Array: """Returns the Ohmic source for electron heat equation.""" del source_name # Unused. @@ -190,17 +190,10 @@ class OhmicHeatSource(source_lib.Source): """ SOURCE_NAME: ClassVar[str] = 'ohmic_heat_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'ohmic_model_func' + model_func: source_lib.SourceProfileFunction = ohmic_model_func # Users must pass in a pointer to the complete set of sources to this object. source_models: source_models_lib.SourceModels - # The model function is fixed to ohmic_model_func because that is the only - # supported implementation of this source. - # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set the default to the - # desired value. - model_func: source_lib.SourceProfileFunction | None = dataclasses.field( - init=False, - default_factory=lambda: ohmic_model_func, - ) @property def source_name(self) -> str: diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index dee3bdfa..990bbc25 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -72,6 +72,7 @@ class QeiSource(source.Source): """ SOURCE_NAME: ClassVar[str] = 'qei_source' + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'model_based_qei' @property def source_name(self) -> str: diff --git a/torax/sources/register_source.py b/torax/sources/register_source.py index e5e48fce..5a043b68 100644 --- a/torax/sources/register_source.py +++ b/torax/sources/register_source.py @@ -54,117 +54,149 @@ class to build, the runtime associated with that source and (optionally) the @dataclasses.dataclass(frozen=True) -class RegisteredSource: - source_class: Type[source.Source] - source_builder_class: source.SourceBuilderProtocol - default_runtime_params_class: Type[runtime_params.RuntimeParams] - - -def _register_new_source( - source_class: Type[source.Source], - default_runtime_params_class: Type[runtime_params.RuntimeParams], - source_builder_class: source.SourceBuilderProtocol | None = None, - links_back: bool = False, -) -> RegisteredSource: - """Register source class, default runtime params and (optional) builder for this source. +class ModelFunction: + source_profile_function: source.SourceProfileFunction | None + runtime_params_class: Type[runtime_params.RuntimeParams] + source_builder_class: source.SourceBuilderProtocol | None = None + links_back: bool = False - Args: - source_class: The source class. - default_runtime_params_class: The default runtime params class. - source_builder_class: The source builder class. If None, a default builder - is created which uses the source class and default runtime params class to - construct a builder for that source. - links_back: Whether the source requires a reference to all the source - models. - - Returns: - A `RegisteredSource` dataclass containing the source class, source - builder class, and default runtime params class. - """ - if source_builder_class is None: - builder_class = source.make_source_builder( - source_class, - runtime_params_type=default_runtime_params_class, - links_back=links_back, - ) - else: - builder_class = source_builder_class - return RegisteredSource( - source_class=source_class, - source_builder_class=builder_class, - default_runtime_params_class=default_runtime_params_class, - ) - - -_REGISTERED_SOURCES = { - bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME: ( - _register_new_source( - source_class=bootstrap_current_source.BootstrapCurrentSource, - default_runtime_params_class=bootstrap_current_source.RuntimeParams, - ) +@dataclasses.dataclass(frozen=True) +class SupportedSource: + """Source that can be used in TORAX and any associated model functions.""" + source_class: Type[source.Source] + model_functions: dict[str, ModelFunction] + + +_SUPPORTED_SOURCES = { + bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME: SupportedSource( + source_class=bootstrap_current_source.BootstrapCurrentSource, + model_functions={ + bootstrap_current_source.BootstrapCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=bootstrap_current_source.RuntimeParams, + ) + }, ), - generic_current_source.GenericCurrentSource.SOURCE_NAME: ( - _register_new_source( - source_class=generic_current_source.GenericCurrentSource, - default_runtime_params_class=generic_current_source.RuntimeParams, - ) + generic_current_source.GenericCurrentSource.SOURCE_NAME: SupportedSource( + source_class=generic_current_source.GenericCurrentSource, + model_functions={ + generic_current_source.GenericCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=generic_current_source.calculate_generic_current_face, + runtime_params_class=generic_current_source.RuntimeParams, + ) + }, ), - electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: _register_new_source( + electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: SupportedSource( source_class=electron_cyclotron_source.ElectronCyclotronSource, - default_runtime_params_class=electron_cyclotron_source.RuntimeParams, + model_functions={ + electron_cyclotron_source.ElectronCyclotronSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_cyclotron_source.calc_heating_and_current, + runtime_params_class=electron_cyclotron_source.RuntimeParams, + ) + }, ), - electron_density_sources.GenericParticleSource.SOURCE_NAME: _register_new_source( + electron_density_sources.GenericParticleSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.GenericParticleSource, - default_runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams, + model_functions={ + electron_density_sources.GenericParticleSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_generic_particle_source, + runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams, + ) + }, ), - electron_density_sources.GasPuffSource.SOURCE_NAME: _register_new_source( + electron_density_sources.GasPuffSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.GasPuffSource, - default_runtime_params_class=electron_density_sources.GasPuffRuntimeParams, + model_functions={ + electron_density_sources.GasPuffSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_puff_source, + runtime_params_class=electron_density_sources.GasPuffRuntimeParams, + ) + }, ), - electron_density_sources.PelletSource.SOURCE_NAME: _register_new_source( + electron_density_sources.PelletSource.SOURCE_NAME: SupportedSource( source_class=electron_density_sources.PelletSource, - default_runtime_params_class=electron_density_sources.PelletRuntimeParams, + model_functions={ + electron_density_sources.PelletSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=electron_density_sources.calc_pellet_source, + runtime_params_class=electron_density_sources.PelletRuntimeParams, + ) + }, ), - ion_el_heat.GenericIonElectronHeatSource.SOURCE_NAME: _register_new_source( + ion_el_heat.GenericIonElectronHeatSource.SOURCE_NAME: SupportedSource( source_class=ion_el_heat.GenericIonElectronHeatSource, - default_runtime_params_class=ion_el_heat.RuntimeParams, + model_functions={ + ion_el_heat.GenericIonElectronHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=ion_el_heat.default_formula, + runtime_params_class=ion_el_heat.RuntimeParams, + ) + }, ), - fusion_heat_source.FusionHeatSource.SOURCE_NAME: _register_new_source( + fusion_heat_source.FusionHeatSource.SOURCE_NAME: SupportedSource( source_class=fusion_heat_source.FusionHeatSource, - default_runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, + model_functions={ + fusion_heat_source.FusionHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=fusion_heat_source.fusion_heat_model_func, + runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams, + ) + }, ), - qei_source.QeiSource.SOURCE_NAME: _register_new_source( + qei_source.QeiSource.SOURCE_NAME: SupportedSource( source_class=qei_source.QeiSource, - default_runtime_params_class=qei_source.RuntimeParams, + model_functions={ + qei_source.QeiSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=qei_source.RuntimeParams, + ) + }, ), - ohmic_heat_source.OhmicHeatSource.SOURCE_NAME: _register_new_source( + ohmic_heat_source.OhmicHeatSource.SOURCE_NAME: SupportedSource( source_class=ohmic_heat_source.OhmicHeatSource, - default_runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, - links_back=True, + model_functions={ + ohmic_heat_source.OhmicHeatSource.DEFAULT_MODEL_FUNCTION_NAME: ( + ModelFunction( + source_profile_function=ohmic_heat_source.ohmic_model_func, + runtime_params_class=ohmic_heat_source.OhmicRuntimeParams, + links_back=True, + ) + ) + }, ), - bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME: ( - _register_new_source( - source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, - default_runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, - ) + bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME: SupportedSource( + source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink, + model_functions={ + bremsstrahlung_heat_sink.BremsstrahlungHeatSink.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=bremsstrahlung_heat_sink.bremsstrahlung_model_func, + runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams, + ) + }, ), - ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: _register_new_source( + ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: SupportedSource( source_class=ion_cyclotron_source.IonCyclotronSource, - default_runtime_params_class=ion_cyclotron_source.RuntimeParams, - source_builder_class=ion_cyclotron_source.IonCyclotronSourceBuilder, + model_functions={ + ion_cyclotron_source.IonCyclotronSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=None, + runtime_params_class=ion_cyclotron_source.RuntimeParams, + source_builder_class=ion_cyclotron_source.IonCyclotronSourceBuilder, + ) + }, ), - impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: _register_new_source( + impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: SupportedSource( source_class=impurity_radiation_heat_sink.ImpurityRadiationHeatSink, - default_runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, - links_back=True, + model_functions={ + impurity_radiation_heat_sink.ImpurityRadiationHeatSink.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction( + source_profile_function=impurity_radiation_heat_sink.radially_constant_fraction_of_Pin, + runtime_params_class=impurity_radiation_heat_sink.RuntimeParams, + links_back=True, + ) + }, ), } -def get_registered_source(source_name: str) -> RegisteredSource: - """Used when building a simulation to get the registered source.""" - if source_name in _REGISTERED_SOURCES: - return _REGISTERED_SOURCES[source_name] +def get_supported_source(source_name: str) -> SupportedSource: + """Used when building a simulation to get the supported source.""" + if source_name in _SUPPORTED_SOURCES: + return _SUPPORTED_SOURCES[source_name] else: raise RuntimeError(f'Source:{source_name} has not been registered.') diff --git a/torax/sources/source.py b/torax/sources/source.py index b49d55a0..c276553a 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -28,7 +28,7 @@ import enum import types import typing -from typing import Any, Callable, Optional, Protocol, TypeAlias +from typing import Any, Callable, ClassVar, Optional, Protocol, TypeAlias # We use Optional here because | doesn't work with string name types. # We use string name 'source_models.SourceModels' in this file to avoid @@ -104,6 +104,8 @@ class Source(abc.ABC): are in turn used to compute coeffs in sim.py. Attributes: + DEFAULT_MODEL_FUNCTION_NAME: The name of the model function used with this + source if another isn't specified. runtime_params: Input dataclass containing all the source-specific runtime parameters. At runtime, the parameters here are interpolated to a specific time t and then passed to the model_func or formula, depending on the mode @@ -127,7 +129,7 @@ class Source(abc.ABC): affected_core_profiles_ints: Derived property from the affected_core_profiles. Integer values of those enums. """ - + DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default' model_func: SourceProfileFunction | None = None formula: SourceProfileFunction | None = None @@ -152,6 +154,7 @@ def supported_modes(self) -> tuple[runtime_params_lib.Mode, ...]: return ( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED, + runtime_params_lib.Mode.MODEL_BASED, runtime_params_lib.Mode.PRESCRIBED, ) @@ -202,24 +205,14 @@ def get_value( ] self.check_mode(static_source_runtime_params.mode) output_shape = self.output_shape_getter(geo) - model_func = ( - (lambda _0, _1, _2, _3, _4, _5: jnp.zeros(output_shape)) - if self.model_func is None - else self.model_func - ) - formula = ( - (lambda _0, _1, _2, _3, _4, _5: jnp.zeros(output_shape)) - if self.formula is None - else self.formula - ) return get_source_profiles( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, - model_func=model_func, - formula=formula, + model_func=self.model_func, + formula=self.formula, prescribed_values=dynamic_source_runtime_params.prescribed_values, output_shape=output_shape, source_models=getattr(self, 'source_models', None), @@ -309,8 +302,8 @@ def get_source_profiles( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, - model_func: SourceProfileFunction, - formula: SourceProfileFunction, + model_func: SourceProfileFunction | None, + formula: SourceProfileFunction | None, prescribed_values: chex.Array, output_shape: tuple[int, ...], source_models: Optional['source_models.SourceModels'], @@ -344,6 +337,10 @@ def get_source_profiles( mode = static_runtime_params_slice.sources[source_name].mode match mode: case runtime_params_lib.Mode.MODEL_BASED.value: + if model_func is None: + raise ValueError( + 'Source is in MODEL_BASED mode but has no model function.' + ) return model_func( static_runtime_params_slice, dynamic_runtime_params_slice, @@ -353,6 +350,10 @@ def get_source_profiles( source_models, ) case runtime_params_lib.Mode.FORMULA_BASED.value: + if formula is None: + raise ValueError( + 'Source is in FORMULA_BASED mode but has no formula function.' + ) return formula( static_runtime_params_slice, dynamic_runtime_params_slice, @@ -427,6 +428,7 @@ def is_source_builder(obj, raise_if_false: bool = False) -> bool: def _convert_source_builder_to_init_kwargs( source_builder: ..., + model_func: SourceProfileFunction | None, ) -> dict[str, Any]: """Returns a dict of init kwargs for the source builder.""" source_init_kwargs = {} @@ -438,12 +440,14 @@ def _convert_source_builder_to_init_kwargs( # including turning custom dataclasses with __call__ methods into # plain Python dictionaries. source_init_kwargs[field.name] = getattr(source_builder, field.name) + source_init_kwargs['model_func'] = model_func return source_init_kwargs def make_source_builder( source_type: ..., runtime_params_type: ... = runtime_params_lib.RuntimeParams, + model_func: SourceProfileFunction | None = None, links_back=False, ) -> SourceBuilderProtocol: """Given a Source type, returns a Builder for that type. @@ -454,6 +458,7 @@ def make_source_builder( source_type: The Source class to make a builder for. runtime_params_type: The type of `runtime_params` field which will be added to the builder dataclass. + model_func: The model function to pass to the source. links_back: If True, the Source class has a `source_models` field linking back to the SourceModels object. This must be passed to the builder's __call__ method. @@ -510,7 +515,10 @@ def check_kwargs(source_init_kwargs, context_msg): assert all([isinstance(var, runtime_params_lib.Mode) for var in v]) elif f.type == 'SourceProfileFunction | None': assert v is None or callable(v) - elif f.type == 'source.SourceProfileFunction': + elif f.type in [ + 'source.SourceProfileFunction', + 'source_lib.SourceProfileFunction', + ]: if not callable(v): raise TypeError( f'While {context_msg} {source_type} got field ' @@ -535,8 +543,8 @@ def check_kwargs(source_init_kwargs, context_msg): raise TypeError(f'Unrecognized type string: {f.type}') # Check if the field is a parameterized generic. - # Python cannot check isinstance for parameterized generics, so we need - # to handle those cases differently. + # Python cannot check isinstance for parameterized generics, so we ignore + # these cases for now. # For instance, if a field type is `tuple[float, ...]` and the value is # valid, like `(1, 2, 3)`, then `isinstance(v, f.type)` would raise a # TypeError. @@ -544,13 +552,7 @@ def check_kwargs(source_init_kwargs, context_msg): type(f.type) == types.GenericAlias # pylint: disable=unidiomatic-typecheck or typing.get_origin(f.type) is not None ): - # Do a superficial check in these instances. Only check that the origin - # type matches the value. Don't look into the rest of the object. - if not isinstance(v, typing.get_origin(f.type)): - raise TypeError( - f'While {context_msg} {source_type} got field {f.name} with ' - f'input type {type(v)} but an expected type {f.type}.' - ) + pass else: try: @@ -571,7 +573,9 @@ def check_kwargs(source_init_kwargs, context_msg): # pylint doesn't like this function name because it doesn't realize # this function is to be installed in a class def __post_init__(self): # pylint:disable=invalid-name - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, model_func + ) check_kwargs(source_init_kwargs, 'making builder') # check_kwargs checks only the kwargs to Source, not SourceBuilder, # so it doesn't check "runtime_params" @@ -597,7 +601,10 @@ def check_source(source): if links_back: def build_source(self, source_models): - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, + model_func, + ) source_init_kwargs['source_models'] = source_models check_kwargs(source_init_kwargs, 'building') source = source_type(**source_init_kwargs) @@ -607,7 +614,10 @@ def build_source(self, source_models): else: def build_source(self): - source_init_kwargs = _convert_source_builder_to_init_kwargs(self) + source_init_kwargs = _convert_source_builder_to_init_kwargs( + self, + model_func, + ) check_kwargs(source_init_kwargs, 'building') source = source_type(**source_init_kwargs) check_source(source) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 4d30a6b2..45cfbfb3 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -719,6 +719,7 @@ def __init__( ] = source_lib.make_source_builder( generic_current_source.GenericCurrentSource, runtime_params_type=generic_current_source.RuntimeParams, + model_func=generic_current_source.calculate_generic_current_face, )() source_builders[ generic_current_source.GenericCurrentSource.SOURCE_NAME diff --git a/torax/sources/tests/bootstrap_current_source.py b/torax/sources/tests/bootstrap_current_source.py index 390ff553..71191626 100644 --- a/torax/sources/tests/bootstrap_current_source.py +++ b/torax/sources/tests/bootstrap_current_source.py @@ -37,6 +37,7 @@ def setUpClass(cls): runtime_params_lib.Mode.FORMULA_BASED, ], source_name=bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME, + model_func=None, ) def test_extraction_of_relevant_profile_from_output(self): diff --git a/torax/sources/tests/bremsstrahlung_heat_sink.py b/torax/sources/tests/bremsstrahlung_heat_sink.py index 64c53669..35575ceb 100644 --- a/torax/sources/tests/bremsstrahlung_heat_sink.py +++ b/torax/sources/tests/bremsstrahlung_heat_sink.py @@ -42,6 +42,7 @@ def setUpClass(cls): runtime_params_lib.Mode.FORMULA_BASED, ], source_name=bremsstrahlung_heat_sink.BremsstrahlungHeatSink.SOURCE_NAME, + model_func=bremsstrahlung_heat_sink.bremsstrahlung_model_func, ) @parameterized.parameters([ diff --git a/torax/sources/tests/electron_cyclotron_source.py b/torax/sources/tests/electron_cyclotron_source.py index 670fb21c..23b9cf46 100644 --- a/torax/sources/tests/electron_cyclotron_source.py +++ b/torax/sources/tests/electron_cyclotron_source.py @@ -42,6 +42,7 @@ def setUpClass(cls): runtime_params_lib.Mode.FORMULA_BASED, ], source_name=electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME, + model_func=electron_cyclotron_source.calc_heating_and_current, ) def test_source_value(self): @@ -51,16 +52,10 @@ def test_source_value(self): raise TypeError(f"{type(self)} has a bad _source_class_builder") runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder( - { - electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: ( - source_builder - ) - }, + {self._source_name: source_builder}, ) source_models = source_models_builder() - source = source_models.sources[ - electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME - ] + source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() diff --git a/torax/sources/tests/electron_density_sources.py b/torax/sources/tests/electron_density_sources.py index 1f745c98..d7cf0802 100644 --- a/torax/sources/tests/electron_density_sources.py +++ b/torax/sources/tests/electron_density_sources.py @@ -16,7 +16,6 @@ from absl.testing import absltest from torax.sources import electron_density_sources as eds -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -28,10 +27,9 @@ def setUpClass(cls): super().setUpClass( source_class=eds.GasPuffSource, runtime_params_class=eds.GasPuffRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], source_name=eds.GasPuffSource.SOURCE_NAME, + model_func=eds.calc_puff_source, ) @@ -43,10 +41,9 @@ def setUpClass(cls): super().setUpClass( source_class=eds.PelletSource, runtime_params_class=eds.PelletRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], source_name=eds.PelletSource.SOURCE_NAME, + model_func=eds.calc_pellet_source, ) @@ -58,10 +55,9 @@ def setUpClass(cls): super().setUpClass( source_class=eds.GenericParticleSource, runtime_params_class=eds.GenericParticleSourceRuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], source_name=eds.GenericParticleSource.SOURCE_NAME, + model_func=eds.calc_generic_particle_source, ) diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index 9329c9ff..6053dc6d 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -41,6 +41,7 @@ def setUpClass(cls): runtime_params_lib.Mode.FORMULA_BASED, ], source_name=fusion_heat_source.FusionHeatSource.SOURCE_NAME, + model_func=fusion_heat_source.fusion_heat_model_func, ) @parameterized.parameters([ diff --git a/torax/sources/tests/generic_current_source.py b/torax/sources/tests/generic_current_source.py index 0f2298f5..15dc1ba0 100644 --- a/torax/sources/tests/generic_current_source.py +++ b/torax/sources/tests/generic_current_source.py @@ -39,6 +39,7 @@ def setUpClass(cls): runtime_params_lib.Mode.MODEL_BASED, ], source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME, + model_func=generic_current_source.calculate_generic_current_face, ) def test_generic_current_hires(self): diff --git a/torax/sources/tests/generic_ion_el_heat_source.py b/torax/sources/tests/generic_ion_el_heat_source.py index 86fc59e5..fb644f47 100644 --- a/torax/sources/tests/generic_ion_el_heat_source.py +++ b/torax/sources/tests/generic_ion_el_heat_source.py @@ -16,7 +16,6 @@ from absl.testing import absltest from torax.sources import generic_ion_el_heat_source -from torax.sources import runtime_params as runtime_params_lib from torax.sources.tests import test_lib @@ -28,10 +27,9 @@ def setUpClass(cls): super().setUpClass( source_class=generic_ion_el_heat_source.GenericIonElectronHeatSource, runtime_params_class=generic_ion_el_heat_source.RuntimeParams, - unsupported_modes=[ - runtime_params_lib.Mode.MODEL_BASED, - ], + unsupported_modes=[], source_name=generic_ion_el_heat_source.GenericIonElectronHeatSource.SOURCE_NAME, + model_func=generic_ion_el_heat_source.default_formula, ) diff --git a/torax/sources/tests/impurity_radiation_heat_sink.py b/torax/sources/tests/impurity_radiation_heat_sink.py index 21bf57c8..dd1135de 100644 --- a/torax/sources/tests/impurity_radiation_heat_sink.py +++ b/torax/sources/tests/impurity_radiation_heat_sink.py @@ -45,6 +45,7 @@ def setUpClass(cls): unsupported_modes=[], source_name=impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME, links_back=True, + model_func=impurity_radiation_heat_sink_lib.radially_constant_fraction_of_Pin ) def test_source_value(self): @@ -62,8 +63,11 @@ def test_source_value(self): heat_source_builder_builder = source_lib.make_source_builder( source_type=generic_ion_el_heat_source.GenericIonElectronHeatSource, runtime_params_type=generic_ion_el_heat_source.RuntimeParams, + model_func=generic_ion_el_heat_source.default_formula, + ) + heat_source_builder = heat_source_builder_builder( + model_func=generic_ion_el_heat_source.default_formula ) - heat_source_builder = heat_source_builder_builder() # Runtime params runtime_params = general_runtime_params.GeneralRuntimeParams() diff --git a/torax/sources/tests/ion_cyclotron_source.py b/torax/sources/tests/ion_cyclotron_source.py index 1ddda493..ee380f73 100644 --- a/torax/sources/tests/ion_cyclotron_source.py +++ b/torax/sources/tests/ion_cyclotron_source.py @@ -101,6 +101,7 @@ def setUpClass(cls): unsupported_modes=[runtime_params_lib.Mode.FORMULA_BASED], source_class_builder=ion_cyclotron_source.IonCyclotronSourceBuilder, source_name=ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME, + model_func=None, ) @parameterized.product( diff --git a/torax/sources/tests/ohmic_heat_source.py b/torax/sources/tests/ohmic_heat_source.py index b4dfb1c2..597d7dc3 100644 --- a/torax/sources/tests/ohmic_heat_source.py +++ b/torax/sources/tests/ohmic_heat_source.py @@ -31,6 +31,7 @@ def setUpClass(cls): ], source_name=ohmic_heat_source.OhmicHeatSource.SOURCE_NAME, links_back=True, + model_func=ohmic_heat_source.ohmic_model_func, ) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index 43e53db6..bf64cf78 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -38,6 +38,7 @@ def setUpClass(cls): runtime_params_lib.Mode.FORMULA_BASED, ], source_name=qei_source.QeiSource.SOURCE_NAME, + model_func=None, ) def test_source_value(self): diff --git a/torax/sources/tests/register_source.py b/torax/sources/tests/register_source.py index 47a6465f..dd0100db 100644 --- a/torax/sources/tests/register_source.py +++ b/torax/sources/tests/register_source.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for the source registry.""" - from absl.testing import absltest from absl.testing import parameterized from torax.sources import bootstrap_current_source @@ -27,6 +26,7 @@ from torax.sources import ohmic_heat_source from torax.sources import qei_source from torax.sources import register_source +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -48,10 +48,21 @@ class SourceTest(parameterized.TestCase): ) def test_sources_in_registry_build_successfully(self, source_name: str): """Test that all sources in the registry build successfully.""" - registered_source = register_source.get_registered_source(source_name) + registered_source = register_source.get_supported_source(source_name) source_class = registered_source.source_class - source_runtime_params_class = registered_source.default_runtime_params_class - source_builder_class = registered_source.source_builder_class + model_function = registered_source.model_functions[ + source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + source_builder_class = model_function.source_builder_class + source_runtime_params_class = model_function.runtime_params_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + registered_source.source_class, + runtime_params_type=source_runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + source_runtime_params_class = model_function.runtime_params_class source_builder = source_builder_class() self.assertIsInstance( source_builder.runtime_params, source_runtime_params_class diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 09434650..c2ac684a 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -156,6 +156,7 @@ class NotEq: @dataclasses.dataclass(frozen=True, eq=True) class MySource: my_field: int + model_func: source_lib.SourceProfileFunction | None = None # pylint doesn't realize this is a class MySourceBuilder = source_lib.make_source_builder( # pylint: disable=invalid-name @@ -376,16 +377,13 @@ def test_defaults_output_zeros(self): }, torax_mesh=geo.torax_mesh, ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose( - profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), - ) + with self.assertRaises(ValueError): + source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + core_profiles=core_profiles, + ) with self.subTest('formula'): static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params=runtime_params, @@ -397,16 +395,13 @@ def test_defaults_output_zeros(self): }, torax_mesh=geo.torax_mesh, ) - profile = source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_slice, - geo=geo, - core_profiles=core_profiles, - ) - np.testing.assert_allclose( - profile, - get_zero_profile(source_lib.ProfileType.CELL, geo), - ) + with self.assertRaises(ValueError): + source.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + core_profiles=core_profiles, + ) with self.subTest('prescribed'): static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params=runtime_params, @@ -477,9 +472,10 @@ def test_overriding_model(self): geo = geometry.build_circular_geometry() output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) expected_output = jnp.ones(output_shape) - source_builder = IonElTestSourceBuilder( + source_builder = source_lib.make_source_builder( + IonElTestSource, model_func=lambda _0, _1, _2, _3, _4, _5: expected_output, - ) + )() source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED source_models_builder = source_models_lib.SourceModelsBuilder( {'foo': source_builder}, diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index f0538225..c42fb22e 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -71,6 +71,7 @@ class SourceTestCase(parameterized.TestCase): _config_attr_name: str _unsupported_modes: Sequence[runtime_params_lib.Mode] _source_name: str + _runtime_params_class: Type[runtime_params_lib.RuntimeParams] @classmethod def setUpClass( @@ -79,6 +80,7 @@ def setUpClass( runtime_params_class: Type[runtime_params_lib.RuntimeParams], unsupported_modes: Sequence[runtime_params_lib.Mode], source_name: str, + model_func: source_lib.SourceProfileFunction | None, links_back: bool = False, source_class_builder: source_lib.SourceBuilderProtocol | None = None, ): @@ -89,6 +91,7 @@ def setUpClass( source_type=source_class, runtime_params_type=runtime_params_class, links_back=links_back, + model_func=model_func, ) else: cls._source_class_builder = source_class_builder @@ -138,7 +141,7 @@ def test_source_value(self): # SingleProfileSource subclasses should have default names and be # instantiable without any __init__ arguments. # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter + source_builder = self._source_class_builder() if not source_lib.is_source_builder(source_builder): raise TypeError(f'{type(self)} has a bad _source_class_builder') # pylint: enable=missing-kwoa @@ -148,7 +151,7 @@ def test_source_value(self): ) source_models = source_models_builder() source = source_models.sources[self._source_name] - source_builder.runtime_params.mode = source.supported_modes[0] + source_builder.runtime_params.mode = source.supported_modes[1] self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() dynamic_runtime_params_slice = ( @@ -235,7 +238,7 @@ class IonElSourceTestCase(SourceTestCase): def test_source_value(self): """Tests that the source can provide a value by default.""" # pylint: disable=missing-kwoa - source_builder = self._source_class_builder() # pytype: disable=missing-parameter + source_builder = self._source_class_builder() # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry() diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index 215eade5..2e35174f 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -107,19 +107,19 @@ def custom_source_formula( # Combine the outputs. # pylint: disable=protected-access return ( - electron_density_sources._calc_puff_source( + electron_density_sources.calc_puff_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, source_name=electron_density_sources.GasPuffSource.SOURCE_NAME, ) - + electron_density_sources._calc_generic_particle_source( + + electron_density_sources.calc_generic_particle_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, source_name=electron_density_sources.GenericParticleSource.SOURCE_NAME, ) - + electron_density_sources._calc_pellet_source( + + electron_density_sources.calc_pellet_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, diff --git a/torax/tests/test_lib/default_sources.py b/torax/tests/test_lib/default_sources.py index fd45273a..959b6924 100644 --- a/torax/tests/test_lib/default_sources.py +++ b/torax/tests/test_lib/default_sources.py @@ -13,7 +13,9 @@ # limitations under the License. """Utilities to help with testing sources.""" + from torax.sources import register_source +from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -60,9 +62,18 @@ def get_default_sources_builder() -> source_models_lib.SourceModelsBuilder: ] source_builders = {} for name in names: - registered_source = register_source.get_registered_source(name) - runtime_params = registered_source.default_runtime_params_class() - source_builders[name] = registered_source.source_builder_class( - runtime_params=runtime_params - ) + registered_source = register_source.get_supported_source(name) + model_function = registered_source.model_functions[ + registered_source.source_class.DEFAULT_MODEL_FUNCTION_NAME + ] + runtime_params = model_function.runtime_params_class() + source_builder_class = model_function.source_builder_class + if source_builder_class is None: + source_builder_class = source_lib.make_source_builder( + registered_source.source_class, + runtime_params_type=model_function.runtime_params_class, + links_back=model_function.links_back, + model_func=model_function.source_profile_function, + ) + source_builders[name] = source_builder_class(runtime_params=runtime_params) return source_models_lib.SourceModelsBuilder(source_builders)