diff --git a/equinox/nn/_stateful.py b/equinox/nn/_stateful.py index d1f6bf5c..a8b413b1 100644 --- a/equinox/nn/_stateful.py +++ b/equinox/nn/_stateful.py @@ -18,6 +18,17 @@ _T = TypeVar("_T") +class _Sentinel(Module): + """A module for sentinels that can be passed dynamically.""" + + pass + + +# Used as a sentinel in two ways: keeping track of updated `State`s, and keeping track +# of deleted initial states. +_sentinel = _Sentinel() + + class StateIndex(Module, Generic[_Value], strict=True): """This wraps together (a) a unique dictionary key used for looking up a stateful value, and (b) how that stateful value should be initialised. @@ -43,10 +54,10 @@ def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: [`equinox.nn.BatchNorm`][] for further reference. """ # noqa: E501 - # Starts off as an `object` when initialised; later replaced with an `int` inside + # Starts off as None when initialised; later replaced with an `int` inside # `make_with_state`. marker: Union[object, int] = field(static=True) - init: _Value + init: Union[_Value, _Sentinel] def __init__(self, init: _Value): """**Arguments:** @@ -70,11 +81,6 @@ def _is_index(x: Any) -> bool: return isinstance(x, StateIndex) -# Used as a sentinel in two ways: keeping track of updated `State`s, and keeping track -# of deleted initial states. -_sentinel = object() - - _state_error = """ Attempted to use old state. Probably you have done something like: ``` @@ -117,13 +123,13 @@ def __init__(self, model: PyTree): leaves = jtu.tree_leaves(model, is_leaf=_is_index) for leaf in leaves: if _is_index(leaf): - if leaf.init is _sentinel: + if isinstance(leaf.init, _Sentinel): raise ValueError( "Do not call `eqx.nn.State(model)` directly. You should call " "`eqx.nn.make_with_state(ModelClass)(...args...)` instead." ) state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init) - self._state = state + self._state: Union[_Sentinel, dict[object | int, Any]] = state def get(self, item: StateIndex[_Value]) -> _Value: """Given an [`equinox.nn.StateIndex`][], returns the value of its state. @@ -136,11 +142,11 @@ def get(self, item: StateIndex[_Value]) -> _Value: The current state associated with that index. """ - if self._state is _sentinel: + if isinstance(self._state, _Sentinel): raise ValueError(_state_error) if type(item) is not StateIndex: raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.") - return self._state[item.marker] # pyright: ignore + return self._state[item.marker] def set(self, item: StateIndex[_Value], value: _Value) -> "State": """Sets a new value for an [`equinox.nn.StateIndex`][], **and returns the @@ -158,11 +164,11 @@ def set(self, item: StateIndex[_Value], value: _Value) -> "State": As a safety guard against accidentally writing `state.set(item, value)` without assigning it to a new value, then the old object (`self`) will become invalid. """ - if self._state is _sentinel: + if isinstance(self._state, _Sentinel): raise ValueError(_state_error) if type(item) is not StateIndex: raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.") - old_value = self._state[item.marker] # pyright: ignore + old_value = self._state[item.marker] value = jtu.tree_map(jnp.asarray, value) old_struct = jax.eval_shape(lambda: old_value) new_struct = jax.eval_shape(lambda: value) @@ -194,7 +200,7 @@ def substate(self, pytree: PyTree) -> "State": A new [`equinox.nn.State`][] object, which tracks only some of the overall states. """ - if self._state is _sentinel: + if isinstance(self._state, _Sentinel): raise ValueError(_state_error) leaves = jtu.tree_leaves(pytree, is_leaf=_is_index) markers = [x.marker for x in leaves if _is_index(x)] @@ -218,7 +224,7 @@ def update(self, substate: "State") -> "State": As a safety guard against accidentally writing `state.set(item, value)` without assigning it to a new value, then the old object (`self`) will become invalid. """ - if self._state is _sentinel: + if isinstance(self._state, _Sentinel): raise ValueError(_state_error) if type(substate) is not State: raise ValueError("Can only use `eqx.nn.State`s in `update`.") @@ -239,7 +245,7 @@ def __repr__(self): return tree_pformat(self) def __tree_pp__(self, **kwargs): - if self._state is _sentinel: + if isinstance(self._state, _Sentinel): return text("State(~old~)") else: objs = named_objs( @@ -258,7 +264,7 @@ def __tree_pp__(self, **kwargs): ) def tree_flatten(self): - if self._state is _sentinel: + if isinstance(self._state, _Sentinel): raise ValueError(_state_error) keys = tuple(self._state.keys()) # pyright: ignore values = tuple(self._state[k] for k in keys) # pyright: ignore