Skip to content

Commit

Permalink
Internal: use a single registry for abstractify APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 22, 2024
1 parent 44d67e1 commit 0fda407
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 59 deletions.
4 changes: 0 additions & 4 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error


def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
Expand All @@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))

core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array


def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
Expand All @@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar

core.literalable_types.update(array_types)

Expand All @@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val):

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
1 change: 0 additions & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2564,7 +2564,6 @@ def _sds_aval_mapping(x):
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping


@api_boundary
Expand Down
1 change: 0 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,6 @@ def _get_aval_array(self):
else:
return self.aval

core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array

# TODO(jakevdp) replace this with true inheritance at the C++ level.
Expand Down
87 changes: 45 additions & 42 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array):
" is ambiguous. Use a.any() or a.all()")


pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}

def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify


def _aval_property(name):
return property(lambda self: getattr(self.aval, name))

Expand Down Expand Up @@ -918,6 +925,8 @@ def unsafe_buffer_pointer(self):
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])

pytype_aval_mappings[Tracer] = lambda x: x.aval

def check_eval_args(args):
for arg in args:
if isinstance(arg, Tracer):
Expand Down Expand Up @@ -1400,45 +1409,49 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")


def _shaped_abstractify_slow(x):
try:
return x if isinstance(x, AbstractValue) else get_aval(x)
except TypeError:
pass

weak_type = getattr(x, 'weak_type', False)
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
# following differences:
#
# - abstractify returns avals for non-traced array-like objects.
# - get_aval is like abstractify, but also accepts tracers.
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
#
# TODO(jakevdp): can these be unified further?

# TODO(jakevdp): deduplicate this with abstractify
def shaped_abstractify(x):
# This was originally api_util.shaped_abstractify; temporarily moved
# here in order to facilitate combining it with abstractify.
handler = shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
if (aval_fn := pytype_aval_mappings.get(type(x))):
return aval_fn(x)
for typ in type(x).__mro__:
if (aval_fn := pytype_aval_mappings.get(typ)):
return aval_fn(x)
if isinstance(x, AbstractValue):
return x
if hasattr(x, '__jax_array__'):
return shaped_abstractify(x.__jax_array__())
if hasattr(x, 'dtype'):
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")


def abstractify(x):
for typ in type(x).__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
if isinstance(x, Tracer):
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
return get_aval(x)


def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return abstractify(x)
if (aval_fn := pytype_aval_mappings.get(type(x))):
return aval_fn(x)
for typ in type(x).__mro__:
if (aval_fn := pytype_aval_mappings.get(typ)):
return aval_fn(x)
if hasattr(x, '__jax_array__'):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

get_type = get_aval

def is_concrete(x):
return to_concrete_value(x) is not None
Expand Down Expand Up @@ -1831,13 +1844,6 @@ def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {}

def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
shaped_abstractify_handlers[str] = _str_abstractify

class DArray:
_aval: DShapedArray
Expand Down Expand Up @@ -1894,7 +1900,6 @@ def _darray_aval(x):
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)

pytype_aval_mappings[DArray] = _darray_aval
shaped_abstractify_handlers[DArray] = _darray_aval


@dataclass(frozen=True)
Expand Down Expand Up @@ -1924,11 +1929,10 @@ def __init__(self, aval, buf):
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __getitem__(self, idx): return self._aval._getitem(self, idx)
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
shaped_abstractify_handlers[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
Expand Down Expand Up @@ -1984,7 +1988,6 @@ def __init__(self, buf):
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
shaped_abstractify_handlers[Token] = lambda _: abstract_token


# TODO(dougalm): Deprecate these. They're just here for backwards compat.
Expand Down
1 change: 0 additions & 1 deletion jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler

core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(
Expand Down
1 change: 0 additions & 1 deletion jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")

core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)

def _convertible_to_int(p: DimSize) -> bool:
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,10 +1569,7 @@ def get_referent(self):
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)


def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return x.aval
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval

def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()
Expand Down
1 change: 0 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def __instancecheck__(self, instance: Any) -> bool:
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta

def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
Expand Down
2 changes: 0 additions & 2 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,6 @@ def __hash__(self) -> int:


core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')

xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x


Expand Down
4 changes: 2 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
_src_core.escaped_tracer_error),
"extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.",
_src_core.extend_axis_env_nd),
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_type),
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval),
"get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent),
"join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects),
"leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.",
Expand Down Expand Up @@ -212,7 +212,7 @@
escaped_tracer_error = _src_core.escaped_tracer_error
extend_axis_env_nd = _src_core.extend_axis_env_nd
full_lower = _src_core.full_lower
get_type = _src_core.get_type
get_type = _src_core.get_aval
get_referent = _src_core.get_referent
jaxpr_as_fun = _src_core.jaxpr_as_fun
join_effects = _src_core.join_effects
Expand Down

0 comments on commit 0fda407

Please sign in to comment.