Skip to content

Commit

Permalink
Remove core.concrete_aval and replace with abstractify
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 18, 2024
1 parent 1e22149 commit 3cecbf3
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 56 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## Unreleased

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.

## jax 0.4.38 (Dec 17, 2024)

* Changes:
Expand Down
4 changes: 0 additions & 4 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
core.xla_pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error


def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
Expand All @@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))

core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array


def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
Expand All @@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar

core.literalable_types.update(array_types)

Expand All @@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val):

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
1 change: 0 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,6 @@ def _get_aval_array(self):

api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array

# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(ArrayImpl)
Expand Down
38 changes: 4 additions & 34 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def lattice_join(x, y):

def valid_jaxtype(x) -> bool:
try:
concrete_aval(x)
abstractify(x)
except TypeError:
return False
else:
Expand All @@ -1400,35 +1400,9 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")


# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
# This is tricky because concrete_aval includes sharding information, and
# abstractify does not; further, because abstractify is in the dispatch path,
# performance is important and simply adding sharding there is not an option.
def concrete_aval(x):
# This differs from abstractify below in that the abstract values
# include sharding where applicable. Historically (before stackless)
# the returned avals were concrete, but after the stackless change
# this returns ShapedArray like abstractify.
# Rules are registered in pytype_aval_mappings.
for typ in type(x).__mro__:
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return concrete_aval(x.__jax_array__())
raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
"type")


def abstractify(x):
# Historically, this was called xla.abstractify. It differs from
# concrete_aval in that it excludes sharding information, and
# uses a more performant path for accessing avals. Rules are
# registered in xla_pytype_aval_mappings.
typ = type(x)
aval_fn = xla_pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = xla_pytype_aval_mappings.get(typ)
for typ in type(x).__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
Expand All @@ -1439,7 +1413,7 @@ def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return concrete_aval(x)
return abstractify(x)

get_type = get_aval

Expand Down Expand Up @@ -1835,7 +1809,6 @@ def to_tangent_aval(self):
self.weak_type)

pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}


class DArray:
Expand Down Expand Up @@ -1892,7 +1865,6 @@ def data(self):

pytype_aval_mappings[DArray] = \
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
xla_pytype_aval_mappings[DArray] = lambda x: x._aval

@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
Expand Down Expand Up @@ -1925,7 +1897,6 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
Expand Down Expand Up @@ -1979,7 +1950,6 @@ def __init__(self, buf):
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
xla_pytype_aval_mappings[Token] = lambda _: abstract_token


# TODO(dougalm): Deprecate these. They're just here for backwards compat.
Expand Down
1 change: 0 additions & 1 deletion jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")

core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)

def _convertible_to_int(p: DimSize) -> bool:
Expand Down
7 changes: 0 additions & 7 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ def _canonicalize_python_scalar_dtype(typ, x):
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity

# TODO(jakevdp): deprecate and remove this.
def abstractify(x) -> Any:
return core.abstractify(x)

# TODO(jakevdp): deprecate and remove this.
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings

initial_style_primitives: set[core.Primitive] = set()

def register_initial_style_primitive(prim: core.Primitive):
Expand Down
1 change: 0 additions & 1 deletion jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ def __hash__(self) -> int:


core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval

xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x

Expand Down
4 changes: 2 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
_src_core.call_p),
"closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p",
_src_core.closed_call_p),
"concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.concrete_aval),
"concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify),
"dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents),
"escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.",
_src_core.escaped_tracer_error),
Expand Down Expand Up @@ -207,7 +207,7 @@
axis_frame = _src_core.axis_frame
call_p = _src_core.call_p
closed_call_p = _src_core.closed_call_p
concrete_aval = _src_core.concrete_aval
concrete_aval = _src_core.abstractify
dedup_referents = _src_core.dedup_referents
escaped_tracer_error = _src_core.escaped_tracer_error
extend_axis_env_nd = _src_core.extend_axis_env_nd
Expand Down
22 changes: 19 additions & 3 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.

from jax._src.interpreters.xla import (
abstractify as abstractify,
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
pytype_aval_mappings as pytype_aval_mappings,
)

from jax._src.dispatch import (
Expand All @@ -27,8 +25,19 @@
Backend = _xc._xla.Client
del _xc

from jax._src import core as _src_core

# Deprecations
_deprecations = {
# Added 2024-12-17
"abstractify": (
"jax.interpreters.xla.abstractify is deprecated.",
_src_core.abstractify
),
"pytype_aval_mappings": (
"jax.interpreters.xla.pytype_aval_mappings is deprecated.",
_src_core.pytype_aval_mappings
),
# Finalized 2024-10-24; remove after 2025-01-24
"xb": (
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
Expand All @@ -44,6 +53,13 @@
),
}

import typing as _typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
if _typing.TYPE_CHECKING:
abstractify = _src_core.abstractify
pytype_aval_mappings = _src_core.pytype_aval_mappings
else:
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
del _src_core
3 changes: 0 additions & 3 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3959,8 +3959,6 @@ def setUp(self):
core.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
core.xla_pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
mlir._constant_handlers[FooArray] = foo_array_constant_handler
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
Expand All @@ -3973,7 +3971,6 @@ def setUp(self):
def tearDown(self):
del core.pytype_aval_mappings[FooArray]
del xla.canonicalize_dtype_handlers[FooArray]
del core.xla_pytype_aval_mappings[FooArray]
del mlir._constant_handlers[FooArray]
del mlir._lowerings[make_p]
del mlir._lowerings[bake_p]
Expand Down

0 comments on commit 3cecbf3

Please sign in to comment.