Skip to content

Commit

Permalink
In progress experimention. Add StringDType to JAX's supported types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707662268
  • Loading branch information
Google-ML-Automation committed Dec 20, 2024
1 parent 4216f8f commit f9d2b96
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,15 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}

try:
import numpy.dtypes as np_dtypes
_string_types: list[JAXType] = [np_dtypes.StringDType()]
except ImportError:
_string_types: list[JAXType] = []
np_dtypes = None # type: ignore

_jax_types = _bool_types + _int_types + _float_types + _complex_types + _string_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types, *_string_types}

_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
Expand Down

0 comments on commit f9d2b96

Please sign in to comment.