diff --git a/autobound/jax/jax_bound.py b/autobound/jax/jax_bound.py index 944648b..eb98249 100644 --- a/autobound/jax/jax_bound.py +++ b/autobound/jax/jax_bound.py @@ -58,10 +58,10 @@ def upper(self, x): def taylor_bounds( - f: Callable[[jnp.ndarray], jnp.ndarray], + f: Callable[[types.NDArrayLike], types.NDArrayLike], max_degree: int, propagate_trust_regions: bool = False, -) -> Callable[[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]], TaylorBounds]: +) -> Callable[[types.NDArrayLike, tuple[types.NDArrayLike, types.NDArrayLike]], TaylorBounds]: """Returns version of f that returns a TaylorBounds object. Args: