Skip to content

Commit

Permalink
Replace deprecated jax.linear_util.wrap_init with jax.extend.linear_u…
Browse files Browse the repository at this point in the history
…til.wrap_init.

PiperOrigin-RevId: 562017973
  • Loading branch information
suryabhupa authored and DistraxDev committed Sep 1, 2023
1 parent 93c54a8 commit 175cace
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion distrax/_src/utils/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def write(var, val):
# if primitive is an xla_call, get subexpressions and evaluate recursively
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, params)
if call_jaxpr:
subfuns = [jax.linear_util.wrap_init(
subfuns = [jax.extend.linear_util.wrap_init(
functools.partial(_interpret_inverse, call_jaxpr, ()))]
prim_inv = eqn.primitive

Expand Down

0 comments on commit 175cace

Please sign in to comment.