diff --git a/equinox/_ad.py b/equinox/_ad.py index 665f7792..4ad3129d 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -18,6 +18,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.numpy as jnp import jax.tree_util as jtu @@ -598,7 +599,7 @@ class _ClosureConvert(Module): # Important that `jaxpr` be a leaf (and not static), so that it is a tuple element # when passing through `filter_primitive_bind` and thus visible to # `jax.core.subjaxprs` - jaxpr: jax.core.Jaxpr + jaxpr: jax.extend.core.Jaxpr consts: PyTree[ArrayLike] # Captured in the PyTree structure of _ClosureConvert in_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True) out_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True) diff --git a/equinox/_make_jaxpr.py b/equinox/_make_jaxpr.py index f26c6439..fb9ff4db 100644 --- a/equinox/_make_jaxpr.py +++ b/equinox/_make_jaxpr.py @@ -4,7 +4,7 @@ import jax import jax._src.traceback_util as traceback_util -import jax.core +import jax.extend.core import jax.tree_util as jtu from jaxtyping import PyTree @@ -49,7 +49,7 @@ def _fn(*_dynamic_flat): def filter_make_jaxpr( fun: Callable[_P, Any], ) -> Callable[ - _P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]] + _P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]] ]: """As `jax.make_jaxpr`, but accepts arbitrary PyTrees as input and output. diff --git a/equinox/_unvmap.py b/equinox/_unvmap.py index d6855e5f..d3db6e8c 100644 --- a/equinox/_unvmap.py +++ b/equinox/_unvmap.py @@ -2,6 +2,7 @@ import jax import jax.core +import jax.extend.core import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir import jax.numpy as jnp @@ -10,7 +11,7 @@ # unvmap_all -unvmap_all_p = jax.core.Primitive("unvmap_all") +unvmap_all_p = jax.extend.core.Primitive("unvmap_all") def unvmap_all(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]: @@ -41,7 +42,7 @@ def _unvmap_all_batch(x, batch_axes): # unvmap_any -unvmap_any_p = jax.core.Primitive("unvmap_any") +unvmap_any_p = jax.extend.core.Primitive("unvmap_any") def unvmap_any(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]: @@ -72,7 +73,7 @@ def _unvmap_any_batch(x, batch_axes): # unvmap_max -unvmap_max_p = jax.core.Primitive("unvmap_max") +unvmap_max_p = jax.extend.core.Primitive("unvmap_max") def unvmap_max(x: Int[ArrayLike, "..."]) -> Int[Array, ""]: diff --git a/equinox/debug/_announce_transform.py b/equinox/debug/_announce_transform.py index 9e6af450..6de5487a 100644 --- a/equinox/debug/_announce_transform.py +++ b/equinox/debug/_announce_transform.py @@ -2,7 +2,7 @@ from typing import Any import jax -import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -124,7 +124,7 @@ def _mlir(*x, stack, name, intermediates, announce): return x -announce_jaxpr_p = jax.core.Primitive("announce_jaxpr") +announce_jaxpr_p = jax.extend.core.Primitive("announce_jaxpr") announce_jaxpr_p.multiple_results = True announce_jaxpr_p.def_impl(_impl) announce_jaxpr_p.def_abstract_eval(_abstract) diff --git a/equinox/internal/_finalise_jaxpr.py b/equinox/internal/_finalise_jaxpr.py index 0edeb951..a01775ae 100644 --- a/equinox/internal/_finalise_jaxpr.py +++ b/equinox/internal/_finalise_jaxpr.py @@ -22,6 +22,7 @@ import jax import jax.core import jax.custom_derivatives +import jax.extend.core import jax.tree_util as jtu from jaxtyping import PyTree @@ -36,13 +37,13 @@ def _safe_map(f, *args): def _maybe_finalise_jaxpr(val: Any): is_open_jaxpr = False - if isinstance(val, jax.core.Jaxpr): + if isinstance(val, jax.extend.core.Jaxpr): if len(val.constvars) == 0: is_open_jaxpr = True - val = jax.core.ClosedJaxpr(val, []) + val = jax.extend.core.ClosedJaxpr(val, []) else: return val - if isinstance(val, jax.core.ClosedJaxpr): + if isinstance(val, jax.extend.core.ClosedJaxpr): val = finalise_jaxpr(val) if is_open_jaxpr: val = val.jaxpr @@ -60,33 +61,33 @@ def _finalise_jaxprs_in_params(params): return new_params -def _default_finalisation(prim: jax.core.Primitive, *args, **kwargs): +def _default_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs): return prim.bind(*args, **kwargs) -def _impl_finalisation(prim: jax.core.Primitive, *args, **kwargs): +def _impl_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs): return prim.impl(*args, **kwargs) primitive_finalisations = {} -def register_impl_finalisation(prim: jax.core.Primitive): +def register_impl_finalisation(prim: jax.extend.core.Primitive): primitive_finalisations[prim] = ft.partial(_impl_finalisation, prim) -def finalise_eval_jaxpr(jaxpr: jax.core.Jaxpr, consts, *args): +def finalise_eval_jaxpr(jaxpr: jax.extend.core.Jaxpr, consts, *args): """As jax.core.eval_jaxpr, but finalises (typically by calling `impl` rather than `bind` for custom primitives). """ def read(v: jax.core.Atom) -> Any: - return v.val if isinstance(v, jax.core.Literal) else env[v] + return v.val if isinstance(v, jax.extend.core.Literal) else env[v] - def write(v: jax.core.Var, val: Any) -> None: + def write(v: jax.extend.core.Var, val: Any) -> None: env[v] = val - env: dict[jax.core.Var, Any] = {} + env: dict[jax.extend.core.Var, Any] = {} _safe_map(write, jaxpr.constvars, consts) _safe_map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: @@ -104,18 +105,18 @@ def write(v: jax.core.Var, val: Any) -> None: return _safe_map(read, jaxpr.outvars) -def finalise_jaxpr_as_fn(jaxpr: jax.core.ClosedJaxpr): +def finalise_jaxpr_as_fn(jaxpr: jax.extend.core.ClosedJaxpr): """As `jax.core.jaxpr_as_fn`, but the result is finalised.""" return ft.partial(finalise_eval_jaxpr, jaxpr.jaxpr, jaxpr.consts) -def finalise_jaxpr(jaxpr: jax.core.ClosedJaxpr) -> jax.core.ClosedJaxpr: +def finalise_jaxpr(jaxpr: jax.extend.core.ClosedJaxpr) -> jax.extend.core.ClosedJaxpr: """A jaxpr-to-jaxpr transformation that performs finalisation.""" fn = finalise_jaxpr_as_fn(jaxpr) args = [ jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in jaxpr.jaxpr.invars ] - return cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args)) + return cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args)) def finalise_fn(fn): @@ -136,13 +137,15 @@ def _finalise_fn(*args): @overload def finalise_make_jaxpr( fn, *, return_shape: Literal[False] = False -) -> Callable[..., jax.core.ClosedJaxpr]: ... +) -> Callable[..., jax.extend.core.ClosedJaxpr]: ... @overload def finalise_make_jaxpr( fn, *, return_shape: Literal[True] = True -) -> Callable[..., tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ... +) -> Callable[ + ..., tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]] +]: ... @overload @@ -151,7 +154,8 @@ def finalise_make_jaxpr( ) -> Callable[ ..., Union[ - jax.core.ClosedJaxpr, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]] + jax.extend.core.ClosedJaxpr, + tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]], ], ]: ... @@ -164,12 +168,12 @@ def _finalise_make_jaxpr(*args): *args ) if return_shape: - jaxpr_struct = cast(tuple[jax.core.ClosedJaxpr, Any], jaxpr_struct) + jaxpr_struct = cast(tuple[jax.extend.core.ClosedJaxpr, Any], jaxpr_struct) jaxpr, struct = jaxpr_struct jaxpr = finalise_jaxpr(jaxpr) return jaxpr, struct else: - jaxpr_struct = cast(jax.core.ClosedJaxpr, jaxpr_struct) + jaxpr_struct = cast(jax.extend.core.ClosedJaxpr, jaxpr_struct) jaxpr = finalise_jaxpr(jaxpr_struct) return jaxpr diff --git a/equinox/internal/_loop/common.py b/equinox/internal/_loop/common.py index 5c650559..d72479de 100644 --- a/equinox/internal/_loop/common.py +++ b/equinox/internal/_loop/common.py @@ -2,7 +2,7 @@ from typing import Any, TYPE_CHECKING, Union import jax -import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -105,7 +105,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes): return out, out_axis -select_if_vmap_p = jax.core.Primitive("select_if_vmap") +select_if_vmap_p = jax.extend.core.Primitive("select_if_vmap") select_if_vmap_p.def_impl(_select_if_vmap_impl) select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract) ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp diff --git a/equinox/internal/_noinline.py b/equinox/internal/_noinline.py index 47b38ca3..c3a005a3 100644 --- a/equinox/internal/_noinline.py +++ b/equinox/internal/_noinline.py @@ -4,6 +4,7 @@ import jax import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -330,7 +331,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs): return result -noinline_p = jax.core.Primitive("noinline") +noinline_p = jax.extend.core.Primitive("noinline") noinline_p.multiple_results = True noinline_p.def_impl(_noinline_impl) noinline_p.def_abstract_eval(_noinline_abstract) diff --git a/equinox/internal/_nontraceable.py b/equinox/internal/_nontraceable.py index 2f313fc3..48f01652 100644 --- a/equinox/internal/_nontraceable.py +++ b/equinox/internal/_nontraceable.py @@ -6,7 +6,7 @@ from typing import Optional import jax -import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -29,7 +29,7 @@ def _error(*args, name): return _error -nontraceable_p = jax.core.Primitive("nontraceable") +nontraceable_p = jax.extend.core.Primitive("nontraceable") nontraceable_p.def_impl(_nontraceable_impl) nontraceable_p.def_abstract_eval(_nontraceable_impl) ad.primitive_jvps[nontraceable_p] = _make_error("differentiation") @@ -53,7 +53,7 @@ def nontraceable(x, *, name="nontraceable operation"): return combine(dynamic, static) -nondifferentiable_backward_p = jax.core.Primitive("nondifferentiable_backward") +nondifferentiable_backward_p = jax.extend.core.Primitive("nondifferentiable_backward") def _nondifferentiable_backward_batch(x, batch_axes, *, msg, symbolic): @@ -137,7 +137,7 @@ def _cannot_batch(x, b, *, msg, allow_constant_across_batch): raise ValueError(msg) -nonbatchable_p = jax.core.Primitive("nonbatchable") +nonbatchable_p = jax.extend.core.Primitive("nonbatchable") nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x) nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x) batching.primitive_batchers[nonbatchable_p] = _cannot_batch diff --git a/equinox/internal/_primitive.py b/equinox/internal/_primitive.py index 400c6893..94ef8de6 100644 --- a/equinox/internal/_primitive.py +++ b/equinox/internal/_primitive.py @@ -3,6 +3,7 @@ import jax import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -255,7 +256,7 @@ def _wrapper(dynamic, batch_axes, *, treedef, static, flatten): return _wrapper -def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree: +def filter_primitive_bind(prim: jax.extend.core.Primitive, *args) -> PyTree: """Calls a primitive that has had its rules defined using the filter functions above. """ @@ -301,7 +302,7 @@ def materialise_zeros(primal, tangent, allow_struct=False): def create_vprim(name: str, impl, abstract_eval, jvp, transpose): - prim = jax.core.Primitive(name) + prim = jax.extend.core.Primitive(name) prim.multiple_results = True def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params): diff --git a/mkdocs.yml b/mkdocs.yml index 48eb6915..0622e74a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,8 +81,9 @@ plugins: - import jaxtyping - jaxtyping.set_array_name_format("array") - import jax + - import jax.extend.core - jax.ShapeDtypeStruct.__module__ = "jax" - - jax.core.ClosedJaxpr.__module__ = "jax.core" + - jax.extend.core.ClosedJaxpr.__module__ = "jax.core" selection: inherited_members: true # Allow looking up inherited methods diff --git a/pyproject.toml b/pyproject.toml index aad27c43..668ddbd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "equinox" version = "0.11.10" description = "Elegant easy-to-use neural networks in JAX." readme = "README.md" -requires-python =">=3.9" +requires-python =">=3.10" license = {file = "LICENSE"} authors = [ {name = "Patrick Kidger", email = "contact@kidger.site"}, @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/equinox" } -dependencies = ["jax>=0.4.13,!=0.4.27", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"] +dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"] [build-system] requires = ["hatchling"] diff --git a/tests/test_finalise_jaxpr.py b/tests/test_finalise_jaxpr.py index e347c8e7..80f4a4f8 100644 --- a/tests/test_finalise_jaxpr.py +++ b/tests/test_finalise_jaxpr.py @@ -3,6 +3,7 @@ import equinox.internal as eqxi import jax import jax.core +import jax.extend.core import jax.lax as lax import jax.numpy as jnp @@ -23,7 +24,9 @@ def _assert_vars_equal(obj1, obj2, varnames): assert a.aval.strip_weak_type() == b.aval.strip_weak_type() -def _assert_jaxpr_equal(jaxpr1: jax.core.ClosedJaxpr, jaxpr2: jax.core.ClosedJaxpr): +def _assert_jaxpr_equal( + jaxpr1: jax.extend.core.ClosedJaxpr, jaxpr2: jax.extend.core.ClosedJaxpr +): assert jaxpr1.consts == jaxpr2.consts jaxpr1 = jaxpr1.jaxpr jaxpr2 = jaxpr2.jaxpr @@ -41,7 +44,7 @@ def fn(x): x = x * 2 return x - jaxpr = cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(1)) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(1)) jaxpr2 = eqxi.finalise_jaxpr(jaxpr) _assert_jaxpr_equal(jaxpr, jaxpr2) @@ -53,13 +56,13 @@ def fn(x): x = jnp.invert(x) return x - jaxpr = cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(True)) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(True)) jaxpr2 = eqxi.finalise_jaxpr(jaxpr) jaxpr3 = eqxi.finalise_jaxpr(jaxpr2) _assert_jaxpr_equal(jaxpr2, jaxpr3) jaxpr = jax.make_jaxpr(jax.vmap(fn))(jnp.array([True, False])) - jaxpr = cast(jax.core.ClosedJaxpr, jaxpr) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) jaxpr2 = eqxi.finalise_jaxpr(jaxpr) jaxpr3 = eqxi.finalise_jaxpr(jaxpr2) _assert_jaxpr_equal(jaxpr2, jaxpr3) @@ -78,9 +81,9 @@ def fn(x): assert tree_allclose(fn(-1), finalised_fn(-1)) jaxpr = jax.make_jaxpr(fn)(1) - jaxpr = cast(jax.core.ClosedJaxpr, jaxpr) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) finalised_jaxpr = jax.make_jaxpr(finalised_fn)(1) - finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_jaxpr) + finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_jaxpr) _assert_jaxpr_equal(finalised_jaxpr, jaxpr) @@ -96,9 +99,11 @@ def fn(x): assert tree_allclose(fn(True), finalised_fn(True)) finalised_jaxpr = jax.make_jaxpr(finalised_fn)(True) - finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_jaxpr) + finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_jaxpr) finalised_finalised_jaxpr = jax.make_jaxpr(eqxi.finalise_fn(finalised_fn))(True) - finalised_finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_finalised_jaxpr) + finalised_finalised_jaxpr = cast( + jax.extend.core.ClosedJaxpr, finalised_finalised_jaxpr + ) _assert_jaxpr_equal(finalised_jaxpr, finalised_finalised_jaxpr) for eqn in finalised_jaxpr.eqns: assert eqn.primitive != eqxi.unvmap_any_p @@ -114,19 +119,19 @@ def fn(x): assert tree_allclose(vmap_fn(arg), finalised_vmap_fn(arg)) finalised_vmap_jaxpr = jax.make_jaxpr(finalised_vmap_fn)(jnp.array([False, False])) - finalised_vmap_jaxpr = cast(jax.core.ClosedJaxpr, finalised_vmap_jaxpr) + finalised_vmap_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_vmap_jaxpr) finalised_finalised_vmap_jaxpr = jax.make_jaxpr( eqxi.finalise_fn(finalised_vmap_fn) )(jnp.array([False, False])) finalised_finalised_vmap_jaxpr = cast( - jax.core.ClosedJaxpr, finalised_finalised_vmap_jaxpr + jax.extend.core.ClosedJaxpr, finalised_finalised_vmap_jaxpr ) for eqn in finalised_vmap_jaxpr.eqns: assert eqn.primitive != eqxi.unvmap_any_p _assert_jaxpr_equal(finalised_vmap_jaxpr, finalised_finalised_vmap_jaxpr) -def _assert_no_unvmap(jaxpr: jax.core.Jaxpr): +def _assert_no_unvmap(jaxpr: jax.extend.core.Jaxpr): for eqn in jaxpr.eqns: assert eqn.primitive not in (eqxi.unvmap_any_p, eqxi.unvmap_all_p) for subjaxpr in jax.core.subjaxprs(jaxpr): diff --git a/tests/test_nontraceable.py b/tests/test_nontraceable.py index 855fe1fc..23839d79 100644 --- a/tests/test_nontraceable.py +++ b/tests/test_nontraceable.py @@ -4,6 +4,7 @@ import equinox.internal as eqxi import jax import jax.core +import jax.extend.core import jax.numpy as jnp import jax.tree_util as jtu import pytest @@ -73,7 +74,7 @@ def run(dynamic, static): jax.vmap(run, in_axes=(0, None))(dynamic_batch, static) jaxpr = jax.make_jaxpr(run, static_argnums=1)(dynamic, static) - jaxpr = cast(jax.core.ClosedJaxpr, jaxpr) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) run2 = jax.core.jaxpr_as_fun(jaxpr) run2(*dynamic_flat) # pyright: ignore diff --git a/tests/test_primitive.py b/tests/test_primitive.py index d8d28597..759a53ca 100644 --- a/tests/test_primitive.py +++ b/tests/test_primitive.py @@ -4,6 +4,7 @@ import equinox.internal as eqxi import jax import jax.core +import jax.extend.core import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -14,7 +15,7 @@ def test_call(): - newprim_p = jax.core.Primitive("newprim") + newprim_p = jax.extend.core.Primitive("newprim") newprim_p.multiple_results = True newprim = ft.partial(eqxi.filter_primitive_bind, newprim_p)