Skip to content

Commit

Permalink
[Mosaic] Fix insufficient checks in lane shifts in relayout
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571890703
  • Loading branch information
apaszke authored and jax authors committed Oct 9, 2023
1 parent d0976f0 commit 1441436
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
17 changes: 9 additions & 8 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3013,11 +3013,10 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
}
// Try to reconcile differences in implicit dim.
if (src.implicit_dim() != dst.implicit_dim()) {
const ArrayRef<int64_t> shape = vty.getShape();
if (dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
shape[shape.size() - xla::to_underlying(src.implicit_dim())] == 1) {
src = VectorLayout(src.bitwidth(), src.offsets(), src.tiling(),
VectorLayout::ImplicitDim::kNone);
VectorLayout candidate(src.bitwidth(), src.offsets(), src.tiling(),
dst.implicit_dim());
if (candidate.equivalentTo(src, vty.getShape(), ctx.target_shape)) {
src = candidate;
}
}

Expand Down Expand Up @@ -3162,6 +3161,7 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
// Fix up the offsets, assuming everything else matches between src and dst.
if (src.tiling() == dst.tiling() &&
src.implicit_dim() == dst.implicit_dim()) {
const auto &tiling = src.tiling();
// TODO(apaszke): Changing an offset might add or remove one vreg.
if (dst_tiles_shape != src_tiles.dimensions()) {
return emitError(v.getLoc(), "Offsets changing the vreg array shape");
Expand Down Expand Up @@ -3245,9 +3245,10 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
if (col_diff < 0) {
return emitError(v.getLoc(), "Not implemented: Shifts to the left");
}
if (bitwidth != 32) {
return emitError(
v.getLoc(), "Not implemented: Only 32-bit column shifts supported");
if (bitwidth != 32 || tiling != ctx.target_shape) {
return emitError(v.getLoc(),
"Not implemented: Only 32-bit column shifts for "
"native layouts supported");
}
const int64_t sublane_diff = col_diff;
CHECK_GE(src_tiles.num_dimensions(), 1);
Expand Down
16 changes: 11 additions & 5 deletions jaxlib/mosaic/python/apply_vector_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,8 +1223,11 @@ def relayout(

# Try to reconcile differences in implicit dim.
if src.implicit_dim != dst.implicit_dim:
if dst.implicit_dim is None and vty.shape[src.implicit_dim] == 1:
src = VectorLayout(src.bitwidth, src.offsets, src.tiling, None)
candidate = VectorLayout(
src.bitwidth, src.offsets, src.tiling, dst.implicit_dim
)
if candidate.equivalent_to(src, vty.shape):
src = candidate

# Handle retiling from (1, 128) to (8, 128) for 32-bit data.
if (
Expand Down Expand Up @@ -1315,7 +1318,7 @@ def relayout(
src = new_src
src_tiles = src_tiles_retiled

# (8, 128) -> (32, 128) for int8. Useful for preparing data for matmuls.
# (8, 128) -> (32, 128) for int8. Useful for preparing data for matmuls.
if (
src.implicit_dim is None
and dst.implicit_dim is None
Expand Down Expand Up @@ -1357,6 +1360,7 @@ def relayout(

# Fix up the offsets, assuming everything else matches between src and dst.
if src.tiling == dst.tiling and src.implicit_dim == dst.implicit_dim:
tiling = src.tiling
# TODO(apaszke): Changing an offset might add or remove one vreg.
if dst_tiles_shape != src_tiles.shape:
raise NotImplementedError("Offsets changing the vreg array shape")
Expand Down Expand Up @@ -1420,8 +1424,10 @@ def relayout(
raise NotImplementedError("Both columns and rows are shifted")
if col_diff < 0:
raise NotImplementedError("Shifts to the left")
if bitwidth != 32:
raise NotImplementedError("Only 32-bit column shifts supported")
if bitwidth != 32 or tiling != TARGET_SHAPE:
raise NotImplementedError(
"Only 32-bit column shifts for native layouts supported"
)
sublane_diff = col_diff
sublane_diff_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signed(32), sublane_diff
Expand Down

0 comments on commit 1441436

Please sign in to comment.