Skip to content

Commit

Permalink
In progress experimentation for supporting JAX Arrays with variable-w…
Browse files Browse the repository at this point in the history
…idth strings (i.e., with dtype = StringDType).

PiperOrigin-RevId: 703603535
  • Loading branch information
Google-ML-Automation committed Dec 17, 2024
1 parent 36b12d5 commit 10f19ba
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
10 changes: 9 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None

from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
Expand Down Expand Up @@ -614,7 +619,10 @@ def _str_abstractify(x):

def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)

if (np_dtypes is not None) and (getattr(np_dtypes, "StringDType", None) is not None) and (not isinstance(dtype, np_dtypes.StringDType)): # type: ignore
dtypes.check_valid_dtype(dtype)

return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
Expand Down
16 changes: 10 additions & 6 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
from functools import partial
from typing import Any, Union

import numpy as np

from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.util import safe_zip, safe_map

from jax._src.lib import xla_client as xc
from jax._src.typing import Shape
from jax._src.util import safe_map, safe_zip
import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None

from jax._src.lib import xla_client as xc

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -170,7 +173,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if (np_dtypes is not None) and (getattr(np_dtypes, "StringDType", None) is not None) and (not isinstance(dtype, np_dtypes.StringDType)): # type: ignore
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))


Expand Down
13 changes: 13 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np

try:
from numpy import dtypes as np_dtypes
except ImportError:
np_dtypes = None
import opt_einsum

export = set_module('jax.numpy')
Expand Down Expand Up @@ -5575,6 +5580,14 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# 2DO: Comment.
if isinstance(object, np.ndarray) and (np_dtypes is not None) and (getattr(np_dtypes, "StringDType", None) is not None) and (isinstance(object.dtype, np_dtypes.StringDType)): # type: ignore
if (ndmin > 0) and (ndmin != object.ndim):
raise TypeError(
f"ndmin {ndmin} does not match ndims {object.ndim} of input array"
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down

0 comments on commit 10f19ba

Please sign in to comment.