Skip to content

Commit

Permalink
Conditional dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl authored Dec 11, 2024
1 parent ad1be72 commit ddbf557
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 11 deletions.
13 changes: 8 additions & 5 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,21 @@ def type(self) -> type: ...
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e8m0fnu: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu

_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e8m0fnu_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu)

def supports_inf(dtype: DTypeLike) -> bool:
"""Return true if the dtype supports infinity, else return False."""
Expand All @@ -126,7 +126,6 @@ def supports_inf(dtype: DTypeLike) -> bool:
float8_e4m3fnuz,
float8_e5m2,
float8_e5m2fnuz,
float8_e8m0fnu,
bfloat16,
]
_custom_float_dtypes = [
Expand All @@ -135,7 +134,6 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e4m3fnuz_dtype,
_float8_e5m2_dtype,
_float8_e5m2fnuz_dtype,
_float8_e8m0fnu_dtype,
_bfloat16_dtype,
]
_float8_dtypes = [
Expand All @@ -144,7 +142,6 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e4m3fnuz_dtype,
_float8_e5m2_dtype,
_float8_e5m2fnuz_dtype,
_float8_e8m0fnu_dtype,
]

# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
Expand All @@ -160,6 +157,12 @@ def supports_inf(dtype: DTypeLike) -> bool:
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
_float8_dtypes.insert(0, _float8_e3m4_dtype)
if hasattr(ml_dtypes, "float8_e8m0fnu"):
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)

# 2-bit integer support
int2: type[np.generic] | None = None
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,14 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz,
dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2,
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu,
}

if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3

if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def _is_ir_values(x: IrValues) -> bool:
np.dtype(dtypes.float8_e4m3fnuz): ir.Float8E4M3FNUZType.get,
np.dtype(dtypes.float8_e5m2): ir.Float8E5M2Type.get,
np.dtype(dtypes.float8_e5m2fnuz): ir.Float8E5M2FNUZType.get,
np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get,
np.dtype(dtypes.bfloat16): ir.BF16Type.get,
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
Expand All @@ -192,6 +191,8 @@ def _is_ir_values(x: IrValues) -> bool:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get

def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,8 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
Expand Down Expand Up @@ -3738,6 +3740,8 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes += (dtypes.float8_e3m4,)
if dtypes.float8_e4m3 is not None:
fp8_dtypes += (dtypes.float8_e4m3,)
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += (dtypes.float8_e8m0fnu,)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
if dtypes.float8_e8m0fnu is not None:
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def default_tolerance():
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
Expand All @@ -119,6 +122,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)

def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,8 @@ def custom_floats(self):
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
return self.supported(float_dtypes)

@_cached_property
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
float8_e4m3fnuz as float8_e4m3fnuz,
float8_e5m2 as float8_e5m2,
float8_e5m2fnuz as float8_e5m2fnuz,
float8_e8m0fnu as float8_e8m0fnu,
float_ as float_,
floating as floating,
fmax as fmax,
Expand Down Expand Up @@ -280,6 +279,7 @@
from jax._src.numpy.lax_numpy import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e8m0fnu as float8_e8m0fnu,
)
except ImportError:
pass
Expand Down
1 change: 0 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ float8_e4m3fn: Any
float8_e4m3fnuz: Any
float8_e5m2: Any
float8_e5m2fnuz: Any
float8_e8m0fnu: Any
float_: Any
def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: ...
floating = _np.floating
Expand Down
2 changes: 2 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

Expand Down

0 comments on commit ddbf557

Please sign in to comment.