Skip to content

Commit

Permalink
[Mosaic TPU] Add support for exp, exp2 and log in bf16 on TPUv6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707520511
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 18, 2024
1 parent 5c9756b commit d95b95b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
6 changes: 4 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,8 +2178,10 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
# here.
return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)(
ctx, x)
return lower_fun(
lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x),
multiple_results=False,
)(ctx, x)


lowering_rules[lax.exp2_p] = _exp2_lowering_rule
Expand Down
17 changes: 12 additions & 5 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def kernel(x_ref, o_ref):
([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]),
(
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
["float16", "float32", "float64"],
["bfloat16", "float16", "float32", "float64"],
),
(
# fmt: off
Expand All @@ -843,11 +843,13 @@ def test_elementwise(self, fn, dtype):
if dtype in ("int16", "float16"):
self.skipTest("int16 and float16 are not supported on TPU")
if (
fn in (jnp.ceil, jnp.floor, jnp.negative)
fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log)
and dtype == "bfloat16"
and not jtu.is_device_tpu_at_least(6)
):
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")
if fn in (jnp.sqrt, jnp.sin, jnp.cos) and dtype == "bfloat16":
self.skipTest(f"bfloat16 {fn.__name__} is not supported on TPU")
# TODO(b/370578663): implement these lowerings on TPU
if fn in (
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,
Expand All @@ -870,8 +872,13 @@ def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

# create an array with shape (8, 128)
x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)
if fn in (jnp.exp, jnp.exp2) and dtype == "bfloat16":
x = jnp.array([0.42, 1.26] * (8 * 128 // 2)).reshape(8, 128).astype(dtype)
rtol = 2e-3
else:
x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype)
rtol = 1e-6
self.assertAllClose(kernel(x), fn(x), rtol=rtol)

@parameterized.named_parameters(
(f"{fn.__name__}_{dtype}", fn, dtype)
Expand Down Expand Up @@ -919,7 +926,7 @@ def kernel(x_ref, o_ref):
o_ref[0] = fn(x_ref[0])
o_ref[1] = fn(x_ref[1])

x = jnp.array([0.42, 2.4]).astype(dtype)
x = jnp.array([0.42, 1.4]).astype(dtype)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
Expand Down

0 comments on commit d95b95b

Please sign in to comment.