From d95b95b405c31fe3a1af91bb3e1d7252c670f9f7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 18 Dec 2024 05:57:53 -0800 Subject: [PATCH] [Mosaic TPU] Add support for exp, exp2 and log in bf16 on TPUv6 PiperOrigin-RevId: 707520511 --- jax/_src/pallas/mosaic/lowering.py | 6 ++++-- tests/pallas/ops_test.py | 17 ++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d0acc655a6a5..5c9f1932ae54 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2178,8 +2178,10 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): def _exp2_lowering_rule(ctx: LoweringRuleContext, x): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. - return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)( - ctx, x) + return lower_fun( + lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), + multiple_results=False, + )(ctx, x) lowering_rules[lax.exp2_p] = _exp2_lowering_rule diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index abbabf401278..7296b2b8ae13 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -816,7 +816,7 @@ def kernel(x_ref, o_ref): ([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]), ( [jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt], - ["float16", "float32", "float64"], + ["bfloat16", "float16", "float32", "float64"], ), ( # fmt: off @@ -843,11 +843,13 @@ def test_elementwise(self, fn, dtype): if dtype in ("int16", "float16"): self.skipTest("int16 and float16 are not supported on TPU") if ( - fn in (jnp.ceil, jnp.floor, jnp.negative) + fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log) and dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6) ): self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+") + if fn in (jnp.sqrt, jnp.sin, jnp.cos) and dtype == "bfloat16": + self.skipTest(f"bfloat16 {fn.__name__} is not supported on TPU") # TODO(b/370578663): implement these lowerings on TPU if fn in ( jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh, @@ -870,8 +872,13 @@ def kernel(x_ref, o_ref): o_ref[:] = fn(x_ref[...]) # create an array with shape (8, 128) - x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype) - self.assertAllClose(kernel(x), fn(x), rtol=1e-6) + if fn in (jnp.exp, jnp.exp2) and dtype == "bfloat16": + x = jnp.array([0.42, 1.26] * (8 * 128 // 2)).reshape(8, 128).astype(dtype) + rtol = 2e-3 + else: + x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype) + rtol = 1e-6 + self.assertAllClose(kernel(x), fn(x), rtol=rtol) @parameterized.named_parameters( (f"{fn.__name__}_{dtype}", fn, dtype) @@ -919,7 +926,7 @@ def kernel(x_ref, o_ref): o_ref[0] = fn(x_ref[0]) o_ref[1] = fn(x_ref[1]) - x = jnp.array([0.42, 2.4]).astype(dtype) + x = jnp.array([0.42, 1.4]).astype(dtype) self.assertAllClose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self):