Skip to content

Commit

Permalink
Migrate from jax.core to jax.extend.core for several deprecated symbols
Browse files Browse the repository at this point in the history
A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core.

PiperOrigin-RevId: 706180277
  • Loading branch information
Jake VanderPlas authored and DistraxDev committed Dec 14, 2024
1 parent ddf4c7e commit c02708a
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions distrax/_src/utils/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

from absl import logging
import jax
import jax.extend as jex
import jax.numpy as jnp

# pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -156,7 +157,7 @@ def is_constant_jacobian(fn, x=0.0):
jac_jaxpr = jax.make_jaxpr(jac_fn)(jnp.array(x)).jaxpr
dependent_vars = _dependent_variables(jac_jaxpr)

jac_is_constant = not any(isinstance(v, jax.core.Var) and v in dependent_vars
jac_is_constant = not any(isinstance(v, jex.core.Var) and v in dependent_vars
for v in jac_jaxpr.outvars)

return jac_is_constant
Expand Down Expand Up @@ -202,7 +203,7 @@ def _dependent_variables(jaxpr, dependent=None):
if v in subjaxpr_dependent)
else:
for v in eqn.invars:
if isinstance(v, jax.core.Var) and v in dependent:
if isinstance(v, jex.core.Var) and v in dependent:
dependent.update(eqn.outvars)

return dependent
Expand All @@ -226,20 +227,20 @@ def _identify_variable_in_eqn(eqn):
var_idx = 0

elif len(eqn.invars) == 2: # binary operation
if tuple(map(type, eqn.invars)) == (jax.core.Var, jax.core.Literal):
if tuple(map(type, eqn.invars)) == (jex.core.Var, jex.core.Literal):
var_idx = 0

elif tuple(map(type, eqn.invars)) == (jax.core.Literal, jax.core.Var):
elif tuple(map(type, eqn.invars)) == (jex.core.Literal, jex.core.Var):
var_idx = 1

elif tuple(map(type, eqn.invars)) == (jax.core.Var, jax.core.Var):
elif tuple(map(type, eqn.invars)) == (jex.core.Var, jex.core.Var):
raise NotImplementedError(
"Expressions with multiple occurrences of the input variable are "
"not supported. Please rearrange such that the variable appears only "
"once in the expression if possible. If not possible, consider "
"providing both `forward` and `inverse` to Lambda explicitly.")

elif tuple(map(type, eqn.invars)) == (jax.core.Literal, jax.core.Literal):
elif tuple(map(type, eqn.invars)) == (jex.core.Literal, jex.core.Literal):
raise ValueError("Expression appears to contain no variables and "
"therefore cannot be inverted.")

Expand All @@ -259,7 +260,7 @@ def _interpret_inverse(jaxpr, consts, *args):
env = {}

def read(var):
return var.val if isinstance(var, jax.core.Literal) else env[var]
return var.val if isinstance(var, jex.core.Literal) else env[var]
def write(var, val):
env[var] = val

Expand Down

0 comments on commit c02708a

Please sign in to comment.