Skip to content

Commit

Permalink
Remove references to jax.core.raise_to_shaped
Browse files Browse the repository at this point in the history
As of JAX v0.4.36, `core.raise_to_shaped` is deprecated, and simply returns the input unchanged.

PiperOrigin-RevId: 705117138
  • Loading branch information
Jake VanderPlas authored and DistraxDev committed Dec 11, 2024
1 parent 45e8a8e commit ddf4c7e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion distrax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def multiply_no_nan_jvp(
x, y = primals
x_dot, y_dot = tangents
primal_out = multiply_no_nan(x, y)
primal_aval = jax_core.raise_to_shaped(jax_core.get_aval(primal_out))
primal_aval = jax_core.get_aval(primal_out)
result_aval = primal_aval.at_least_vspace()
tangent_out_1 = scale_maybe_symbolic(result_aval, x_dot, y)
tangent_out_2 = scale_maybe_symbolic(result_aval, y_dot, x)
Expand Down

0 comments on commit ddf4c7e

Please sign in to comment.