Skip to content

Commit

Permalink
Allow addition of model functions (or builders) to the source registry.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708301156
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent 9b6ce80 commit 88c4622
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
25 changes: 25 additions & 0 deletions torax/sources/register_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,28 @@ def get_supported_source(source_name: str) -> SupportedSource:
return _SUPPORTED_SOURCES[source_name]
else:
raise RuntimeError(f'Source:{source_name} has not been registered.')


def register_model_function(
source_name: str,
model_function_name: str,
model_function: source.SourceProfileFunction,
runtime_params_class: Type[runtime_params.RuntimeParams],
source_builder_class: source.SourceBuilderProtocol | None = None,
links_back: bool = False,
) -> None:
"""Register a model function by adding to one of the supported sources in the registry."""
if source_name not in _SUPPORTED_SOURCES:
raise ValueError(f'Source:{source_name} not found under supported sources.')
if model_function in _SUPPORTED_SOURCES[source_name].model_functions:
raise ValueError(
f'Model function:{model_function} has already been registered for'
f' source:{source_name}.'
)
registered_source = _SUPPORTED_SOURCES[source_name]
registered_source.model_functions[model_function_name] = ModelFunction(
source_profile_function=model_function,
runtime_params_class=runtime_params_class,
source_builder_class=source_builder_class,
links_back=links_back,
)
2 changes: 2 additions & 0 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class Source(abc.ABC):
are in turn used to compute coeffs in sim.py.
Attributes:
SOURCE_NAME: The name of the source.
DEFAULT_MODEL_FUNCTION_NAME: The name of the model function used with this
source if another isn't specified.
runtime_params: Input dataclass containing all the source-specific runtime
Expand All @@ -127,6 +128,7 @@ class Source(abc.ABC):
affected_core_profiles_ints: Derived property from the
affected_core_profiles. Integer values of those enums.
"""
SOURCE_NAME: ClassVar[str] = 'source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default'
model_func: SourceProfileFunction | None = None

Expand Down

0 comments on commit 88c4622

Please sign in to comment.