diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index cac197479f4a..aa1df2ee47e2 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5726,7 +5726,9 @@ LogicalResult retileToLargeTileWithScratch( RewriteContext &ctx, OpBuilder &builder, const Location loc, xla::Array &dst_tiles, const std::array &dst_tile, const xla::Array &src_tiles, const std::array &src_tile, - TypedValue scratch_ref) { + TypedValue scratch_ref, const int64_t store_2nd_minor_offset, + const int64_t load_minor_offset) { + // TODO(DO NOT SUBMIT): Can we infer offsets from vreg array sizes? if (dst_tile[0] % src_tile[0] != 0) { return failure(); } @@ -5830,7 +5832,7 @@ LogicalResult retileToLargeTileWithScratch( SmallVector src_idx(rank); dst_tiles.Each([&](absl::Span dst_idx, Value *dst_vreg) { int64_t dst_row_idx = *(dst_idx.end() - 2); - int64_t dst_col_idx = *(dst_idx.end() - 1); + int64_t dst_col_idx = *(dst_idx.end() - 1) + load_minor_offset; int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group; int64_t load_offset = sublanes_per_group * stored_group_cnt + vreg_idx_in_group * sl_per_vreg * stride; @@ -5841,16 +5843,21 @@ LogicalResult retileToLargeTileWithScratch( // the vregs from current group and now we need to store corresponding // group of src vregs before actually emitting the loads. if (vreg_idx_in_group == vregs_per_group - 1 || - dst_col_idx == dst_tiles.dimensions().back() - 1) { - auto src_row_idx = dst_row_idx * vregs_per_group; + dst_idx.back() == dst_tiles.dimensions().back() - 1) { + auto base_src_row_idx = + dst_row_idx * vregs_per_group - store_2nd_minor_offset; auto src_col_idx = dst_col_idx / vregs_per_group; std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin()); for (int vi = 0; vi < vregs_per_group; ++vi) { - if (src_row_idx + vi >= src_tiles.dim(rank - 2) || + const int64_t src_row_idx = base_src_row_idx + vi; + if (src_row_idx < 0) { + continue; + } + if (src_row_idx >= src_tiles.dim(rank - 2) || src_col_idx >= src_tiles.dim(rank - 1)) { break; } - *(src_idx.end() - 2) = src_row_idx + vi; + *(src_idx.end() - 2) = src_row_idx; *(src_idx.end() - 1) = src_col_idx; Value src_vreg = src_tiles(src_idx); src_vreg = @@ -5879,7 +5886,8 @@ LogicalResult retileToSmallTileWithScratch( RewriteContext &ctx, OpBuilder &builder, const Location loc, xla::Array &dst_tiles, const std::array &dst_tile, const xla::Array &src_tiles, const std::array &src_tile, - TypedValue scratch_ref) { + TypedValue scratch_ref, const int64_t store_minor_offset, + const int64_t load_2nd_minor_offset) { if (src_tile[0] % dst_tile[0] != 0) { return failure(); } @@ -6006,7 +6014,7 @@ LogicalResult retileToSmallTileWithScratch( SmallVector dst_idx(rank); src_tiles.Each([&](absl::Span src_idx, Value src_vreg) { int64_t src_row_idx = *(src_idx.end() - 2); - int64_t src_col_idx = *(src_idx.end() - 1); + int64_t src_col_idx = *(src_idx.end() - 1) + store_minor_offset; int64_t vreg_idx_in_group = src_col_idx % vregs_per_group; src_vreg = builder.create(loc, temp_vreg_ty, src_vreg); if (use_shuffled_load) { @@ -6029,16 +6037,21 @@ LogicalResult retileToSmallTileWithScratch( // vregs' row, this indicates we have stored all the vregs needed to // assemble a new group of dst vreg. if (vreg_idx_in_group == vregs_per_group - 1 || - src_col_idx == src_tiles.dimensions().back() - 1) { - auto dst_row_idx = src_row_idx * vregs_per_group; + src_idx.back() == src_tiles.dimensions().back() - 1) { + auto base_dst_row_idx = + src_row_idx * vregs_per_group - load_2nd_minor_offset; auto dst_col_idx = src_col_idx / vregs_per_group; std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin()); for (int vi = 0; vi < vregs_per_group; ++vi) { - if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) || + const int64_t dst_row_idx = base_dst_row_idx + vi; + if (dst_row_idx < 0) { + continue; + } + if (dst_row_idx >= dst_tiles.dim(rank - 2) || dst_col_idx >= dst_tiles.dim(rank - 1)) { break; } - *(dst_idx.end() - 2) = dst_row_idx + vi; + *(dst_idx.end() - 2) = dst_row_idx; *(dst_idx.end() - 1) = dst_col_idx; Value *dst_vreg = &dst_tiles(dst_idx); int64_t load_offset = @@ -6067,14 +6080,16 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder, const Location loc, xla::Array &dst_tiles, const std::array &dst_tiling, + const LayoutOffsets dst_offsets, const xla::Array &src_tiles, const std::array &src_tiling, - int packing) { + const LayoutOffsets src_offsets, int packing) { if (!(src_tiling[1] == ctx.target_shape[1] && dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 && dst_tiling[0] % packing == 0)) { return failure(); } + const int bitwidth = 32 / packing; // Try to get i32 vector scratch space. Because we will bitcast vregs to // i32 vregs before using scratch for retiling. Through this way we can // handle packed types as well. @@ -6089,15 +6104,35 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder, dst_tiling[1]}; std::array vi32_src_tiling = {src_tiling[0] / packing, src_tiling[1]}; + // TODO(tlongeri): Think about replicated offsets + CHECK(dst_offsets[0].has_value()); + CHECK(dst_offsets[1].has_value()); + CHECK(src_offsets[0].has_value()); + CHECK(src_offsets[1].has_value()); + const std::array src_vreg_slice = + VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling); + const std::array dst_vreg_slice = + VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling); + const int64_t alignment_2nd_minor = + std::min(src_vreg_slice[0], dst_vreg_slice[0]); + const int64_t alignment_minor = + std::min(src_vreg_slice[1], dst_vreg_slice[1]); + CHECK_EQ((*dst_offsets[0] - *src_offsets[0]) % alignment_2nd_minor, 0); + CHECK_EQ((*dst_offsets[1] - *src_offsets[1]) % alignment_minor, 0); if (src_tiling[0] > dst_tiling[0]) { - return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles, - vi32_dst_tiling, src_tiles, - vi32_src_tiling, ref); + const int64_t store_minor_offset = *dst_offsets[1] / alignment_minor; + const int64_t load_2nd_minor_offset = *src_offsets[0] / alignment_2nd_minor; + return retileToSmallTileWithScratch( + ctx, builder, loc, dst_tiles, vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref, store_minor_offset, load_2nd_minor_offset); } if (src_tiling[0] < dst_tiling[0]) { - return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles, - vi32_dst_tiling, src_tiles, - vi32_src_tiling, ref); + const int64_t store_2nd_minor_offset = + *dst_offsets[0] / alignment_2nd_minor; + const int64_t load_minor_offset = *src_offsets[1] / alignment_minor; + return retileToLargeTileWithScratch( + ctx, builder, loc, dst_tiles, vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref, store_2nd_minor_offset, load_minor_offset); } dst_tiles = std::move(src_tiles); return success(); @@ -6106,7 +6141,8 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder, FailureOr>> changeTiling( RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, const VectorLayout src, xla::Array vregs, - const std::array dst_tiling, bool try_replicate_rows) { + const std::array dst_tiling, + const LayoutOffsets dst_offset_hints) { bool has_enough_scratch = ctx.max_sublanes_in_scratch >= ctx.target_shape[0] * (ctx.target_shape[0] + 1); const auto &target_shape = ctx.target_shape; @@ -6116,42 +6152,52 @@ FailureOr>> changeTiling( } const int packing = src.packing(); const int8_t bitwidth = src.bitwidth(); - VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, - src.implicit_dim()); - if (!dst.isValid(target_shape)) { - return emitError(loc, "Not implemented: invalid offsets in tiling target"); - } - auto dst_tiles_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating // sublanes. - if (try_replicate_rows && packing == 1 && - *(vregs.dimensions().end() - 2) == 1 && - src.offsets() == LayoutOffsets{0, 0} && + if (!dst_offset_hints[0].has_value() && packing == 1 && + *(vty.getShape().end() - 2) == 1 && src.tiling() == std::array{1, ctx.target_shape[1]} && dst_tiling == ctx.target_shape) { - xla::Array retiled(dst_tiles_shape); + DCHECK_EQ(src.offsets()[0].value_or(0), 0); + const LayoutOffset dst_minor_offset = + src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1]) + : std::nullopt; + const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset}, + dst_tiling, src.implicit_dim()); + xla::Array retiled( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); *(src_idx.end() - 2) *= target_shape[0]; - *(src_idx.end() - 1) /= target_shape[0]; - const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0]; - CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + if (!src.offsets()[1].has_value()) { + // With (1, 128) tiling and replicated minor, each vreg is already fully + // replicated + *(src_idx.end() - 1) = 0; + *tile = vregs(src_idx); + } else { + // The column (in units of sublanes) of the sublane we want: + const int64_t sublane_column = + *(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1]; + *(src_idx.end() - 1) = sublane_column / target_shape[0]; + const int64_t src_sl_idx = sublane_column % target_shape[0]; + *tile = + broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + } }); - // We have successfully replicated sublanes. - dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, - dst.implicit_dim()); + // We have successfully replicated sublanes return std::pair(dst, std::move(retiled)); } + VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, + src.implicit_dim()); // (8,128) -> (8 * packing,128) tiling change for packed type. - if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape && + if (dst.isValid(target_shape) && bitwidth < 32 && 32 % bitwidth == 0 && + src_tiling == ctx.target_shape && dst_tiling == std::array{ctx.target_shape[0] * dst.packing(), ctx.target_shape[1]}) { // Note: for int4, retiling with scratch is always faster. if (bitwidth != 4 || !has_enough_scratch) { - xla::Array retiled(dst_tiles_shape); + xla::Array retiled( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); int vty_packing = dst.packing(); VectorType vreg_x32 = vty.getElementType().isSignlessInteger() @@ -6189,7 +6235,7 @@ FailureOr>> changeTiling( // interesting if the next step is a retile, since we can also // match corresponding elements without shifting. It's just that // the tiles are not adjacent (no contiguous vreg slice). - if (bitwidth < 32 && 32 % bitwidth == 0 && + if (dst.isValid(target_shape) && bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == std::array{1, ctx.target_shape[1] * packing} && dst_tiling == std::array{packing, ctx.target_shape[1]}) { // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of @@ -6232,7 +6278,8 @@ FailureOr>> changeTiling( // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before // moving to the next one. This is exactly an interleaving of the sublanes // of the vreg parts. - xla::Array retiled(dst_tiles_shape); + xla::Array retiled( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); const VectorType vreg_x32 = vty.getElementType().isSignlessInteger() ? VectorType::get(target_shape, builder.getI32Type()) @@ -6258,24 +6305,29 @@ FailureOr>> changeTiling( return std::pair(dst, std::move(retiled)); } if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) { - // TODO(b/368088671): When sublane tiling changes, we should be able to - // preserve some replications from the source layout. But we need to - // make sure they are implemented efficiently and well-tested. For now, we - // just simply use 0 for the replicated offset after retiling. - dst = VectorLayout( - bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)}, - dst_tiling, dst.implicit_dim()); - + DCHECK(llvm::isPowerOf2_32(src_tiling[0])); + DCHECK(llvm::isPowerOf2_32(dst_tiling[0])); + const std::array src_vreg_slice = + VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling); + const std::array dst_vreg_slice = + VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling); // All clauses in the and expression are based on performance benchmarking. bool use_alu = !has_enough_scratch || (ctx.hardware_generation >= 5 && src_tiling[0] != packing && dst_tiling[0] != packing); if (use_alu) { - if (src_tiling[0] > dst_tiling[0]) { - return std::pair( - dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs, - dst, target_shape)); + if (src_tiling[0] > dst_tiling[0] && + src.offsets()[0].value_or(0) < dst_vreg_slice[0] && + src.offsets()[1].value_or(0) < dst_vreg_slice[1]) { + // retileToReducedSublanes does not support offset changes + return std::pair(dst, retileToReducedSublanes( + builder, vty.getShape(), src, vregs, + VectorLayout(bitwidth, + {src.offsets()[0].value_or(0), + src.offsets()[1].value_or(0)}, + dst_tiling, dst.implicit_dim()), + target_shape)); } else if (!has_enough_scratch) { // TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops. return emitError( @@ -6283,9 +6335,42 @@ FailureOr>> changeTiling( "Not implemented: retiling to increase sublane tiling with ALU"); } } - xla::Array retiled(dst_tiles_shape); - if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs, - src_tiling, packing))) { + + const int64_t alignment_2nd_minor = + std::min(src_vreg_slice[0], dst_vreg_slice[0]); + const int64_t alignment_minor = + std::min(src_vreg_slice[1], dst_vreg_slice[1]); + // TODO(b/368088671): When sublane tiling changes, we should be able to + // preserve some replications from the source layout. But we need to + // make sure they are implemented efficiently and well-tested. For now, we + // just simply use 0 for the replicated offset after retiling. + LayoutOffset dst_offset_0 = + src.offsets()[0].value_or(0) % alignment_2nd_minor; + LayoutOffset dst_offset_1 = src.offsets()[1].value_or(0) % alignment_minor; + if (dst_offset_hints[0].has_value() && + (*dst_offset_hints[0] - src.offsets()[0].value_or(0)) % + alignment_2nd_minor == + 0) { + dst_offset_0 = dst_offset_hints[0]; + } + if (dst_offset_hints[1].has_value() && + (*dst_offset_hints[1] - src.offsets()[1].value_or(0)) % + alignment_minor == + 0) { + dst_offset_1 = dst_offset_hints[1]; + } + const LayoutOffsets dst_offsets{dst_offset_0, dst_offset_1}; + dst = VectorLayout(bitwidth, dst_offsets, dst_tiling, dst.implicit_dim()); + TPU_ASSERT_LOC(loc, dst.isValid(target_shape)); + + xla::Array retiled( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, + dst_offsets, vregs, src_tiling, + /*src_offsets=*/ + LayoutOffsets{src.offsets()[0].value_or(0), + src.offsets()[1].value_or(0)}, + packing))) { return failure(); } return std::pair(dst, std::move(retiled)); @@ -6523,9 +6608,7 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), - dst.tiling(), - dst.offsets()[0] == std::nullopt && - src.offsets()[0] != std::nullopt)); + dst.tiling(), dst.offsets())); FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles),