Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade jax and fix deprecation warnings #915

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.numpy as jnp
import jax.tree_util as jtu
Expand Down Expand Up @@ -598,7 +599,7 @@ class _ClosureConvert(Module):
# Important that `jaxpr` be a leaf (and not static), so that it is a tuple element
# when passing through `filter_primitive_bind` and thus visible to
# `jax.core.subjaxprs`
jaxpr: jax.core.Jaxpr
jaxpr: jax.extend.core.Jaxpr
consts: PyTree[ArrayLike] # Captured in the PyTree structure of _ClosureConvert
in_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
out_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
Expand Down
4 changes: 2 additions & 2 deletions equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend.core
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand Down Expand Up @@ -49,7 +49,7 @@ def _fn(*_dynamic_flat):
def filter_make_jaxpr(
fun: Callable[_P, Any],
) -> Callable[
_P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
_P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
]:
"""As `jax.make_jaxpr`, but accepts arbitrary PyTrees as input and output.

Expand Down
7 changes: 4 additions & 3 deletions equinox/_unvmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
Expand All @@ -10,7 +11,7 @@

# unvmap_all

unvmap_all_p = jax.core.Primitive("unvmap_all")
unvmap_all_p = jax.extend.core.Primitive("unvmap_all")


def unvmap_all(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
Expand Down Expand Up @@ -41,7 +42,7 @@ def _unvmap_all_batch(x, batch_axes):

# unvmap_any

unvmap_any_p = jax.core.Primitive("unvmap_any")
unvmap_any_p = jax.extend.core.Primitive("unvmap_any")


def unvmap_any(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
Expand Down Expand Up @@ -72,7 +73,7 @@ def _unvmap_any_batch(x, batch_axes):

# unvmap_max

unvmap_max_p = jax.core.Primitive("unvmap_max")
unvmap_max_p = jax.extend.core.Primitive("unvmap_max")


def unvmap_max(x: Int[ArrayLike, "..."]) -> Int[Array, ""]:
Expand Down
4 changes: 2 additions & 2 deletions equinox/debug/_announce_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -124,7 +124,7 @@ def _mlir(*x, stack, name, intermediates, announce):
return x


announce_jaxpr_p = jax.core.Primitive("announce_jaxpr")
announce_jaxpr_p = jax.extend.core.Primitive("announce_jaxpr")
announce_jaxpr_p.multiple_results = True
announce_jaxpr_p.def_impl(_impl)
announce_jaxpr_p.def_abstract_eval(_abstract)
Expand Down
40 changes: 22 additions & 18 deletions equinox/internal/_finalise_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import jax.core
import jax.custom_derivatives
import jax.extend.core
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand All @@ -36,13 +37,13 @@ def _safe_map(f, *args):

def _maybe_finalise_jaxpr(val: Any):
is_open_jaxpr = False
if isinstance(val, jax.core.Jaxpr):
if isinstance(val, jax.extend.core.Jaxpr):
if len(val.constvars) == 0:
is_open_jaxpr = True
val = jax.core.ClosedJaxpr(val, [])
val = jax.extend.core.ClosedJaxpr(val, [])
else:
return val
if isinstance(val, jax.core.ClosedJaxpr):
if isinstance(val, jax.extend.core.ClosedJaxpr):
val = finalise_jaxpr(val)
if is_open_jaxpr:
val = val.jaxpr
Expand All @@ -60,33 +61,33 @@ def _finalise_jaxprs_in_params(params):
return new_params


def _default_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _default_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.bind(*args, **kwargs)


def _impl_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _impl_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.impl(*args, **kwargs)


primitive_finalisations = {}


def register_impl_finalisation(prim: jax.core.Primitive):
def register_impl_finalisation(prim: jax.extend.core.Primitive):
primitive_finalisations[prim] = ft.partial(_impl_finalisation, prim)


def finalise_eval_jaxpr(jaxpr: jax.core.Jaxpr, consts, *args):
def finalise_eval_jaxpr(jaxpr: jax.extend.core.Jaxpr, consts, *args):
"""As jax.core.eval_jaxpr, but finalises (typically by calling `impl` rather than
`bind` for custom primitives).
"""

def read(v: jax.core.Atom) -> Any:
return v.val if isinstance(v, jax.core.Literal) else env[v]
return v.val if isinstance(v, jax.extend.core.Literal) else env[v]

def write(v: jax.core.Var, val: Any) -> None:
def write(v: jax.extend.core.Var, val: Any) -> None:
env[v] = val

env: dict[jax.core.Var, Any] = {}
env: dict[jax.extend.core.Var, Any] = {}
_safe_map(write, jaxpr.constvars, consts)
_safe_map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
Expand All @@ -104,18 +105,18 @@ def write(v: jax.core.Var, val: Any) -> None:
return _safe_map(read, jaxpr.outvars)


def finalise_jaxpr_as_fn(jaxpr: jax.core.ClosedJaxpr):
def finalise_jaxpr_as_fn(jaxpr: jax.extend.core.ClosedJaxpr):
"""As `jax.core.jaxpr_as_fn`, but the result is finalised."""
return ft.partial(finalise_eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)


def finalise_jaxpr(jaxpr: jax.core.ClosedJaxpr) -> jax.core.ClosedJaxpr:
def finalise_jaxpr(jaxpr: jax.extend.core.ClosedJaxpr) -> jax.extend.core.ClosedJaxpr:
"""A jaxpr-to-jaxpr transformation that performs finalisation."""
fn = finalise_jaxpr_as_fn(jaxpr)
args = [
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in jaxpr.jaxpr.invars
]
return cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))
return cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))


def finalise_fn(fn):
Expand All @@ -136,13 +137,15 @@ def _finalise_fn(*args):
@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[False] = False
) -> Callable[..., jax.core.ClosedJaxpr]: ...
) -> Callable[..., jax.extend.core.ClosedJaxpr]: ...


@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[True] = True
) -> Callable[..., tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ...
) -> Callable[
..., tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
]: ...


@overload
Expand All @@ -151,7 +154,8 @@ def finalise_make_jaxpr(
) -> Callable[
...,
Union[
jax.core.ClosedJaxpr, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
jax.extend.core.ClosedJaxpr,
tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]],
],
]: ...

Expand All @@ -164,12 +168,12 @@ def _finalise_make_jaxpr(*args):
*args
)
if return_shape:
jaxpr_struct = cast(tuple[jax.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr_struct = cast(tuple[jax.extend.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr, struct = jaxpr_struct
jaxpr = finalise_jaxpr(jaxpr)
return jaxpr, struct
else:
jaxpr_struct = cast(jax.core.ClosedJaxpr, jaxpr_struct)
jaxpr_struct = cast(jax.extend.core.ClosedJaxpr, jaxpr_struct)
jaxpr = finalise_jaxpr(jaxpr_struct)
return jaxpr

Expand Down
4 changes: 2 additions & 2 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, TYPE_CHECKING, Union

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -105,7 +105,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
return out, out_axis


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
select_if_vmap_p = jax.extend.core.Primitive("select_if_vmap")
select_if_vmap_p.def_impl(_select_if_vmap_impl)
select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract)
ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp
Expand Down
3 changes: 2 additions & 1 deletion equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -330,7 +331,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):
return result


noinline_p = jax.core.Primitive("noinline")
noinline_p = jax.extend.core.Primitive("noinline")
noinline_p.multiple_results = True
noinline_p.def_impl(_noinline_impl)
noinline_p.def_abstract_eval(_noinline_abstract)
Expand Down
8 changes: 4 additions & 4 deletions equinox/internal/_nontraceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand All @@ -29,7 +29,7 @@ def _error(*args, name):
return _error


nontraceable_p = jax.core.Primitive("nontraceable")
nontraceable_p = jax.extend.core.Primitive("nontraceable")
nontraceable_p.def_impl(_nontraceable_impl)
nontraceable_p.def_abstract_eval(_nontraceable_impl)
ad.primitive_jvps[nontraceable_p] = _make_error("differentiation")
Expand All @@ -53,7 +53,7 @@ def nontraceable(x, *, name="nontraceable operation"):
return combine(dynamic, static)


nondifferentiable_backward_p = jax.core.Primitive("nondifferentiable_backward")
nondifferentiable_backward_p = jax.extend.core.Primitive("nondifferentiable_backward")


def _nondifferentiable_backward_batch(x, batch_axes, *, msg, symbolic):
Expand Down Expand Up @@ -137,7 +137,7 @@ def _cannot_batch(x, b, *, msg, allow_constant_across_batch):
raise ValueError(msg)


nonbatchable_p = jax.core.Primitive("nonbatchable")
nonbatchable_p = jax.extend.core.Primitive("nonbatchable")
nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x)
batching.primitive_batchers[nonbatchable_p] = _cannot_batch
Expand Down
5 changes: 3 additions & 2 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -255,7 +256,7 @@ def _wrapper(dynamic, batch_axes, *, treedef, static, flatten):
return _wrapper


def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree:
def filter_primitive_bind(prim: jax.extend.core.Primitive, *args) -> PyTree:
"""Calls a primitive that has had its rules defined using the filter
functions above.
"""
Expand Down Expand Up @@ -301,7 +302,7 @@ def materialise_zeros(primal, tangent, allow_struct=False):


def create_vprim(name: str, impl, abstract_eval, jvp, transpose):
prim = jax.core.Primitive(name)
prim = jax.extend.core.Primitive(name)
prim.multiple_results = True

def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
Expand Down
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ plugins:
- import jaxtyping
- jaxtyping.set_array_name_format("array")
- import jax
- import jax.extend.core
- jax.ShapeDtypeStruct.__module__ = "jax"
- jax.core.ClosedJaxpr.__module__ = "jax.core"
- jax.extend.core.ClosedJaxpr.__module__ = "jax.core"

selection:
inherited_members: true # Allow looking up inherited methods
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "equinox"
version = "0.11.10"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
requires-python =">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Patrick Kidger", email = "[email protected]"},
Expand All @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/equinox" }
dependencies = ["jax>=0.4.13,!=0.4.27", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]

[build-system]
requires = ["hatchling"]
Expand Down
Loading
Loading