From 3ee576b4b01f41126794071a80a4bea933907bcd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 24 Dec 2024 03:58:46 -0800 Subject: [PATCH] [pallas:mosaic_gpu] Addressed a todo in `broadcasted_iota` lowering PiperOrigin-RevId: 709303687 --- jax/_src/pallas/mosaic_gpu/primitives.py | 65 +++++++++++++----------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 85b7364ce2cc..4a6e2764eb94 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,6 +16,7 @@ from __future__ import annotations +from collections.abc import Sequence import enum import math from typing import Any, Literal @@ -25,7 +26,6 @@ from jax._src import state from jax._src import tree_util from jax._src import util -from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect @@ -36,7 +36,8 @@ from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives -import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic import gpu as mgpu +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp @@ -703,38 +704,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): del layout, dimension return jax_core.ShapedArray(shape, dtype) -@lowering.register_lowering_rule(broadcasted_iota_p) -def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout): - del ctx - # Unsigned integers (as opposed to signless) cause MLIR verification - # errors so we only use signless like Mosaic GPU does. - # - # TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead. - mlir_dtype = ( - ir.IntegerType.get_signless(dtype.itemsize * 8) - if jnp.issubdtype(dtype, jnp.integer) - else mlir.dtype_to_ir_type(dtype) - ) - undef = llvm_dialect.mlir_undef(mlir_dtype) - is_signed = ( - jnp.issubdtype(dtype, jnp.signedinteger) - if jnp.issubdtype(dtype, jnp.integer) - else None - ) - i32 = ir.IntegerType.get_signless(32) - def _cast(x): - if ir.FloatType.isinstance(mlir_dtype): - x = arith_dialect.index_cast(i32, x) - return arith_dialect.uitofp(mlir_dtype, x) - else: - return arith_dialect.index_cast(mlir_dtype, x) +@lowering.register_lowering_rule(broadcasted_iota_p) +def _broadcasted_iota_lowering( + ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout +): + del ctx # Unused. + mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) + if ir.FloatType.isinstance(mlir_dtype): + i32 = ir.IntegerType.get_signless(32) + cast = lambda x: arith_dialect.uitofp( + mlir_dtype, arith_dialect.index_cast(i32, x) + ) + else: + cast = lambda x: arith_dialect.index_cast(mlir_dtype, x) + is_signed = mgpu_utils.is_signed(dtype) return mgpu.FragmentedArray.splat( - undef, shape, layout.value, is_signed=is_signed + llvm_dialect.mlir_undef(mlir_dtype), + shape, + layout.value, + is_signed=is_signed, ).foreach( - lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed + lambda _, idx: cast(idx[dimension]), + create_array=True, + is_signed=is_signed, ) -def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None): - return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout) +def broadcasted_iota( + dtype: jax.typing.DTypeLike, + shape: Sequence[int], + dimension: int, + *, + layout: Layout | None = None, +) -> jax.Array: + return broadcasted_iota_p.bind( + dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout + )