Skip to content

Commit

Permalink
Add e8m0fnu support by conditional dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Dec 18, 2024
1 parent 8abb1a7 commit 7577ac9
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 1 deletion.
8 changes: 8 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def type(self) -> type: ...
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e8m0fnu: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
Expand All @@ -101,6 +102,7 @@ def type(self) -> type: ...

_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e8m0fnu_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
Expand Down Expand Up @@ -155,6 +157,12 @@ def supports_inf(dtype: DTypeLike) -> bool:
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
_float8_dtypes.insert(0, _float8_e3m4_dtype)
if hasattr(ml_dtypes, "float8_e8m0fnu"):
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)

# 2-bit integer support
int2: type[np.generic] | None = None
Expand Down
1 change: 1 addition & 0 deletions jax/_src/export/serialization.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ enum DType: byte {
f8_e4m3fnuz = 19,
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
f8_e8m0fnu = 25,
}

table AbstractValue {
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3

if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}
Expand Down
1 change: 1 addition & 0 deletions jax/_src/export/serialization_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DType(object):
f8_e5m2 = 20
f8_e5m2fnuz = 21
f0 = 22
f8_e8m0fnu = 25


class ShardingKind(object):
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def _is_ir_values(x: IrValues) -> bool:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get

def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,8 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
Expand Down Expand Up @@ -3764,6 +3766,8 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes += (dtypes.float8_e3m4,)
if dtypes.float8_e4m3 is not None:
fp8_dtypes += (dtypes.float8_e4m3,)
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += (dtypes.float8_e8m0fnu,)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
if dtypes.float8_e8m0fnu is not None:
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def default_tolerance():
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
Expand All @@ -119,6 +122,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)

def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,8 @@ def custom_floats(self):
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
return self.supported(float_dtypes)

@_cached_property
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@
from jax._src.numpy.lax_numpy import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e8m0fnu as float8_e8m0fnu,
)
except ImportError:
pass
Expand Down
2 changes: 2 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

Expand Down

0 comments on commit 7577ac9

Please sign in to comment.