Skip to content

Commit

Permalink
Deprecate symbols in jax.interpreters.xla
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 17, 2024
1 parent 63d73a5 commit 254be60
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
7 changes: 0 additions & 7 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ def _canonicalize_python_scalar_dtype(typ, x):
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity

# TODO(jakevdp): deprecate and remove this.
def abstractify(x) -> Any:
return core.abstractify(x)

# TODO(jakevdp): deprecate and remove this alias.
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.pytype_aval_mappings

initial_style_primitives: set[core.Primitive] = set()

def register_initial_style_primitive(prim: core.Primitive):
Expand Down
22 changes: 19 additions & 3 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.

from jax._src.interpreters.xla import (
abstractify as abstractify,
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
pytype_aval_mappings as pytype_aval_mappings,
)

from jax._src.dispatch import (
Expand All @@ -27,8 +25,19 @@
Backend = _xc._xla.Client
del _xc

from jax._src import core as _src_core

# Deprecations
_deprecations = {
# Added 2024-12-17
"abstractify": (
"jax.interpreters.xla.abstractify is deprecated.",
_src_core.abstractify
),
"pytype_aval_mappings": (
"jax.interpreters.xla.pytype_aval_mappings is deprecated.",
_src_core.pytype_aval_mappings
),
# Finalized 2024-10-24; remove after 2025-01-24
"xb": (
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
Expand All @@ -44,6 +53,13 @@
),
}

import typing as _typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
if _typing.TYPE_CHECKING:
abstractify = _src_core.abstractify
pytype_aval_mappings = _src_core.pytype_aval_mappings
else:
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
del _src_core

0 comments on commit 254be60

Please sign in to comment.