Skip to content

Commit

Permalink
Merge pull request #25650 from jakevdp:view-int4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708468858
  • Loading branch information
Google-ML-Automation committed Dec 21, 2024
2 parents 0159bea + 75f36dc commit 1c0dee8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
8 changes: 5 additions & 3 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,16 @@ def bit_width(dtype: DTypeLike) -> int:
"""Number of bits per element for the dtype."""
# Note: we cannot use dtype.itemsize here because this is
# incorrect for sub-byte integer types.
if dtype == bool:
if dtype == np.dtype(bool):
return 8 # physical bit layout for boolean dtype
elif issubdtype(dtype, np.integer):
return iinfo(dtype).bits
elif issubdtype(dtype, np.inexact):
elif issubdtype(dtype, np.floating):
return finfo(dtype).bits
elif issubdtype(dtype, np.complexfloating):
return 2 * finfo(dtype).bits
else:
raise ValueError("unexpected input: {dtype=}")
raise ValueError(f"unexpected input: {dtype=}")

# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)])
Expand Down
18 changes: 10 additions & 8 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
dtypes.check_user_dtype_supported(dtype, "view")
dtype = dtypes.canonicalize_dtype(dtype)

nbits_in = dtypes.bit_width(self.dtype)
nbits_out = dtypes.bit_width(dtype)

if self.ndim == 0:
if self.dtype.itemsize != dtype.itemsize:
if nbits_in != nbits_out:
raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.")
return _view(lax.expand_dims(self, (0,)), dtype).squeeze()

if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0:
if (self.shape[-1] * nbits_in) % nbits_out != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")

Expand Down Expand Up @@ -543,16 +546,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr

# lax.bitcast_convert_type adds or subtracts dimensions depending on the
# relative bitwidths of the dtypes; we account for that with reshapes.
if self.dtype.itemsize < dtype.itemsize:
factor = dtype.itemsize // self.dtype.itemsize
if nbits_in < nbits_out:
factor = nbits_out // nbits_in
out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor)
return lax.bitcast_convert_type(out, dtype)

if self.dtype.itemsize > dtype.itemsize:
elif nbits_in > nbits_out:
out = lax.bitcast_convert_type(self, dtype)
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])

return lax.bitcast_convert_type(self, dtype)
else:
return lax.bitcast_convert_type(self, dtype)


def _notimplemented_flat(self):
Expand Down
41 changes: 34 additions & 7 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]

def _bitcast_uint4_to_uint8(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint4'
operand = operand.astype('uint8')
return operand[..., ::2] + (operand[..., 1::2] << 4)

def _bitcast_uint8_to_uint4(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint8'
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
result[..., ::2] = (operand & 0b00001111).astype('uint4')
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
return result

def np_view(arr, dtype):
# Implementation of np.ndarray.view() that works for int4/uint4
dtype = np.dtype(dtype)
nbits_in = dtypes.bit_width(arr.dtype)
nbits_out = dtypes.bit_width(dtype)
if nbits_in == 4:
arr = _bitcast_uint4_to_uint8(arr.view('uint4'))
if nbits_out == 4:
arr = _bitcast_uint8_to_uint4(arr.view('uint8'))
return arr.view(dtype)


def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
axis=None, **kwds):
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
Expand Down Expand Up @@ -4244,9 +4270,10 @@ def testItem(self, shape, dtype, num_args, use_tuple):

@jtu.sample_product(
# Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs.
shape=[(0,), (32,), (2, 16)],
a_dtype=all_dtypes,
dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes,
shape=[(0,), (64,), (2, 32)],
a_dtype=(jnp.int4, jnp.uint4, *all_dtypes),
dtype=((jnp.int4, jnp.uint4, *all_dtypes, None)
if config.enable_x64.value else (jnp.int4, jnp.uint4, *all_dtypes)),
)
def testView(self, shape, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):
Expand All @@ -4259,7 +4286,7 @@ def testView(self, shape, a_dtype, dtype):
self.rng()
)
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
np_op = lambda x: np_view(x, dtype)
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
Expand All @@ -4268,9 +4295,9 @@ def testView(self, shape, a_dtype, dtype):

@jtu.sample_product([
{'a_dtype': a_dtype, 'dtype': dtype}
for a_dtype in all_dtypes
for dtype in all_dtypes
if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize
for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes]
for dtype in [jnp.int4, jnp.uint4, *all_dtypes]
if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype)
])
def testViewScalar(self, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):
Expand Down

0 comments on commit 1c0dee8

Please sign in to comment.