Skip to content

Commit

Permalink
More fixes for warnings raised in JAX 0.4.38.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 21, 2024
1 parent 9780faf commit 7ee4ca9
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 17 deletions.
4 changes: 3 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,13 +881,15 @@ def _none_to_zero(ct, x):
if x is None:
return None
else:
aval = jax.core.raise_to_shaped(jax.core.get_aval(x))
aval = jax.core.get_aval(x)
if hasattr(aval, "to_tangent_aval"):
# Earlier versions of JAX were internally inconsistent, and expected
# e.g. integer primals to have integer tangents from `custom_{jvp,vjp}`
# rules.
# That changed in JAX 0.4.34.
aval = aval.to_tangent_aval() # pyright: ignore
else:
aval = jax.core.raise_to_shaped(aval) # pyright: ignore
return jax.custom_derivatives.SymbolicZero(aval)
else:
return ct
Expand Down
30 changes: 21 additions & 9 deletions equinox/_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,29 @@ def filter_pure_callback(
callback,
*args,
result_shape_dtypes,
vectorized=False,
sharding=None,
vmap_method=None,
vectorized=None,
**kwargs,
):
"""Calls a Python function inside a JIT region. As `jax.pure_callback` but accepts
arbitrary Python objects as inputs and outputs. (Not just JAXable types.)
Note that unlike `jax.pure_callback`, then the `result_shape_dtypes` argument must
be passed as a keyword argument.
**Arguments:**
- `callback`: The Python function to call.
- `args`, `kwargs`: The function will be called as `callback(*args, **kwargs)`.
- `*args`, `**kwargs`: The function will be called as `callback(*args, **kwargs)`.
These may be arbitrary Python objects.
- `result_shape_dtypes`: A PyTree specifying the output of `callback`. It should
have a `jax.ShapeDtypeStruct` in place of any JAX arrays.
- `vectorized`: If `True` then `callback` is batched(when transformed by `vmap`)
by calling it directly on the batched arrays. If `False` then `callback` is
called on each batch element individually.
have a `jax.ShapeDtypeStruct` in place of any JAX arrays. Note that unlike
`jax.pure_callback`, this must be passed as a keyword-only argument.
- `sharding`: optional sharding that specifies the device from which the callback
should be invoked.
- `vmap_method`, `vectorized`: these specify how the callback transforms under
`vmap()` as described in the documentation for `jax.pure_callback`.
**Returns:**
Expand All @@ -44,7 +51,12 @@ def _callback(_dynamic):
raise ValueError("Callback did not return matching static elements")
return _dynamic_out

dynamic_out = jax.pure_callback(
_callback, dynamic_struct, dynamic, vectorized=vectorized
)
keywords = {}
if sharding is not None:
keywords["sharding"] = sharding
if vectorized is not None:
keywords["vectorized"] = vectorized
if vmap_method is not None:
keywords["vmap_method"] = vmap_method
dynamic_out = jax.pure_callback(_callback, dynamic_struct, dynamic, **keywords)
return combine(dynamic_out, static_struct)
10 changes: 6 additions & 4 deletions equinox/debug/_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
_dce_store = {}


def _register_alive(name: Hashable, tag: object):
def _register_alive_impl(i, x):
def _register_alive(name: Hashable, tag: object, i: int):
def _register_alive_impl(x):
leaves, _, _ = _dce_store[name][tag]
leaves[i.item()] = (x.shape, x.dtype.name)
leaves[i] = (x.shape, x.dtype.name)
return x

return _register_alive_impl
Expand Down Expand Up @@ -70,7 +70,9 @@ def f(x):
tag_store = _dce_store[name] = {}
tag_store[tag] = ({}, treedef, static)
leaves = [
jax.pure_callback(_register_alive(name, tag), x, i, x, vectorized=True)
jax.pure_callback(
_register_alive(name, tag, i), x, x, vmap_method="expand_dims"
)
for i, x in enumerate(leaves)
]
dynamic_out = jtu.tree_unflatten(treedef, leaves)
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ plugins:
- import jax
- import jax.extend.core
- jax.ShapeDtypeStruct.__module__ = "jax"
- jax.extend.core.ClosedJaxpr.__module__ = "jax.core"
- jax.extend.core.ClosedJaxpr.__module__ = "jax.extend.core"

selection:
inherited_members: true # Allow looking up inherited methods
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ build-backend = "hatchling.build"
include = ["equinox/*"]

[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=equinox,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
addopts = "-Werror --jaxtyping-packages=equinox,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"

[tool.ruff]
extend-include = ["*.ipynb"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nontraceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run(dynamic, static):

jaxpr = jax.make_jaxpr(run, static_argnums=1)(dynamic, static)
jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr)
run2 = jax.core.jaxpr_as_fun(jaxpr)
run2 = jax.extend.core.jaxpr_as_fun(jaxpr)

run2(*dynamic_flat) # pyright: ignore
jax.jit(run2)(*dynamic_flat) # pyright: ignore
Expand Down

0 comments on commit 7ee4ca9

Please sign in to comment.