Skip to content

Commit

Permalink
In progress. Not ready for review. Approach 2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703603535
  • Loading branch information
Google-ML-Automation committed Dec 10, 2024
1 parent 2ff9038 commit d03c1cb
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ py_library_providing_imports_info(
":xla_bridge",
":xla_metadata",
"//jax/_src/lib",
"//third_party/py/absl/logging",
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
)

Expand Down
5 changes: 4 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any

import numpy as np
from numpy import dtypes as np_dtypes

from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -615,7 +616,9 @@ def _str_abstractify(x):

def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if not isinstance(dtype, np_dtypes.StringDType): # type: ignore
dtypes.check_valid_dtype(dtype)

return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
Expand Down
13 changes: 6 additions & 7 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,15 @@
from functools import partial
from typing import Any, Union

import numpy as np

from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.util import safe_zip, safe_map

from jax._src.typing import Shape

from jax._src.lib import xla_client as xc
from jax._src.typing import Shape
from jax._src.util import safe_map, safe_zip
import numpy as np
import numpy.dtypes as np_dtypes

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -170,7 +168,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if not isinstance(dtype, np_dtypes.StringDType): # type: ignore
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))


Expand Down
14 changes: 14 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) ->
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(file, *args, **kwargs)
if isinstance(out, np.ndarray):

if out.dtype == np.object_:
return out

# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
out = out.view(bfloat16)
Expand Down Expand Up @@ -5575,6 +5579,16 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# 2DO: Comment.
if isinstance(object, np.ndarray) and (
object.dtype == np.dtypes.StringDType()
):
if (ndmin > 0) and (ndmin != object.ndim):
raise TypeError(
f"ndmin {ndmin} does not match ndims {object.ndim} of input array"
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down

0 comments on commit d03c1cb

Please sign in to comment.