Skip to content

Commit

Permalink
[Mosaic:TPU] Do some "free" offset changes with scratch retiling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702557197
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 4, 2024
1 parent 8c78c1e commit 77cbef4
Showing 1 changed file with 144 additions and 61 deletions.
205 changes: 144 additions & 61 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5726,7 +5726,9 @@ LogicalResult retileToLargeTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_2nd_minor_offset,
const int64_t load_minor_offset) {
// TODO(DO NOT SUBMIT): Can we infer offsets from vreg array sizes?
if (dst_tile[0] % src_tile[0] != 0) {
return failure();
}
Expand Down Expand Up @@ -5830,7 +5832,7 @@ LogicalResult retileToLargeTileWithScratch(
SmallVector<int64_t, 4> src_idx(rank);
dst_tiles.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
int64_t dst_row_idx = *(dst_idx.end() - 2);
int64_t dst_col_idx = *(dst_idx.end() - 1);
int64_t dst_col_idx = *(dst_idx.end() - 1) + load_minor_offset;
int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group;
int64_t load_offset = sublanes_per_group * stored_group_cnt +
vreg_idx_in_group * sl_per_vreg * stride;
Expand All @@ -5841,16 +5843,21 @@ LogicalResult retileToLargeTileWithScratch(
// the vregs from current group and now we need to store corresponding
// group of src vregs before actually emitting the loads.
if (vreg_idx_in_group == vregs_per_group - 1 ||
dst_col_idx == dst_tiles.dimensions().back() - 1) {
auto src_row_idx = dst_row_idx * vregs_per_group;
dst_idx.back() == dst_tiles.dimensions().back() - 1) {
auto base_src_row_idx =
dst_row_idx * vregs_per_group - store_2nd_minor_offset;
auto src_col_idx = dst_col_idx / vregs_per_group;
std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (src_row_idx + vi >= src_tiles.dim(rank - 2) ||
const int64_t src_row_idx = base_src_row_idx + vi;
if (src_row_idx < 0) {
continue;
}
if (src_row_idx >= src_tiles.dim(rank - 2) ||
src_col_idx >= src_tiles.dim(rank - 1)) {
break;
}
*(src_idx.end() - 2) = src_row_idx + vi;
*(src_idx.end() - 2) = src_row_idx;
*(src_idx.end() - 1) = src_col_idx;
Value src_vreg = src_tiles(src_idx);
src_vreg =
Expand Down Expand Up @@ -5879,7 +5886,8 @@ LogicalResult retileToSmallTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_minor_offset,
const int64_t load_2nd_minor_offset) {
if (src_tile[0] % dst_tile[0] != 0) {
return failure();
}
Expand Down Expand Up @@ -6006,7 +6014,7 @@ LogicalResult retileToSmallTileWithScratch(
SmallVector<int64_t, 4> dst_idx(rank);
src_tiles.Each([&](absl::Span<const int64_t> src_idx, Value src_vreg) {
int64_t src_row_idx = *(src_idx.end() - 2);
int64_t src_col_idx = *(src_idx.end() - 1);
int64_t src_col_idx = *(src_idx.end() - 1) + store_minor_offset;
int64_t vreg_idx_in_group = src_col_idx % vregs_per_group;
src_vreg = builder.create<tpu::BitcastVregOp>(loc, temp_vreg_ty, src_vreg);
if (use_shuffled_load) {
Expand All @@ -6029,16 +6037,21 @@ LogicalResult retileToSmallTileWithScratch(
// vregs' row, this indicates we have stored all the vregs needed to
// assemble a new group of dst vreg.
if (vreg_idx_in_group == vregs_per_group - 1 ||
src_col_idx == src_tiles.dimensions().back() - 1) {
auto dst_row_idx = src_row_idx * vregs_per_group;
src_idx.back() == src_tiles.dimensions().back() - 1) {
auto base_dst_row_idx =
src_row_idx * vregs_per_group - load_2nd_minor_offset;
auto dst_col_idx = src_col_idx / vregs_per_group;
std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) ||
const int64_t dst_row_idx = base_dst_row_idx + vi;
if (dst_row_idx < 0) {
continue;
}
if (dst_row_idx >= dst_tiles.dim(rank - 2) ||
dst_col_idx >= dst_tiles.dim(rank - 1)) {
break;
}
*(dst_idx.end() - 2) = dst_row_idx + vi;
*(dst_idx.end() - 2) = dst_row_idx;
*(dst_idx.end() - 1) = dst_col_idx;
Value *dst_vreg = &dst_tiles(dst_idx);
int64_t load_offset =
Expand Down Expand Up @@ -6067,14 +6080,16 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
const Location loc,
xla::Array<Value> &dst_tiles,
const std::array<int64_t, 2> &dst_tiling,
const LayoutOffsets dst_offsets,
const xla::Array<Value> &src_tiles,
const std::array<int64_t, 2> &src_tiling,
int packing) {
const LayoutOffsets src_offsets, int packing) {
if (!(src_tiling[1] == ctx.target_shape[1] &&
dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 &&
dst_tiling[0] % packing == 0)) {
return failure();
}
const int bitwidth = 32 / packing;
// Try to get i32 vector scratch space. Because we will bitcast vregs to
// i32 vregs before using scratch for retiling. Through this way we can
// handle packed types as well.
Expand All @@ -6089,15 +6104,35 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
dst_tiling[1]};
std::array<int64_t, 2> vi32_src_tiling = {src_tiling[0] / packing,
src_tiling[1]};
// TODO(tlongeri): Think about replicated offsets
CHECK(dst_offsets[0].has_value());
CHECK(dst_offsets[1].has_value());
CHECK(src_offsets[0].has_value());
CHECK(src_offsets[1].has_value());
const std::array<int64_t, 2> src_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
const int64_t alignment_2nd_minor =
std::min(src_vreg_slice[0], dst_vreg_slice[0]);
const int64_t alignment_minor =
std::min(src_vreg_slice[1], dst_vreg_slice[1]);
CHECK_EQ((*dst_offsets[0] - *src_offsets[0]) % alignment_2nd_minor, 0);
CHECK_EQ((*dst_offsets[1] - *src_offsets[1]) % alignment_minor, 0);
if (src_tiling[0] > dst_tiling[0]) {
return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
const int64_t store_minor_offset = *dst_offsets[1] / alignment_minor;
const int64_t load_2nd_minor_offset = *src_offsets[0] / alignment_2nd_minor;
return retileToSmallTileWithScratch(
ctx, builder, loc, dst_tiles, vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref, store_minor_offset, load_2nd_minor_offset);
}
if (src_tiling[0] < dst_tiling[0]) {
return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
const int64_t store_2nd_minor_offset =
*dst_offsets[0] / alignment_2nd_minor;
const int64_t load_minor_offset = *src_offsets[1] / alignment_minor;
return retileToLargeTileWithScratch(
ctx, builder, loc, dst_tiles, vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref, store_2nd_minor_offset, load_minor_offset);
}
dst_tiles = std::move(src_tiles);
return success();
Expand All @@ -6106,7 +6141,8 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty,
const VectorLayout src, xla::Array<Value> vregs,
const std::array<int64_t, 2> dst_tiling, bool try_replicate_rows) {
const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offset_hints) {
bool has_enough_scratch = ctx.max_sublanes_in_scratch >=
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
const auto &target_shape = ctx.target_shape;
Expand All @@ -6116,42 +6152,52 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
}
const int packing = src.packing();
const int8_t bitwidth = src.bitwidth();
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
if (!dst.isValid(target_shape)) {
return emitError(loc, "Not implemented: invalid offsets in tiling target");
}
auto dst_tiles_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
// Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating
// sublanes.
if (try_replicate_rows && packing == 1 &&
*(vregs.dimensions().end() - 2) == 1 &&
src.offsets() == LayoutOffsets{0, 0} &&
if (!dst_offset_hints[0].has_value() && packing == 1 &&
*(vty.getShape().end() - 2) == 1 &&
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
dst_tiling == ctx.target_shape) {
xla::Array<Value> retiled(dst_tiles_shape);
DCHECK_EQ(src.offsets()[0].value_or(0), 0);
const LayoutOffset dst_minor_offset =
src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1])
: std::nullopt;
const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset},
dst_tiling, src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
*(src_idx.end() - 2) *= target_shape[0];
*(src_idx.end() - 1) /= target_shape[0];
const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0];
CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1);
*tile =
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
if (!src.offsets()[1].has_value()) {
// With (1, 128) tiling and replicated minor, each vreg is already fully
// replicated
*(src_idx.end() - 1) = 0;
*tile = vregs(src_idx);
} else {
// The column (in units of sublanes) of the sublane we want:
const int64_t sublane_column =
*(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1];
*(src_idx.end() - 1) = sublane_column / target_shape[0];
const int64_t src_sl_idx = sublane_column % target_shape[0];
*tile =
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
}
});
// We have successfully replicated sublanes.
dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling,
dst.implicit_dim());
// We have successfully replicated sublanes
return std::pair(dst, std::move(retiled));
}
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
if (dst.isValid(target_shape) && bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * dst.packing(),
ctx.target_shape[1]}) {
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
xla::Array<Value> retiled(dst_tiles_shape);
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
int vty_packing = dst.packing();
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
Expand Down Expand Up @@ -6189,7 +6235,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
if (dst.isValid(target_shape) && bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
Expand Down Expand Up @@ -6232,7 +6278,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
// moving to the next one. This is exactly an interleaving of the sublanes
// of the vreg parts.
xla::Array<Value> retiled(dst_tiles_shape);
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
Expand All @@ -6258,34 +6305,72 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
return std::pair(dst, std::move(retiled));
}
if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) {
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
dst = VectorLayout(
bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim());

DCHECK(llvm::isPowerOf2_32(src_tiling[0]));
DCHECK(llvm::isPowerOf2_32(dst_tiling[0]));
const std::array<int64_t, 2> src_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// All clauses in the and expression are based on performance benchmarking.
bool use_alu = !has_enough_scratch ||
(ctx.hardware_generation >= 5 && src_tiling[0] != packing &&
dst_tiling[0] != packing);

if (use_alu) {
if (src_tiling[0] > dst_tiling[0]) {
return std::pair(
dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs,
dst, target_shape));
if (src_tiling[0] > dst_tiling[0] &&
src.offsets()[0].value_or(0) < dst_vreg_slice[0] &&
src.offsets()[1].value_or(0) < dst_vreg_slice[1]) {
// retileToReducedSublanes does not support offset changes
return std::pair(dst, retileToReducedSublanes(
builder, vty.getShape(), src, vregs,
VectorLayout(bitwidth,
{src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim()),
target_shape));
} else if (!has_enough_scratch) {
// TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops.
return emitError(
loc,
"Not implemented: retiling to increase sublane tiling with ALU");
}
}
xla::Array<Value> retiled(dst_tiles_shape);
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs,
src_tiling, packing))) {

const int64_t alignment_2nd_minor =
std::min(src_vreg_slice[0], dst_vreg_slice[0]);
const int64_t alignment_minor =
std::min(src_vreg_slice[1], dst_vreg_slice[1]);
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
LayoutOffset dst_offset_0 =
src.offsets()[0].value_or(0) % alignment_2nd_minor;
LayoutOffset dst_offset_1 = src.offsets()[1].value_or(0) % alignment_minor;
if (dst_offset_hints[0].has_value() &&
(*dst_offset_hints[0] - src.offsets()[0].value_or(0)) %
alignment_2nd_minor ==
0) {
dst_offset_0 = dst_offset_hints[0];
}
if (dst_offset_hints[1].has_value() &&
(*dst_offset_hints[1] - src.offsets()[1].value_or(0)) %
alignment_minor ==
0) {
dst_offset_1 = dst_offset_hints[1];
}
const LayoutOffsets dst_offsets{dst_offset_0, dst_offset_1};
dst = VectorLayout(bitwidth, dst_offsets, dst_tiling, dst.implicit_dim());
TPU_ASSERT_LOC(loc, dst.isValid(target_shape));

xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling,
dst_offsets, vregs, src_tiling,
/*src_offsets=*/
LayoutOffsets{src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)},
packing))) {
return failure();
}
return std::pair(dst, std::move(retiled));
Expand Down Expand Up @@ -6523,9 +6608,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles),
dst.tiling(),
dst.offsets()[0] == std::nullopt &&
src.offsets()[0] != std::nullopt));
dst.tiling(), dst.offsets()));

FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
Expand Down

0 comments on commit 77cbef4

Please sign in to comment.