From 75f36dc3ea613d17ebd89c2b1e3ef7e957960876 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Dec 2024 13:57:40 -0800 Subject: [PATCH] Support int4/uint4 in jnp.ndarray.view --- jax/_src/dtypes.py | 8 ++++--- jax/_src/numpy/array_methods.py | 18 ++++++++------- tests/lax_numpy_test.py | 41 +++++++++++++++++++++++++++------ 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 52cb3d87bbda..04b07843a324 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -209,14 +209,16 @@ def bit_width(dtype: DTypeLike) -> int: """Number of bits per element for the dtype.""" # Note: we cannot use dtype.itemsize here because this is # incorrect for sub-byte integer types. - if dtype == bool: + if dtype == np.dtype(bool): return 8 # physical bit layout for boolean dtype elif issubdtype(dtype, np.integer): return iinfo(dtype).bits - elif issubdtype(dtype, np.inexact): + elif issubdtype(dtype, np.floating): return finfo(dtype).bits + elif issubdtype(dtype, np.complexfloating): + return 2 * finfo(dtype).bits else: - raise ValueError("unexpected input: {dtype=}") + raise ValueError(f"unexpected input: {dtype=}") # Trivial vectorspace datatype needed for tangent values of int/bool primals float0: np.dtype = np.dtype([('float0', np.void, 0)]) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 617213ca03de..2cecc55f6489 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -509,12 +509,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr dtypes.check_user_dtype_supported(dtype, "view") dtype = dtypes.canonicalize_dtype(dtype) + nbits_in = dtypes.bit_width(self.dtype) + nbits_out = dtypes.bit_width(dtype) + if self.ndim == 0: - if self.dtype.itemsize != dtype.itemsize: + if nbits_in != nbits_out: raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.") return _view(lax.expand_dims(self, (0,)), dtype).squeeze() - if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0: + if (self.shape[-1] * nbits_in) % nbits_out != 0: raise ValueError("When changing to a larger dtype, its size must be a divisor " "of the total size in bytes of the last axis of the array.") @@ -543,16 +546,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr # lax.bitcast_convert_type adds or subtracts dimensions depending on the # relative bitwidths of the dtypes; we account for that with reshapes. - if self.dtype.itemsize < dtype.itemsize: - factor = dtype.itemsize // self.dtype.itemsize + if nbits_in < nbits_out: + factor = nbits_out // nbits_in out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor) return lax.bitcast_convert_type(out, dtype) - - if self.dtype.itemsize > dtype.itemsize: + elif nbits_in > nbits_out: out = lax.bitcast_convert_type(self, dtype) return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) - - return lax.bitcast_convert_type(self, dtype) + else: + return lax.bitcast_convert_type(self, dtype) def _notimplemented_flat(self): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7d26b1df849e..d507761abba7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -87,6 +87,32 @@ # uint64 is problematic because with any uint type it promotes to float: int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] +def _bitcast_uint4_to_uint8(operand): + # Note: assumes little-endian byte order. + assert operand.dtype == 'uint4' + operand = operand.astype('uint8') + return operand[..., ::2] + (operand[..., 1::2] << 4) + +def _bitcast_uint8_to_uint4(operand): + # Note: assumes little-endian byte order. + assert operand.dtype == 'uint8' + result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4') + result[..., ::2] = (operand & 0b00001111).astype('uint4') + result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4') + return result + +def np_view(arr, dtype): + # Implementation of np.ndarray.view() that works for int4/uint4 + dtype = np.dtype(dtype) + nbits_in = dtypes.bit_width(arr.dtype) + nbits_out = dtypes.bit_width(dtype) + if nbits_in == 4: + arr = _bitcast_uint4_to_uint8(arr.view('uint4')) + if nbits_out == 4: + arr = _bitcast_uint8_to_uint4(arr.view('uint8')) + return arr.view(dtype) + + def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, **kwds): # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 @@ -4244,9 +4270,10 @@ def testItem(self, shape, dtype, num_args, use_tuple): @jtu.sample_product( # Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs. - shape=[(0,), (32,), (2, 16)], - a_dtype=all_dtypes, - dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, + shape=[(0,), (64,), (2, 32)], + a_dtype=(jnp.int4, jnp.uint4, *all_dtypes), + dtype=((jnp.int4, jnp.uint4, *all_dtypes, None) + if config.enable_x64.value else (jnp.int4, jnp.uint4, *all_dtypes)), ) def testView(self, shape, a_dtype, dtype): if jtu.test_device_matches(["tpu"]): @@ -4259,7 +4286,7 @@ def testView(self, shape, a_dtype, dtype): self.rng() ) args_maker = lambda: [rng(shape, a_dtype)] - np_op = lambda x: np.asarray(x).view(dtype) + np_op = lambda x: np_view(x, dtype) jnp_op = lambda x: jnp.asarray(x).view(dtype) # Above may produce signaling nans; ignore warnings from invalid values. with np.errstate(invalid='ignore'): @@ -4268,9 +4295,9 @@ def testView(self, shape, a_dtype, dtype): @jtu.sample_product([ {'a_dtype': a_dtype, 'dtype': dtype} - for a_dtype in all_dtypes - for dtype in all_dtypes - if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize + for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes] + for dtype in [jnp.int4, jnp.uint4, *all_dtypes] + if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype) ]) def testViewScalar(self, a_dtype, dtype): if jtu.test_device_matches(["tpu"]):