Skip to content

Commit

Permalink
[Mosaic:TPU] Replicating retilings with increasing tile size for (a) …
Browse files Browse the repository at this point in the history
…replicated 2nd minor or (b) 32-bit single-row

This is a generalization of the (1, 128) -> (8, 128) 32-bit replicated retiling

PiperOrigin-RevId: 703704106
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 14, 2024
1 parent 21d1d44 commit d1d7634
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 31 deletions.
11 changes: 11 additions & 0 deletions jaxlib/mosaic/dialect/tpu/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ class VectorLayout {

int8_t bitwidth() const { return bitwidth_; }
const LayoutOffsets &offsets() const { return offsets_; }
const LayoutOffsets getCanonicalOffsets(
const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
// For (1, n) tiling with a single row, 2nd minor replication does not
// change anything about the layout - it is equivalent to an offset of 0.
// We choose a replicated offset as "canonical".
const std::array<int64_t, 2> tiled_ishape = getImplicitTiledDims(shape, 1);
return {
(tiling_[0] == 1 && tiled_ishape[0] == 1) ? std::nullopt : offsets_[0],
offsets_[1]};
}
const std::array<int64_t, 2> &tiling() const { return tiling_; }
ImplicitDim implicit_dim() const { return implicit_dim_; }
int packing() const { return 32 / bitwidth_; }
Expand Down
106 changes: 75 additions & 31 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6164,42 +6164,83 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (src_tiling == dst_tiling) {
return std::pair(src, std::move(vregs));
}
const LayoutOffsets src_offsets =
src.getCanonicalOffsets(vty.getShape(), ctx.target_shape);
const std::array<int64_t, 2> tiled_ishape =
src.getImplicitTiledDims(vty.getShape(), 1);
const int packing = src.packing();
const int8_t bitwidth = src.bitwidth();
// Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating
// sublanes.
if (try_replicate_rows && packing == 1 &&
*(vregs.dimensions().end() - 2) == 1 &&
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
dst_tiling == ctx.target_shape) {
DCHECK_EQ(src.offsets()[0].value_or(0), 0);
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);

// Fully replicated offsets are handled efficiently elsewhere (in relayout)
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());

// Handle replicating small-to-large retiling for (a) replicated 2nd minor or
// (b) 32-bit single-row.
// This retiling is one-to-many vregs.
// TODO(tlongeri): Large-to-small retiling with replicated minor is analogous
// to this.
if (src_tiling[1] == ctx.target_shape[1] &&
dst_tiling[1] == ctx.target_shape[1] &&
dst_tiling[0] % src_tiling[0] == 0 &&
(!src_offsets[0].has_value() || (packing == 1 && tiled_ishape[0] == 1)) &&
// This relayout relies on gathers, which are cheap on newer generations,
// so we always use it for them.
// TODO(tlongeri): Once we have it, probably also prefer the
// small-to-large rotate+blend relayout if we don't need replication. It's
// slightly cheaper for some dst vregs you rotate by 0.
// TODO(tlongeri): Using store + multiple replicated loads is good on
// older gens. I wonder if we can integrate this logic to scratch retiling
(try_replicate_rows || ctx.hardware_generation >= 5)) {
const LayoutOffset dst_minor_offset =
src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1])
: std::nullopt;
src.offsets()[1].has_value() ? *src.offsets()[1] % dst_vreg_slice[1]
: LayoutOffset();
const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset},
dst_tiling, src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
*(src_idx.end() - 2) *= target_shape[0];
if (!src.offsets()[1].has_value()) {
// With (1, 128) tiling each vreg holds values from a single row. This
// means that if the columns are replicated, then the whole vreg is
// already replicated.
*(src_idx.end() - 1) = 0;
*tile = vregs(src_idx);
} else {
// The column (in units of sublanes) of the sublane we want:
const int64_t sublane_column =
*(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1];
*(src_idx.end() - 1) = sublane_column / target_shape[0];
const int64_t src_sl_idx = sublane_column % target_shape[0];
*tile =
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
const SmallVector<int64_t> dst_vreg_array_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
const int64_t src_tiles_per_vreg = src.tilesPerVreg(ctx.target_shape);
const int64_t dst_tiles_per_vreg = dst.tilesPerVreg(ctx.target_shape);
const int64_t src_sublanes_per_tile = src.sublanesPerTile(ctx.target_shape);
const int64_t dst_sublanes_per_tile = dst.sublanesPerTile(ctx.target_shape);
xla::Array<Value> retiled(dst_vreg_array_shape);
SmallVector<int64_t> idxs;
retiled.Each([&](absl::Span<const int64_t> dst_idx, Value *vreg) {
const int64_t dst_col_idx = *(dst_idx.end() - 1);
const int64_t base_dst_tile_idx = dst_col_idx * dst_tiles_per_vreg;
const int64_t base_src_tile_idx =
src_offsets[1].has_value()
? base_dst_tile_idx +
(*src_offsets[1] - *dst_minor_offset) / src_tiling[1]
: 0;
// The following should be true from our choice of minor offset:
DCHECK_EQ(base_src_tile_idx % dst_tiles_per_vreg, 0);
const int64_t src_col_idx = base_src_tile_idx / src_tiles_per_vreg;
SmallVector<int32_t, 8> gather_pattern;
// Iterate over the sublanes in the dst vreg:
for (int32_t sublane = 0; sublane < ctx.target_shape[0]; ++sublane) {
const int64_t dst_tile_idx_in_vreg = sublane / dst_sublanes_per_tile;
const int64_t src_tile_idx_in_vreg =
base_src_tile_idx % src_tiles_per_vreg + dst_tile_idx_in_vreg;
// Although replication may give us several sublanes to choose from,
// we always gather from the first sublane in the source tile. This
// degenerates to a broadcast when dst_tiling is native, which can
// be cheaper than an arbitrary gather (for some hardware gens).
const int64_t src_sublane_in_tile =
src_offsets[0].value_or(0) / packing;
const int64_t src_sublane =
src_tile_idx_in_vreg * src_sublanes_per_tile + src_sublane_in_tile;
gather_pattern.push_back(src_sublane);
}
idxs.assign(dst_idx.begin(), dst_idx.end());
*(idxs.end() - 2) = 0;
*(idxs.end() - 1) = src_col_idx;
Value src_vreg = vregs(idxs);
*vreg = builder.create<tpu::GatherOp>(loc, src_vreg.getType(), src_vreg,
gather_pattern,
/*dimension=*/0);
});
// We have successfully replicated sublanes
return std::pair(dst, std::move(retiled));
}
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
Expand Down Expand Up @@ -6576,8 +6617,11 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
return assemble_with_mask_check(src_tiles,
/*use_implicit_shape=*/true);
}
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
!src.offsets()[1].has_value()) {

if (const LayoutOffsets src_offsets =
src.getCanonicalOffsets(vty.getShape(), ctx.target_shape);
src.layout_rank() >= dst.layout_rank() && !src_offsets[0].has_value() &&
!src_offsets[1].has_value()) {
// A fully replicated value is always easy to relayout
xla::Array<Value> dst_tiles(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
Expand Down

0 comments on commit d1d7634

Please sign in to comment.