Skip to content

Commit

Permalink
Streamline some core.concrete_aval compute paths
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 17, 2024
1 parent 0fa5419 commit a2ac234
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
core.xla_pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error


def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))

core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array


Expand All @@ -73,26 +74,20 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))

for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = canonical_concrete_aval
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar

core.literalable_types.update(array_types)


def _make_concrete_python_scalar(t, x):
dtype = dtypes._scalar_type_to_dtype(t, x)
weak_type = dtypes.is_weakly_typed(x)
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)


def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

0 comments on commit a2ac234

Please sign in to comment.