Skip to content

Commit

Permalink
Add overloads to eqx.combine
Browse files Browse the repository at this point in the history
In the common case wherein all trees have the same structure, the return
should have the same structure.  Helps type checking a lot.
  • Loading branch information
NeilGirdhar committed Sep 14, 2024
1 parent 97ac55a commit de6c21a
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion equinox/_filters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any, Optional, Union
from typing import Any, Optional, TypeVar, Union, overload

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -163,6 +163,15 @@ def _is_none(x):
return x is None


_T = TypeVar("_T", bound=PyTree)


@overload
def combine(*pytrees: _T, is_leaf: Optional[Callable[[Any], bool]] = None) -> _T: ...
@overload
def combine(
*pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None
) -> PyTree: ...
def combine(
*pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None
) -> PyTree:
Expand Down

0 comments on commit de6c21a

Please sign in to comment.