Skip to content

Commit

Permalink
Create an explicit builder for the IonCyclotronSource.
Browse files Browse the repository at this point in the history
This is useful for not providing a specific builder for the ion cyclotron source when we move to flexible model funcs (each of which has its own builder) which otherwise all have the same arguments.

PiperOrigin-RevId: 709295177
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent 2c348d0 commit c835d07
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
32 changes: 20 additions & 12 deletions torax/sources/ion_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,6 @@ class IonCyclotronSource(source.Source):
"""Ion cyclotron source with surrogate model."""

SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source'
# The model function is fixed to _icrh_model_func because that is the only
# supported implementation of this source.
# However, since this is a param in the parent dataclass, we need to (a)
# remove the parameter from the init args and (b) set the default to the
# desired value.
model_func: source.SourceProfileFunction = dataclasses.field(
init=False,
default_factory=lambda: functools.partial(
_icrh_model_func,
toric_nn=ToricNNWrapper(),
),
)

@property
def source_name(self) -> str:
Expand All @@ -530,3 +518,23 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:
@property
def output_shape_getter(self) -> source.SourceOutputShapeFunction:
return source.get_ion_el_output_shape


@dataclasses.dataclass(kw_only=True, frozen=False)
class IonCyclotronSourceBuilder:
"""Builder for the IonCyclotronSource."""

runtime_params: RuntimeParams = dataclasses.field(
default_factory=RuntimeParams
)
links_back: bool = False

def __call__(
self,
formula: source.SourceProfileFunction | None = None,
) -> IonCyclotronSource:
model_func: source.SourceProfileFunction = functools.partial(
_icrh_model_func,
toric_nn=ToricNNWrapper(),
)
return IonCyclotronSource(formula=formula, model_func=model_func,)
1 change: 1 addition & 0 deletions torax/sources/register_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _register_new_source(
ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: _register_new_source(
source_class=ion_cyclotron_source.IonCyclotronSource,
default_runtime_params_class=ion_cyclotron_source.RuntimeParams,
source_builder_class=ion_cyclotron_source.IonCyclotronSourceBuilder,
),
impurity_radiation_heat_sink.ImpurityRadiationHeatSink.SOURCE_NAME: _register_new_source(
source_class=impurity_radiation_heat_sink.ImpurityRadiationHeatSink,
Expand Down
1 change: 1 addition & 0 deletions torax/sources/tests/ion_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def setUpClass(cls):
source_class=ion_cyclotron_source.IonCyclotronSource,
runtime_params_class=ion_cyclotron_source.RuntimeParams,
unsupported_modes=[runtime_params_lib.Mode.FORMULA_BASED],
source_class_builder=ion_cyclotron_source.IonCyclotronSourceBuilder,
source_name=ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME,
)

Expand Down
14 changes: 9 additions & 5 deletions torax/sources/tests/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ def setUpClass(
unsupported_modes: Sequence[runtime_params_lib.Mode],
source_name: str,
links_back: bool = False,
source_class_builder: source_lib.SourceBuilderProtocol | None = None,
):
super().setUpClass()
cls._source_class = source_class
cls._source_class_builder = source_lib.make_source_builder(
source_type=source_class,
runtime_params_type=runtime_params_class,
links_back=links_back,
)
if source_class_builder is None:
cls._source_class_builder = source_lib.make_source_builder(
source_type=source_class,
runtime_params_type=runtime_params_class,
links_back=links_back,
)
else:
cls._source_class_builder = source_class_builder
cls._runtime_params_class = runtime_params_class
cls._unsupported_modes = unsupported_modes
cls._links_back = links_back
Expand Down

0 comments on commit c835d07

Please sign in to comment.