Skip to content

Commit

Permalink
[Mosaic:TPU] In infer ext rule, avoid assigning offsets outside of ds…
Browse files Browse the repository at this point in the history
…t first tile

Note that offsets outside of first tile are still disabled (for both infer and apply), and once we support it we will want to assign offsets differently, this is mostly to avoid assigning invalid layouts (that may not just be outside the first tile, but outside the vreg slice)

PiperOrigin-RevId: 707326013
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 18, 2024
1 parent cca9afa commit 9222b12
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1668,10 +1668,17 @@ class VectorLayoutInferer {
Layout dst_layout;
if (layout.tiling() == nativeTiling(src_bitwidth)) {
// If the source is already in native tiling, we can unpack it directly.
src_layout = layout;
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
LayoutOffsets offsets = {layout.offsets()[0]
? *layout.offsets()[0] % dst_native_tiling[0]
: LayoutOffset(),
layout.offsets()[1]};
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
layout.implicit_dim());
dst_layout =
VectorLayout(dst_bitwidth, layout.offsets(),
nativeTiling(dst_bitwidth), layout.implicit_dim());
VectorLayout(dst_bitwidth, offsets, nativeTiling(dst_bitwidth),
layout.implicit_dim());
} else if (dst_bitwidth == 32 &&
default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
Expand All @@ -1680,13 +1687,18 @@ class VectorLayoutInferer {
// tiling through the op.
// TODO(jevinjiang): we can relax this for non-32bit as well.
src_layout = layout;
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
src_layout->tiling(), layout.implicit_dim());
} else {
// TODO(b/335863273): we should also reduce offsets.
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
LayoutOffsets offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
: LayoutOffset(),
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
: LayoutOffset()};
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
}
setLayout(op, src_layout, dst_layout);
Expand Down

0 comments on commit 9222b12

Please sign in to comment.