diff --git a/equinox/_ad.py b/equinox/_ad.py index 665f7792..bbb1e7c3 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -18,6 +18,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/_enum.py b/equinox/_enum.py index 1716a60c..7d107f08 100644 --- a/equinox/_enum.py +++ b/equinox/_enum.py @@ -3,6 +3,7 @@ import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.numpy as jnp import numpy as np from jaxtyping import Array, ArrayLike, Bool, Int diff --git a/equinox/_errors.py b/equinox/_errors.py index 3297fb3d..0ce20bed 100644 --- a/equinox/_errors.py +++ b/equinox/_errors.py @@ -9,6 +9,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/_jit.py b/equinox/_jit.py index 1d6b190a..4a2d5445 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -11,6 +11,7 @@ import jax._src.dispatch import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.errors import jax.numpy as jnp from jaxtyping import PyTree diff --git a/equinox/_make_jaxpr.py b/equinox/_make_jaxpr.py index f26c6439..4ccee31a 100644 --- a/equinox/_make_jaxpr.py +++ b/equinox/_make_jaxpr.py @@ -5,6 +5,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.tree_util as jtu from jaxtyping import PyTree @@ -49,7 +50,7 @@ def _fn(*_dynamic_flat): def filter_make_jaxpr( fun: Callable[_P, Any], ) -> Callable[ - _P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]] + _P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]] ]: """As `jax.make_jaxpr`, but accepts arbitrary PyTrees as input and output. diff --git a/equinox/_misc.py b/equinox/_misc.py index 14183278..20127f31 100644 --- a/equinox/_misc.py +++ b/equinox/_misc.py @@ -1,5 +1,6 @@ import jax import jax.core +import jax.extend import jax.numpy as jnp from jaxtyping import Array diff --git a/equinox/_unvmap.py b/equinox/_unvmap.py index d6855e5f..780c1276 100644 --- a/equinox/_unvmap.py +++ b/equinox/_unvmap.py @@ -1,7 +1,9 @@ from typing import cast import jax +import jax.extend import jax.core +import jax.extend import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir import jax.numpy as jnp @@ -10,7 +12,7 @@ # unvmap_all -unvmap_all_p = jax.core.Primitive("unvmap_all") +unvmap_all_p = jax.extend.core.Primitive("unvmap_all") def unvmap_all(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]: @@ -41,7 +43,7 @@ def _unvmap_all_batch(x, batch_axes): # unvmap_any -unvmap_any_p = jax.core.Primitive("unvmap_any") +unvmap_any_p = jax.extend.core.Primitive("unvmap_any") def unvmap_any(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]: @@ -72,7 +74,7 @@ def _unvmap_any_batch(x, batch_axes): # unvmap_max -unvmap_max_p = jax.core.Primitive("unvmap_max") +unvmap_max_p = jax.extend.core.Primitive("unvmap_max") def unvmap_max(x: Int[ArrayLike, "..."]) -> Int[Array, ""]: diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index ef917500..588e2aa4 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -8,6 +8,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.numpy as jnp import jax.tree_util as jtu import numpy as np diff --git a/equinox/debug/_announce_transform.py b/equinox/debug/_announce_transform.py index 9e6af450..ee1a33f5 100644 --- a/equinox/debug/_announce_transform.py +++ b/equinox/debug/_announce_transform.py @@ -3,6 +3,7 @@ import jax import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -124,7 +125,7 @@ def _mlir(*x, stack, name, intermediates, announce): return x -announce_jaxpr_p = jax.core.Primitive("announce_jaxpr") +announce_jaxpr_p = jax.extend.core.Primitive("announce_jaxpr") announce_jaxpr_p.multiple_results = True announce_jaxpr_p.def_impl(_impl) announce_jaxpr_p.def_abstract_eval(_abstract) diff --git a/equinox/debug/_backward_nan.py b/equinox/debug/_backward_nan.py index b1c78b16..b8ba0fc0 100644 --- a/equinox/debug/_backward_nan.py +++ b/equinox/debug/_backward_nan.py @@ -3,6 +3,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/debug/_breakpoint_if.py b/equinox/debug/_breakpoint_if.py index f28e015d..780c31c5 100644 --- a/equinox/debug/_breakpoint_if.py +++ b/equinox/debug/_breakpoint_if.py @@ -1,6 +1,7 @@ import jax import jax._src.traceback_util as traceback_util import jax.core +import jax.extend import jax.lax as lax from jaxtyping import Array, Bool diff --git a/equinox/debug/_dce.py b/equinox/debug/_dce.py index 8b70a74a..10578f93 100644 --- a/equinox/debug/_dce.py +++ b/equinox/debug/_dce.py @@ -2,6 +2,7 @@ import jax import jax.core +import jax.extend import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import PyTree diff --git a/equinox/internal/_finalise_jaxpr.py b/equinox/internal/_finalise_jaxpr.py index 0edeb951..cce81690 100644 --- a/equinox/internal/_finalise_jaxpr.py +++ b/equinox/internal/_finalise_jaxpr.py @@ -21,6 +21,7 @@ import jax import jax.core +import jax.extend import jax.custom_derivatives import jax.tree_util as jtu from jaxtyping import PyTree @@ -39,10 +40,10 @@ def _maybe_finalise_jaxpr(val: Any): if isinstance(val, jax.core.Jaxpr): if len(val.constvars) == 0: is_open_jaxpr = True - val = jax.core.ClosedJaxpr(val, []) + val = jax.extend.core.ClosedJaxpr(val, []) else: return val - if isinstance(val, jax.core.ClosedJaxpr): + if isinstance(val, jax.extend.core.ClosedJaxpr): val = finalise_jaxpr(val) if is_open_jaxpr: val = val.jaxpr @@ -60,18 +61,18 @@ def _finalise_jaxprs_in_params(params): return new_params -def _default_finalisation(prim: jax.core.Primitive, *args, **kwargs): +def _default_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs): return prim.bind(*args, **kwargs) -def _impl_finalisation(prim: jax.core.Primitive, *args, **kwargs): +def _impl_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs): return prim.impl(*args, **kwargs) primitive_finalisations = {} -def register_impl_finalisation(prim: jax.core.Primitive): +def register_impl_finalisation(prim: jax.extend.core.Primitive): primitive_finalisations[prim] = ft.partial(_impl_finalisation, prim) @@ -104,18 +105,18 @@ def write(v: jax.core.Var, val: Any) -> None: return _safe_map(read, jaxpr.outvars) -def finalise_jaxpr_as_fn(jaxpr: jax.core.ClosedJaxpr): +def finalise_jaxpr_as_fn(jaxpr: jax.extend.core.ClosedJaxpr): """As `jax.core.jaxpr_as_fn`, but the result is finalised.""" return ft.partial(finalise_eval_jaxpr, jaxpr.jaxpr, jaxpr.consts) -def finalise_jaxpr(jaxpr: jax.core.ClosedJaxpr) -> jax.core.ClosedJaxpr: +def finalise_jaxpr(jaxpr: jax.extend.core.ClosedJaxpr) -> jax.extend.core.ClosedJaxpr: """A jaxpr-to-jaxpr transformation that performs finalisation.""" fn = finalise_jaxpr_as_fn(jaxpr) args = [ jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in jaxpr.jaxpr.invars ] - return cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args)) + return cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args)) def finalise_fn(fn): @@ -136,13 +137,13 @@ def _finalise_fn(*args): @overload def finalise_make_jaxpr( fn, *, return_shape: Literal[False] = False -) -> Callable[..., jax.core.ClosedJaxpr]: ... +) -> Callable[..., jax.extend.core.ClosedJaxpr]: ... @overload def finalise_make_jaxpr( fn, *, return_shape: Literal[True] = True -) -> Callable[..., tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ... +) -> Callable[..., tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ... @overload @@ -151,7 +152,7 @@ def finalise_make_jaxpr( ) -> Callable[ ..., Union[ - jax.core.ClosedJaxpr, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]] + jax.extend.core.ClosedJaxpr, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]] ], ]: ... @@ -164,12 +165,12 @@ def _finalise_make_jaxpr(*args): *args ) if return_shape: - jaxpr_struct = cast(tuple[jax.core.ClosedJaxpr, Any], jaxpr_struct) + jaxpr_struct = cast(tuple[jax.extend.core.ClosedJaxpr, Any], jaxpr_struct) jaxpr, struct = jaxpr_struct jaxpr = finalise_jaxpr(jaxpr) return jaxpr, struct else: - jaxpr_struct = cast(jax.core.ClosedJaxpr, jaxpr_struct) + jaxpr_struct = cast(jax.extend.core.ClosedJaxpr, jaxpr_struct) jaxpr = finalise_jaxpr(jaxpr_struct) return jaxpr diff --git a/equinox/internal/_loop/checkpointed.py b/equinox/internal/_loop/checkpointed.py index fd5ed7c8..3973121f 100644 --- a/equinox/internal/_loop/checkpointed.py +++ b/equinox/internal/_loop/checkpointed.py @@ -54,6 +54,7 @@ import jax import jax.core +import jax.extend import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/internal/_loop/common.py b/equinox/internal/_loop/common.py index 5c650559..3fb5c33d 100644 --- a/equinox/internal/_loop/common.py +++ b/equinox/internal/_loop/common.py @@ -3,6 +3,7 @@ import jax import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -105,7 +106,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes): return out, out_axis -select_if_vmap_p = jax.core.Primitive("select_if_vmap") +select_if_vmap_p = jax.extend.core.Primitive("select_if_vmap") select_if_vmap_p.def_impl(_select_if_vmap_impl) select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract) ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp diff --git a/equinox/internal/_noinline.py b/equinox/internal/_noinline.py index 47b38ca3..2cc598e0 100644 --- a/equinox/internal/_noinline.py +++ b/equinox/internal/_noinline.py @@ -4,6 +4,7 @@ import jax import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -330,7 +331,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs): return result -noinline_p = jax.core.Primitive("noinline") +noinline_p = jax.extend.core.Primitive("noinline") noinline_p.multiple_results = True noinline_p.def_impl(_noinline_impl) noinline_p.def_abstract_eval(_noinline_abstract) diff --git a/equinox/internal/_nontraceable.py b/equinox/internal/_nontraceable.py index 2f313fc3..e4056d18 100644 --- a/equinox/internal/_nontraceable.py +++ b/equinox/internal/_nontraceable.py @@ -7,6 +7,7 @@ import jax import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -29,7 +30,7 @@ def _error(*args, name): return _error -nontraceable_p = jax.core.Primitive("nontraceable") +nontraceable_p = jax.extend.core.Primitive("nontraceable") nontraceable_p.def_impl(_nontraceable_impl) nontraceable_p.def_abstract_eval(_nontraceable_impl) ad.primitive_jvps[nontraceable_p] = _make_error("differentiation") @@ -53,7 +54,7 @@ def nontraceable(x, *, name="nontraceable operation"): return combine(dynamic, static) -nondifferentiable_backward_p = jax.core.Primitive("nondifferentiable_backward") +nondifferentiable_backward_p = jax.extend.core.Primitive("nondifferentiable_backward") def _nondifferentiable_backward_batch(x, batch_axes, *, msg, symbolic): @@ -137,7 +138,7 @@ def _cannot_batch(x, b, *, msg, allow_constant_across_batch): raise ValueError(msg) -nonbatchable_p = jax.core.Primitive("nonbatchable") +nonbatchable_p = jax.extend.core.Primitive("nonbatchable") nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x) nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x) batching.primitive_batchers[nonbatchable_p] = _cannot_batch diff --git a/equinox/internal/_primitive.py b/equinox/internal/_primitive.py index 400c6893..71c9be58 100644 --- a/equinox/internal/_primitive.py +++ b/equinox/internal/_primitive.py @@ -3,6 +3,7 @@ import jax import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -255,7 +256,7 @@ def _wrapper(dynamic, batch_axes, *, treedef, static, flatten): return _wrapper -def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree: +def filter_primitive_bind(prim: jax.extend.core.Primitive, *args) -> PyTree: """Calls a primitive that has had its rules defined using the filter functions above. """ @@ -301,7 +302,7 @@ def materialise_zeros(primal, tangent, allow_struct=False): def create_vprim(name: str, impl, abstract_eval, jvp, transpose): - prim = jax.core.Primitive(name) + prim = jax.extend.core.Primitive(name) prim.multiple_results = True def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params): diff --git a/mkdocs.yml b/mkdocs.yml index 48eb6915..a4d8bd0f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -82,7 +82,7 @@ plugins: - jaxtyping.set_array_name_format("array") - import jax - jax.ShapeDtypeStruct.__module__ = "jax" - - jax.core.ClosedJaxpr.__module__ = "jax.core" + - jax.extend.core.ClosedJaxpr.__module__ = "jax.core" selection: inherited_members: true # Allow looking up inherited methods diff --git a/pyproject.toml b/pyproject.toml index aad27c43..668ddbd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "equinox" version = "0.11.10" description = "Elegant easy-to-use neural networks in JAX." readme = "README.md" -requires-python =">=3.9" +requires-python =">=3.10" license = {file = "LICENSE"} authors = [ {name = "Patrick Kidger", email = "contact@kidger.site"}, @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/equinox" } -dependencies = ["jax>=0.4.13,!=0.4.27", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"] +dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"] [build-system] requires = ["hatchling"] diff --git a/tests/test_finalise_jaxpr.py b/tests/test_finalise_jaxpr.py index e347c8e7..01bad6bb 100644 --- a/tests/test_finalise_jaxpr.py +++ b/tests/test_finalise_jaxpr.py @@ -3,6 +3,7 @@ import equinox.internal as eqxi import jax import jax.core +import jax.extend import jax.lax as lax import jax.numpy as jnp @@ -23,7 +24,7 @@ def _assert_vars_equal(obj1, obj2, varnames): assert a.aval.strip_weak_type() == b.aval.strip_weak_type() -def _assert_jaxpr_equal(jaxpr1: jax.core.ClosedJaxpr, jaxpr2: jax.core.ClosedJaxpr): +def _assert_jaxpr_equal(jaxpr1: jax.extend.core.ClosedJaxpr, jaxpr2: jax.extend.core.ClosedJaxpr): assert jaxpr1.consts == jaxpr2.consts jaxpr1 = jaxpr1.jaxpr jaxpr2 = jaxpr2.jaxpr @@ -41,7 +42,7 @@ def fn(x): x = x * 2 return x - jaxpr = cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(1)) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(1)) jaxpr2 = eqxi.finalise_jaxpr(jaxpr) _assert_jaxpr_equal(jaxpr, jaxpr2) @@ -53,13 +54,13 @@ def fn(x): x = jnp.invert(x) return x - jaxpr = cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(True)) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(True)) jaxpr2 = eqxi.finalise_jaxpr(jaxpr) jaxpr3 = eqxi.finalise_jaxpr(jaxpr2) _assert_jaxpr_equal(jaxpr2, jaxpr3) jaxpr = jax.make_jaxpr(jax.vmap(fn))(jnp.array([True, False])) - jaxpr = cast(jax.core.ClosedJaxpr, jaxpr) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) jaxpr2 = eqxi.finalise_jaxpr(jaxpr) jaxpr3 = eqxi.finalise_jaxpr(jaxpr2) _assert_jaxpr_equal(jaxpr2, jaxpr3) @@ -78,9 +79,9 @@ def fn(x): assert tree_allclose(fn(-1), finalised_fn(-1)) jaxpr = jax.make_jaxpr(fn)(1) - jaxpr = cast(jax.core.ClosedJaxpr, jaxpr) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) finalised_jaxpr = jax.make_jaxpr(finalised_fn)(1) - finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_jaxpr) + finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_jaxpr) _assert_jaxpr_equal(finalised_jaxpr, jaxpr) @@ -96,9 +97,9 @@ def fn(x): assert tree_allclose(fn(True), finalised_fn(True)) finalised_jaxpr = jax.make_jaxpr(finalised_fn)(True) - finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_jaxpr) + finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_jaxpr) finalised_finalised_jaxpr = jax.make_jaxpr(eqxi.finalise_fn(finalised_fn))(True) - finalised_finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_finalised_jaxpr) + finalised_finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_finalised_jaxpr) _assert_jaxpr_equal(finalised_jaxpr, finalised_finalised_jaxpr) for eqn in finalised_jaxpr.eqns: assert eqn.primitive != eqxi.unvmap_any_p @@ -114,12 +115,12 @@ def fn(x): assert tree_allclose(vmap_fn(arg), finalised_vmap_fn(arg)) finalised_vmap_jaxpr = jax.make_jaxpr(finalised_vmap_fn)(jnp.array([False, False])) - finalised_vmap_jaxpr = cast(jax.core.ClosedJaxpr, finalised_vmap_jaxpr) + finalised_vmap_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_vmap_jaxpr) finalised_finalised_vmap_jaxpr = jax.make_jaxpr( eqxi.finalise_fn(finalised_vmap_fn) )(jnp.array([False, False])) finalised_finalised_vmap_jaxpr = cast( - jax.core.ClosedJaxpr, finalised_finalised_vmap_jaxpr + jax.extend.core.ClosedJaxpr, finalised_finalised_vmap_jaxpr ) for eqn in finalised_vmap_jaxpr.eqns: assert eqn.primitive != eqxi.unvmap_any_p diff --git a/tests/test_nontraceable.py b/tests/test_nontraceable.py index 855fe1fc..d75b3496 100644 --- a/tests/test_nontraceable.py +++ b/tests/test_nontraceable.py @@ -4,6 +4,7 @@ import equinox.internal as eqxi import jax import jax.core +import jax.extend import jax.numpy as jnp import jax.tree_util as jtu import pytest @@ -73,7 +74,7 @@ def run(dynamic, static): jax.vmap(run, in_axes=(0, None))(dynamic_batch, static) jaxpr = jax.make_jaxpr(run, static_argnums=1)(dynamic, static) - jaxpr = cast(jax.core.ClosedJaxpr, jaxpr) + jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr) run2 = jax.core.jaxpr_as_fun(jaxpr) run2(*dynamic_flat) # pyright: ignore diff --git a/tests/test_primitive.py b/tests/test_primitive.py index d8d28597..4a217e86 100644 --- a/tests/test_primitive.py +++ b/tests/test_primitive.py @@ -4,6 +4,7 @@ import equinox.internal as eqxi import jax import jax.core +import jax.extend import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -14,7 +15,7 @@ def test_call(): - newprim_p = jax.core.Primitive("newprim") + newprim_p = jax.extend.core.Primitive("newprim") newprim_p.multiple_results = True newprim = ft.partial(eqxi.filter_primitive_bind, newprim_p) diff --git a/tests/test_tree.py b/tests/test_tree.py index b4cae5b1..85c85342 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -3,6 +3,7 @@ import equinox as eqx import jax import jax.core +import jax.extend import jax.nn as jnn import jax.numpy as jnp import jax.random as jrandom