Skip to content

Commit

Permalink
Upgrade jax and fix deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
DrJessop committed Dec 18, 2024
1 parent d7d2cb9 commit 37b7c1e
Show file tree
Hide file tree
Showing 24 changed files with 63 additions and 40 deletions.
1 change: 1 addition & 0 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.interpreters.ad as ad
import jax.numpy as jnp
import jax.tree_util as jtu
Expand Down
1 change: 1 addition & 0 deletions equinox/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, ArrayLike, Bool, Int
Expand Down
1 change: 1 addition & 0 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
Expand Down
1 change: 1 addition & 0 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax._src.dispatch
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.errors
import jax.numpy as jnp
from jaxtyping import PyTree
Expand Down
3 changes: 2 additions & 1 deletion equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand Down Expand Up @@ -49,7 +50,7 @@ def _fn(*_dynamic_flat):
def filter_make_jaxpr(
fun: Callable[_P, Any],
) -> Callable[
_P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
_P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
]:
"""As `jax.make_jaxpr`, but accepts arbitrary PyTrees as input and output.
Expand Down
1 change: 1 addition & 0 deletions equinox/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.core
import jax.extend
import jax.numpy as jnp
from jaxtyping import Array

Expand Down
8 changes: 5 additions & 3 deletions equinox/_unvmap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import cast

import jax
import jax.extend
import jax.core
import jax.extend
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
Expand All @@ -10,7 +12,7 @@

# unvmap_all

unvmap_all_p = jax.core.Primitive("unvmap_all")
unvmap_all_p = jax.extend.core.Primitive("unvmap_all")


def unvmap_all(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
Expand Down Expand Up @@ -41,7 +43,7 @@ def _unvmap_all_batch(x, batch_axes):

# unvmap_any

unvmap_any_p = jax.core.Primitive("unvmap_any")
unvmap_any_p = jax.extend.core.Primitive("unvmap_any")


def unvmap_any(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
Expand Down Expand Up @@ -72,7 +74,7 @@ def _unvmap_any_batch(x, batch_axes):

# unvmap_max

unvmap_max_p = jax.core.Primitive("unvmap_max")
unvmap_max_p = jax.extend.core.Primitive("unvmap_max")


def unvmap_max(x: Int[ArrayLike, "..."]) -> Int[Array, ""]:
Expand Down
1 change: 1 addition & 0 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand Down
3 changes: 2 additions & 1 deletion equinox/debug/_announce_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.core
import jax.extend
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -124,7 +125,7 @@ def _mlir(*x, stack, name, intermediates, announce):
return x


announce_jaxpr_p = jax.core.Primitive("announce_jaxpr")
announce_jaxpr_p = jax.extend.core.Primitive("announce_jaxpr")
announce_jaxpr_p.multiple_results = True
announce_jaxpr_p.def_impl(_impl)
announce_jaxpr_p.def_abstract_eval(_abstract)
Expand Down
1 change: 1 addition & 0 deletions equinox/debug/_backward_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.numpy as jnp
import jax.tree_util as jtu

Expand Down
1 change: 1 addition & 0 deletions equinox/debug/_breakpoint_if.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend
import jax.lax as lax
from jaxtyping import Array, Bool

Expand Down
1 change: 1 addition & 0 deletions equinox/debug/_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import jax.core
import jax.extend
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import PyTree
Expand Down
27 changes: 14 additions & 13 deletions equinox/internal/_finalise_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import jax
import jax.core
import jax.extend
import jax.custom_derivatives
import jax.tree_util as jtu
from jaxtyping import PyTree
Expand All @@ -39,10 +40,10 @@ def _maybe_finalise_jaxpr(val: Any):
if isinstance(val, jax.core.Jaxpr):
if len(val.constvars) == 0:
is_open_jaxpr = True
val = jax.core.ClosedJaxpr(val, [])
val = jax.extend.core.ClosedJaxpr(val, [])
else:
return val
if isinstance(val, jax.core.ClosedJaxpr):
if isinstance(val, jax.extend.core.ClosedJaxpr):
val = finalise_jaxpr(val)
if is_open_jaxpr:
val = val.jaxpr
Expand All @@ -60,18 +61,18 @@ def _finalise_jaxprs_in_params(params):
return new_params


def _default_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _default_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.bind(*args, **kwargs)


def _impl_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _impl_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.impl(*args, **kwargs)


primitive_finalisations = {}


def register_impl_finalisation(prim: jax.core.Primitive):
def register_impl_finalisation(prim: jax.extend.core.Primitive):
primitive_finalisations[prim] = ft.partial(_impl_finalisation, prim)


Expand Down Expand Up @@ -104,18 +105,18 @@ def write(v: jax.core.Var, val: Any) -> None:
return _safe_map(read, jaxpr.outvars)


def finalise_jaxpr_as_fn(jaxpr: jax.core.ClosedJaxpr):
def finalise_jaxpr_as_fn(jaxpr: jax.extend.core.ClosedJaxpr):
"""As `jax.core.jaxpr_as_fn`, but the result is finalised."""
return ft.partial(finalise_eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)


def finalise_jaxpr(jaxpr: jax.core.ClosedJaxpr) -> jax.core.ClosedJaxpr:
def finalise_jaxpr(jaxpr: jax.extend.core.ClosedJaxpr) -> jax.extend.core.ClosedJaxpr:
"""A jaxpr-to-jaxpr transformation that performs finalisation."""
fn = finalise_jaxpr_as_fn(jaxpr)
args = [
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in jaxpr.jaxpr.invars
]
return cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))
return cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))


def finalise_fn(fn):
Expand All @@ -136,13 +137,13 @@ def _finalise_fn(*args):
@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[False] = False
) -> Callable[..., jax.core.ClosedJaxpr]: ...
) -> Callable[..., jax.extend.core.ClosedJaxpr]: ...


@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[True] = True
) -> Callable[..., tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ...
) -> Callable[..., tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ...


@overload
Expand All @@ -151,7 +152,7 @@ def finalise_make_jaxpr(
) -> Callable[
...,
Union[
jax.core.ClosedJaxpr, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
jax.extend.core.ClosedJaxpr, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
],
]: ...

Expand All @@ -164,12 +165,12 @@ def _finalise_make_jaxpr(*args):
*args
)
if return_shape:
jaxpr_struct = cast(tuple[jax.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr_struct = cast(tuple[jax.extend.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr, struct = jaxpr_struct
jaxpr = finalise_jaxpr(jaxpr)
return jaxpr, struct
else:
jaxpr_struct = cast(jax.core.ClosedJaxpr, jaxpr_struct)
jaxpr_struct = cast(jax.extend.core.ClosedJaxpr, jaxpr_struct)
jaxpr = finalise_jaxpr(jaxpr_struct)
return jaxpr

Expand Down
1 change: 1 addition & 0 deletions equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

import jax
import jax.core
import jax.extend
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
Expand Down
3 changes: 2 additions & 1 deletion equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.core
import jax.extend
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -105,7 +106,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
return out, out_axis


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
select_if_vmap_p = jax.extend.core.Primitive("select_if_vmap")
select_if_vmap_p.def_impl(_select_if_vmap_impl)
select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract)
ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp
Expand Down
3 changes: 2 additions & 1 deletion equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.core
import jax.extend
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -330,7 +331,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):
return result


noinline_p = jax.core.Primitive("noinline")
noinline_p = jax.extend.core.Primitive("noinline")
noinline_p.multiple_results = True
noinline_p.def_impl(_noinline_impl)
noinline_p.def_abstract_eval(_noinline_abstract)
Expand Down
7 changes: 4 additions & 3 deletions equinox/internal/_nontraceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jax
import jax.core
import jax.extend
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand All @@ -29,7 +30,7 @@ def _error(*args, name):
return _error


nontraceable_p = jax.core.Primitive("nontraceable")
nontraceable_p = jax.extend.core.Primitive("nontraceable")
nontraceable_p.def_impl(_nontraceable_impl)
nontraceable_p.def_abstract_eval(_nontraceable_impl)
ad.primitive_jvps[nontraceable_p] = _make_error("differentiation")
Expand All @@ -53,7 +54,7 @@ def nontraceable(x, *, name="nontraceable operation"):
return combine(dynamic, static)


nondifferentiable_backward_p = jax.core.Primitive("nondifferentiable_backward")
nondifferentiable_backward_p = jax.extend.core.Primitive("nondifferentiable_backward")


def _nondifferentiable_backward_batch(x, batch_axes, *, msg, symbolic):
Expand Down Expand Up @@ -137,7 +138,7 @@ def _cannot_batch(x, b, *, msg, allow_constant_across_batch):
raise ValueError(msg)


nonbatchable_p = jax.core.Primitive("nonbatchable")
nonbatchable_p = jax.extend.core.Primitive("nonbatchable")
nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x)
batching.primitive_batchers[nonbatchable_p] = _cannot_batch
Expand Down
5 changes: 3 additions & 2 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.core
import jax.extend
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -255,7 +256,7 @@ def _wrapper(dynamic, batch_axes, *, treedef, static, flatten):
return _wrapper


def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree:
def filter_primitive_bind(prim: jax.extend.core.Primitive, *args) -> PyTree:
"""Calls a primitive that has had its rules defined using the filter
functions above.
"""
Expand Down Expand Up @@ -301,7 +302,7 @@ def materialise_zeros(primal, tangent, allow_struct=False):


def create_vprim(name: str, impl, abstract_eval, jvp, transpose):
prim = jax.core.Primitive(name)
prim = jax.extend.core.Primitive(name)
prim.multiple_results = True

def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ plugins:
- jaxtyping.set_array_name_format("array")
- import jax
- jax.ShapeDtypeStruct.__module__ = "jax"
- jax.core.ClosedJaxpr.__module__ = "jax.core"
- jax.extend.core.ClosedJaxpr.__module__ = "jax.core"

selection:
inherited_members: true # Allow looking up inherited methods
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "equinox"
version = "0.11.10"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
requires-python =">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Patrick Kidger", email = "[email protected]"},
Expand All @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/equinox" }
dependencies = ["jax>=0.4.13,!=0.4.27", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]

[build-system]
requires = ["hatchling"]
Expand Down
Loading

0 comments on commit 37b7c1e

Please sign in to comment.