diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 2502b705b8fa..8ddc33fd8983 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs): "Use arr.filled() to convert the value to a standard numpy array.") core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error -core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: @@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array -core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: @@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: for t in numpy_scalar_types: core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar - core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar core.literalable_types.update(array_types) @@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val): for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) - core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t) core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/api.py b/jax/_src/api.py index 38ba4fd2d381..4bf964a72239 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2564,7 +2564,6 @@ def _sds_aval_mapping(x): x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=x.weak_type) core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping -core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping @api_boundary diff --git a/jax/_src/array.py b/jax/_src/array.py index 1ce8e7786bb2..2ee8b01c77d4 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1035,7 +1035,6 @@ def _get_aval_array(self): else: return self.aval -core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array core.pytype_aval_mappings[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. diff --git a/jax/_src/core.py b/jax/_src/core.py index 5f351bd46883..a6b3c843c9ef 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array): " is ambiguous. Use a.any() or a.all()") +pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} + +def _str_abstractify(x): + raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") +pytype_aval_mappings[str] = _str_abstractify + + def _aval_property(name): return property(lambda self: getattr(self.aval, name)) @@ -918,6 +925,8 @@ def unsafe_buffer_pointer(self): aval_property = namedtuple("aval_property", ["fget"]) aval_method = namedtuple("aval_method", ["fun"]) +pytype_aval_mappings[Tracer] = lambda x: x.aval + def check_eval_args(args): for arg in args: if isinstance(arg, Tracer): @@ -1400,45 +1409,49 @@ def check_valid_jaxtype(x): f"Value {x!r} of type {type(x)} is not a valid JAX type") -def _shaped_abstractify_slow(x): - try: - return x if isinstance(x, AbstractValue) else get_aval(x) - except TypeError: - pass - - weak_type = getattr(x, 'weak_type', False) - if hasattr(x, 'dtype'): - dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) - else: - raise TypeError( - f"Cannot interpret value of type {type(x)} as an abstract array; it " - "does not have a dtype attribute") - return ShapedArray(np.shape(x), dtype, weak_type=weak_type) +# We have three flavors of abstractification APIs here which each used to have +# their own separate implementation. Now they're effectively the same, with the +# following differences: +# +# - abstractify returns avals for non-traced array-like objects. +# - get_aval is like abstractify, but also accepts tracers. +# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays. +# +# TODO(jakevdp): can these be unified further? -# TODO(jakevdp): deduplicate this with abstractify def shaped_abstractify(x): - # This was originally api_util.shaped_abstractify; temporarily moved - # here in order to facilitate combining it with abstractify. - handler = shaped_abstractify_handlers.get(type(x), None) - return handler(x) if handler is not None else _shaped_abstractify_slow(x) + if (aval_fn := pytype_aval_mappings.get(type(x))): + return aval_fn(x) + for typ in type(x).__mro__: + if (aval_fn := pytype_aval_mappings.get(typ)): + return aval_fn(x) + if isinstance(x, AbstractValue): + return x + if hasattr(x, '__jax_array__'): + return shaped_abstractify(x.__jax_array__()) + if hasattr(x, 'dtype'): + return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False)) + raise TypeError( + f"Cannot interpret value of type {type(x)} as an abstract array; it " + "does not have a dtype attribute") def abstractify(x): - for typ in type(x).__mro__: - aval_fn = pytype_aval_mappings.get(typ) - if aval_fn: return aval_fn(x) - if hasattr(x, '__jax_array__'): - return abstractify(x.__jax_array__()) - raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") + if isinstance(x, Tracer): + raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") + return get_aval(x) def get_aval(x): - if isinstance(x, Tracer): - return x.aval - else: - return abstractify(x) + if (aval_fn := pytype_aval_mappings.get(type(x))): + return aval_fn(x) + for typ in type(x).__mro__: + if (aval_fn := pytype_aval_mappings.get(typ)): + return aval_fn(x) + if hasattr(x, '__jax_array__'): + return get_aval(x.__jax_array__()) + raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") -get_type = get_aval def is_concrete(x): return to_concrete_value(x) is not None @@ -1831,13 +1844,6 @@ def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) -pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} -shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {} - -def _str_abstractify(x): - raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") -pytype_aval_mappings[str] = _str_abstractify -shaped_abstractify_handlers[str] = _str_abstractify class DArray: _aval: DShapedArray @@ -1894,7 +1900,6 @@ def _darray_aval(x): return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) pytype_aval_mappings[DArray] = _darray_aval -shaped_abstractify_handlers[DArray] = _darray_aval @dataclass(frozen=True) @@ -1924,11 +1929,10 @@ def __init__(self, aval, buf): aval = property(lambda self: self._aval) shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) - def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) - def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) + def __getitem__(self, idx): return self._aval._getitem(self, idx) + def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval -shaped_abstractify_handlers[MutableArray] = lambda x: x._aval def mutable_array(init_val): return mutable_array_p.bind(init_val) @@ -1984,7 +1988,6 @@ def __init__(self, buf): def block_until_ready(self): self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token -shaped_abstractify_handlers[Token] = lambda _: abstract_token # TODO(dougalm): Deprecate these. They're just here for backwards compat. diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 25c2bc2bf7ec..98a0a863981e 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics): return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler -core.shaped_abstractify_handlers[EArray] = lambda self: self.aval core.pytype_aval_mappings[EArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[EArray] = lambda x: x tree_util.dispatch_registry.register_node( diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index b82890cab682..5462723c8335 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool: f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}") core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval -core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval dtypes._weak_types.append(_DimExpr) def _convertible_to_int(p: DimSize) -> bool: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 154b5e972682..ac0ae3a13967 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1569,10 +1569,7 @@ def get_referent(self): val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) - -def _dynamic_jaxpr_tracer_shaped_abstractify(x): - return x.aval -core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify +core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: sentinel = object() diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c907640a985..5aa11ea6a122 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -192,7 +192,6 @@ def __instancecheck__(self, instance: Any) -> bool: def _abstractify_scalar_meta(x): raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.") core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta -core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), diff --git a/jax/_src/prng.py b/jax/_src/prng.py index d29bad5d5304..4f43b54bb478 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -461,8 +461,6 @@ def __hash__(self) -> int: core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval -core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval') - xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x diff --git a/jax/core.py b/jax/core.py index 54bbdac51c87..ef1551b2f1ba 100644 --- a/jax/core.py +++ b/jax/core.py @@ -128,7 +128,7 @@ _src_core.escaped_tracer_error), "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", _src_core.extend_axis_env_nd), - "get_type": ("jax.core.get_type is deprecated.", _src_core.get_type), + "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", @@ -212,7 +212,7 @@ escaped_tracer_error = _src_core.escaped_tracer_error extend_axis_env_nd = _src_core.extend_axis_env_nd full_lower = _src_core.full_lower - get_type = _src_core.get_type + get_type = _src_core.get_aval get_referent = _src_core.get_referent jaxpr_as_fun = _src_core.jaxpr_as_fun join_effects = _src_core.join_effects