From 10996f98aca6fe56daeae0e31a2708c1ee5bb761 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Dec 2024 13:44:52 -0800 Subject: [PATCH] In progress. Not ready for review. Approach 2. PiperOrigin-RevId: 703603535 --- jax/BUILD | 1 + jax/_src/api_util.py | 5 ++++- jax/_src/interpreters/xla.py | 13 ++++++------- jax/_src/numpy/lax_numpy.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 271f5e95a626..6dd130be857c 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -307,6 +307,7 @@ py_library_providing_imports_info( ":xla_bridge", ":xla_metadata", "//jax/_src/lib", + "//third_party/py/absl/logging", ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, ) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index eb5e7e8bf8de..963efaa2f689 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -21,6 +21,7 @@ from typing import Any import numpy as np +from numpy import dtypes as np_dtypes from jax._src import core from jax._src import dtypes @@ -615,7 +616,9 @@ def _str_abstractify(x): def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray: dtype = x.dtype - dtypes.check_valid_dtype(dtype) + if not isinstance(dtype, np_dtypes.StringDType): # type: ignore + dtypes.check_valid_dtype(dtype) + return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) _shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 46bc7bef7ca7..f9a3f2ae1a98 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -20,17 +20,15 @@ from functools import partial from typing import Any, Union -import numpy as np - from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ShapedArray -from jax._src.util import safe_zip, safe_map - -from jax._src.typing import Shape - from jax._src.lib import xla_client as xc +from jax._src.typing import Shape +from jax._src.util import safe_map, safe_zip +import numpy as np +import numpy.dtypes as np_dtypes map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -170,7 +168,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: dtype = x.dtype - dtypes.check_valid_dtype(dtype) + if not isinstance(dtype, np_dtypes.StringDType): # type: ignore + dtypes.check_valid_dtype(dtype) return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3da4aa462f16..be02eec86f46 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -374,6 +374,10 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> # Note: this will only work for files created via np.save(), not np.savez(). out = np.load(file, *args, **kwargs) if isinstance(out, np.ndarray): + + if out.dtype == np.object_: + return out + # numpy does not recognize bfloat16, so arrays are serialized as void16 if out.dtype == 'V2': out = out.view(bfloat16) @@ -5575,6 +5579,16 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # Keep the output uncommitted. return jax.device_put(object) + # 2DO: Comment. + if isinstance(object, np.ndarray) and ( + object.dtype == np.dtypes.StringDType() + ): + if (ndmin > 0) and (ndmin != object.ndim): + raise TypeError( + f"ndmin {ndmin} does not match ndims {object.ndim} of input array" + ) + return jax.device_put(x=object, device=device) + # For Python scalar literals, call coerce_to_array to catch any overflow # errors. We don't use dtypes.is_python_scalar because we don't want this # triggering for traced values. We do this here because it matters whether or