Skip to content

Commit

Permalink
Make SourceProfileFunction as a Protocol.
Browse files Browse the repository at this point in the history
Drive-by:
- fix the broken type check in the auto SourceBuilder builder. Currently all types are stringified (with `from __future__ import annotations`) so this branch isn't hit anyway. In a follow up we will replace this flexible logic with logic that just checks the runtime params and model func explicitly (rather than the general `Builder` logic we have currently).
PiperOrigin-RevId: 707541015
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent 542bbf4 commit 0ab1a7e
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,23 @@
from torax.sources import runtime_params as runtime_params_lib


# Sources implement these functions to be able to provide source profiles.
# pytype bug: 'source_models.SourceModels' not treated as forward reference
SourceProfileFunction: TypeAlias = Callable[ # pytype: disable=name-error
[ # Arguments
runtime_params_slice.StaticRuntimeParamsSlice, # Static runtime params.
runtime_params_slice.DynamicRuntimeParamsSlice, # General config params
geometry.Geometry,
str, # Source name
state.CoreProfiles,
Optional['source_models.SourceModels'],
],
# Returns a JAX array, tuple of arrays, or mapping of arrays.
chex.ArrayTree,
]
# pytype: disable=name-error
@typing.runtime_checkable
class SourceProfileFunction(Protocol):
"""Sources implement these functions to be able to provide source profiles."""

def __call__(
self,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: Optional['source_models.SourceModels'],
) -> chex.ArrayTree:
...
# pytype: enable=name-error


# Any callable which takes the dynamic runtime_params, geometry, and optional
Expand Down Expand Up @@ -552,8 +555,16 @@ def check_kwargs(source_init_kwargs, context_msg):
type(f.type) == types.GenericAlias # pylint: disable=unidiomatic-typecheck
or typing.get_origin(f.type) is not None
):
pass

# For `Union`s check if the value is a member of the union.
# `typing.Union` is for types defined with `Union[A, B, C]` syntax.
# `types.UnionType` is for types defined with `A | B | C` syntax.
if typing.get_origin(f.type) in [typing.Union, types.UnionType]:
if not isinstance(v, typing.get_args(f.type)):
raise TypeError(
f'While {context_msg} {source_type} got argument '
f'{f.name} of type {type(v)} but expected '
f'{f.type}).'
)
else:
try:
type_works = isinstance(v, f.type)
Expand Down

0 comments on commit 0ab1a7e

Please sign in to comment.