Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Addressed a todo in broadcasted_iota lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709303687
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 24, 2024
1 parent 4eff131 commit 3ee576b
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

from collections.abc import Sequence
import enum
import math
from typing import Any, Literal
Expand All @@ -25,7 +26,6 @@
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
Expand All @@ -36,7 +36,8 @@
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import utils as mgpu_utils
import jax.numpy as jnp


Expand Down Expand Up @@ -703,38 +704,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
del layout, dimension
return jax_core.ShapedArray(shape, dtype)

@lowering.register_lowering_rule(broadcasted_iota_p)
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
del ctx
# Unsigned integers (as opposed to signless) cause MLIR verification
# errors so we only use signless like Mosaic GPU does.
#
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
mlir_dtype = (
ir.IntegerType.get_signless(dtype.itemsize * 8)
if jnp.issubdtype(dtype, jnp.integer)
else mlir.dtype_to_ir_type(dtype)
)
undef = llvm_dialect.mlir_undef(mlir_dtype)
is_signed = (
jnp.issubdtype(dtype, jnp.signedinteger)
if jnp.issubdtype(dtype, jnp.integer)
else None
)

i32 = ir.IntegerType.get_signless(32)
def _cast(x):
if ir.FloatType.isinstance(mlir_dtype):
x = arith_dialect.index_cast(i32, x)
return arith_dialect.uitofp(mlir_dtype, x)
else:
return arith_dialect.index_cast(mlir_dtype, x)
@lowering.register_lowering_rule(broadcasted_iota_p)
def _broadcasted_iota_lowering(
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
):
del ctx # Unused.
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
if ir.FloatType.isinstance(mlir_dtype):
i32 = ir.IntegerType.get_signless(32)
cast = lambda x: arith_dialect.uitofp(
mlir_dtype, arith_dialect.index_cast(i32, x)
)
else:
cast = lambda x: arith_dialect.index_cast(mlir_dtype, x)
is_signed = mgpu_utils.is_signed(dtype)
return mgpu.FragmentedArray.splat(
undef, shape, layout.value, is_signed=is_signed
llvm_dialect.mlir_undef(mlir_dtype),
shape,
layout.value,
is_signed=is_signed,
).foreach(
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
lambda _, idx: cast(idx[dimension]),
create_array=True,
is_signed=is_signed,
)


def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)
def broadcasted_iota(
dtype: jax.typing.DTypeLike,
shape: Sequence[int],
dimension: int,
*,
layout: Layout | None = None,
) -> jax.Array:
return broadcasted_iota_p.bind(
dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout
)

0 comments on commit 3ee576b

Please sign in to comment.