diff --git a/distrax/_src/utils/transformations.py b/distrax/_src/utils/transformations.py index e754c05..8d84549 100644 --- a/distrax/_src/utils/transformations.py +++ b/distrax/_src/utils/transformations.py @@ -258,7 +258,7 @@ def write(var, val): # if primitive is an xla_call, get subexpressions and evaluate recursively call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, params) if call_jaxpr: - subfuns = [jax.linear_util.wrap_init( + subfuns = [jax.extend.linear_util.wrap_init( functools.partial(_interpret_inverse, call_jaxpr, ()))] prim_inv = eqn.primitive