diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index dd63ba66cb9c..148230750408 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1668,10 +1668,17 @@ class VectorLayoutInferer { Layout dst_layout; if (layout.tiling() == nativeTiling(src_bitwidth)) { // If the source is already in native tiling, we can unpack it directly. - src_layout = layout; + std::array dst_native_tiling = nativeTiling(dst_bitwidth); + LayoutOffsets offsets = {layout.offsets()[0] + ? *layout.offsets()[0] % dst_native_tiling[0] + : LayoutOffset(), + layout.offsets()[1]}; + DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]); + src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(), + layout.implicit_dim()); dst_layout = - VectorLayout(dst_bitwidth, layout.offsets(), - nativeTiling(dst_bitwidth), layout.implicit_dim()); + VectorLayout(dst_bitwidth, offsets, nativeTiling(dst_bitwidth), + layout.implicit_dim()); } else if (dst_bitwidth == 32 && default_tiling_[0] % layout.tiling()[0] == 0 && default_tiling_[1] == layout.tiling()[1]) { @@ -1680,13 +1687,18 @@ class VectorLayoutInferer { // tiling through the op. // TODO(jevinjiang): we can relax this for non-32bit as well. src_layout = layout; - dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), - layout.implicit_dim()); + dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), + src_layout->tiling(), layout.implicit_dim()); } else { // TODO(b/335863273): we should also reduce offsets. - src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_, + LayoutOffsets offsets = { + layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0] + : LayoutOffset(), + layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1] + : LayoutOffset()}; + src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_, layout.implicit_dim()); - dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_, + dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_, layout.implicit_dim()); } setLayout(op, src_layout, dst_layout);