diff --git a/jax/core.py b/jax/core.py index 6a135a105ca9..62a1b0d28bc2 100644 --- a/jax/core.py +++ b/jax/core.py @@ -96,9 +96,9 @@ "OutDBIdx": ("jax.core.OutDBIdx is deprecated.", _src_core.OutDBIdx), "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING is deprecated.", _src_core.TRACER_LEAK_DEBUGGER_WARNING), - "call_p": ("jax.core.call_p is deprecated. Use jax.extend.primitives.call_p", + "call_p": ("jax.core.call_p is deprecated. Use jax.extend.core.primitives.call_p", _src_core.call_p), - "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.primitives.closed_call_p", + "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", _src_core.closed_call_p), "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.concrete_aval), "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents),