From 7ee4ca944d75c33d1403122f7ccf141bc390a55e Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 21 Dec 2024 22:51:14 +0100 Subject: [PATCH] More fixes for warnings raised in JAX 0.4.38. --- equinox/_ad.py | 4 +++- equinox/_callback.py | 30 +++++++++++++++++++++--------- equinox/debug/_dce.py | 10 ++++++---- mkdocs.yml | 2 +- pyproject.toml | 2 +- tests/test_nontraceable.py | 2 +- 6 files changed, 33 insertions(+), 17 deletions(-) diff --git a/equinox/_ad.py b/equinox/_ad.py index 4ad3129d..3527dc97 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -881,13 +881,15 @@ def _none_to_zero(ct, x): if x is None: return None else: - aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) + aval = jax.core.get_aval(x) if hasattr(aval, "to_tangent_aval"): # Earlier versions of JAX were internally inconsistent, and expected # e.g. integer primals to have integer tangents from `custom_{jvp,vjp}` # rules. # That changed in JAX 0.4.34. aval = aval.to_tangent_aval() # pyright: ignore + else: + aval = jax.core.raise_to_shaped(aval) # pyright: ignore return jax.custom_derivatives.SymbolicZero(aval) else: return ct diff --git a/equinox/_callback.py b/equinox/_callback.py index fb6f322c..fd3949aa 100644 --- a/equinox/_callback.py +++ b/equinox/_callback.py @@ -12,22 +12,29 @@ def filter_pure_callback( callback, *args, result_shape_dtypes, - vectorized=False, + sharding=None, + vmap_method=None, + vectorized=None, **kwargs, ): """Calls a Python function inside a JIT region. As `jax.pure_callback` but accepts arbitrary Python objects as inputs and outputs. (Not just JAXable types.) + Note that unlike `jax.pure_callback`, then the `result_shape_dtypes` argument must + be passed as a keyword argument. + **Arguments:** - `callback`: The Python function to call. - - `args`, `kwargs`: The function will be called as `callback(*args, **kwargs)`. + - `*args`, `**kwargs`: The function will be called as `callback(*args, **kwargs)`. These may be arbitrary Python objects. - `result_shape_dtypes`: A PyTree specifying the output of `callback`. It should - have a `jax.ShapeDtypeStruct` in place of any JAX arrays. - - `vectorized`: If `True` then `callback` is batched(when transformed by `vmap`) - by calling it directly on the batched arrays. If `False` then `callback` is - called on each batch element individually. + have a `jax.ShapeDtypeStruct` in place of any JAX arrays. Note that unlike + `jax.pure_callback`, this must be passed as a keyword-only argument. + - `sharding`: optional sharding that specifies the device from which the callback + should be invoked. + - `vmap_method`, `vectorized`: these specify how the callback transforms under + `vmap()` as described in the documentation for `jax.pure_callback`. **Returns:** @@ -44,7 +51,12 @@ def _callback(_dynamic): raise ValueError("Callback did not return matching static elements") return _dynamic_out - dynamic_out = jax.pure_callback( - _callback, dynamic_struct, dynamic, vectorized=vectorized - ) + keywords = {} + if sharding is not None: + keywords["sharding"] = sharding + if vectorized is not None: + keywords["vectorized"] = vectorized + if vmap_method is not None: + keywords["vmap_method"] = vmap_method + dynamic_out = jax.pure_callback(_callback, dynamic_struct, dynamic, **keywords) return combine(dynamic_out, static_struct) diff --git a/equinox/debug/_dce.py b/equinox/debug/_dce.py index 8b70a74a..be92f9df 100644 --- a/equinox/debug/_dce.py +++ b/equinox/debug/_dce.py @@ -14,10 +14,10 @@ _dce_store = {} -def _register_alive(name: Hashable, tag: object): - def _register_alive_impl(i, x): +def _register_alive(name: Hashable, tag: object, i: int): + def _register_alive_impl(x): leaves, _, _ = _dce_store[name][tag] - leaves[i.item()] = (x.shape, x.dtype.name) + leaves[i] = (x.shape, x.dtype.name) return x return _register_alive_impl @@ -70,7 +70,9 @@ def f(x): tag_store = _dce_store[name] = {} tag_store[tag] = ({}, treedef, static) leaves = [ - jax.pure_callback(_register_alive(name, tag), x, i, x, vectorized=True) + jax.pure_callback( + _register_alive(name, tag, i), x, x, vmap_method="expand_dims" + ) for i, x in enumerate(leaves) ] dynamic_out = jtu.tree_unflatten(treedef, leaves) diff --git a/mkdocs.yml b/mkdocs.yml index 0622e74a..574bbf44 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,7 +83,7 @@ plugins: - import jax - import jax.extend.core - jax.ShapeDtypeStruct.__module__ = "jax" - - jax.extend.core.ClosedJaxpr.__module__ = "jax.core" + - jax.extend.core.ClosedJaxpr.__module__ = "jax.extend.core" selection: inherited_members: true # Allow looking up inherited methods diff --git a/pyproject.toml b/pyproject.toml index 668ddbd9..56d27f3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ build-backend = "hatchling.build" include = ["equinox/*"] [tool.pytest.ini_options] -addopts = "--jaxtyping-packages=equinox,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" +addopts = "-Werror --jaxtyping-packages=equinox,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" [tool.ruff] extend-include = ["*.ipynb"] diff --git a/tests/test_nontraceable.py b/tests/test_nontraceable.py index 23839d79..33bd93b5 100644 --- a/tests/test_nontraceable.py +++ b/tests/test_nontraceable.py @@ -75,7 +75,7 @@ def run(dynamic, static): jaxpr = jax.make_jaxpr(run, static_argnums=1)(dynamic, static) jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) - run2 = jax.core.jaxpr_as_fun(jaxpr) + run2 = jax.extend.core.jaxpr_as_fun(jaxpr) run2(*dynamic_flat) # pyright: ignore jax.jit(run2)(*dynamic_flat) # pyright: ignore