diff --git a/equinox/_filters.py b/equinox/_filters.py index 77640d9d..94705772 100644 --- a/equinox/_filters.py +++ b/equinox/_filters.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, Optional, Union +from typing import Any, Optional, TypeVar, Union, overload import jax import jax.numpy as jnp @@ -163,6 +163,15 @@ def _is_none(x): return x is None +_T = TypeVar("_T", bound=PyTree) + + +@overload +def combine(*pytrees: _T, is_leaf: Optional[Callable[[Any], bool]] = None) -> _T: ... +@overload +def combine( + *pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None +) -> PyTree: ... def combine( *pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None ) -> PyTree: