diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 5c9b3d178c15..7ca4204db343 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4623,7 +4623,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, incremented_batch_idx.end()); src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end}); xla::Array src_tile_vregs = src_vregs.Slice( - src_slice_starts, src_slice_ends, /*out_of_bounds_ok=*/true); + src_slice_starts, src_slice_ends, + builder.create( + op.getLoc(), builder.getZeroAttr(src_vregs.begin()->getType()))); // Drop leading singleton (batch) dimensions to have a shape that conforms // with the vreg array shape specified by layout_in, as expected by assemble src_tile_vregs.Reshape(