From 88c4622264ee51f2a1dcab7042182e69a6699e2c Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Fri, 20 Dec 2024 06:09:51 -0800 Subject: [PATCH] Allow addition of model functions (or builders) to the source registry. PiperOrigin-RevId: 708301156 --- torax/sources/register_source.py | 25 +++++++++++++++++++++++++ torax/sources/source.py | 2 ++ 2 files changed, 27 insertions(+) diff --git a/torax/sources/register_source.py b/torax/sources/register_source.py index f8cf784f..7354d968 100644 --- a/torax/sources/register_source.py +++ b/torax/sources/register_source.py @@ -201,3 +201,28 @@ def get_supported_source(source_name: str) -> SupportedSource: return _SUPPORTED_SOURCES[source_name] else: raise RuntimeError(f'Source:{source_name} has not been registered.') + + +def register_model_function( + source_name: str, + model_function_name: str, + model_function: source.SourceProfileFunction, + runtime_params_class: Type[runtime_params.RuntimeParams], + source_builder_class: source.SourceBuilderProtocol | None = None, + links_back: bool = False, +) -> None: + """Register a model function by adding to one of the supported sources in the registry.""" + if source_name not in _SUPPORTED_SOURCES: + raise ValueError(f'Source:{source_name} not found under supported sources.') + if model_function in _SUPPORTED_SOURCES[source_name].model_functions: + raise ValueError( + f'Model function:{model_function} has already been registered for' + f' source:{source_name}.' + ) + registered_source = _SUPPORTED_SOURCES[source_name] + registered_source.model_functions[model_function_name] = ModelFunction( + source_profile_function=model_function, + runtime_params_class=runtime_params_class, + source_builder_class=source_builder_class, + links_back=links_back, + ) diff --git a/torax/sources/source.py b/torax/sources/source.py index 7539e88c..5cf4a5cf 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -107,6 +107,7 @@ class Source(abc.ABC): are in turn used to compute coeffs in sim.py. Attributes: + SOURCE_NAME: The name of the source. DEFAULT_MODEL_FUNCTION_NAME: The name of the model function used with this source if another isn't specified. runtime_params: Input dataclass containing all the source-specific runtime @@ -127,6 +128,7 @@ class Source(abc.ABC): affected_core_profiles_ints: Derived property from the affected_core_profiles. Integer values of those enums. """ + SOURCE_NAME: ClassVar[str] = 'source' DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default' model_func: SourceProfileFunction | None = None