Skip to content

Commit

Permalink
[Mosaic] Add support for generalized reduction unrolling to infer_vec…
Browse files Browse the repository at this point in the history
…tor_layout

Otherwise it's inaccessible to users.

PiperOrigin-RevId: 578471167
  • Loading branch information
apaszke authored and jax authors committed Nov 1, 2023
1 parent f17f549 commit 5d28961
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 24 deletions.
15 changes: 15 additions & 0 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,21 @@ llvm::hash_code hash_value(const VectorLayout& layout) {
return llvm::hash_value(layout.as_tuple());
}

std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim) {
switch (dim) {
case VectorLayout::ImplicitDim::kNone:
os << "none";
break;
case VectorLayout::ImplicitDim::kMinor:
os << "-1";
break;
case VectorLayout::ImplicitDim::kSecondMinor:
os << "-2";
break;
}
return os;
}

std::optional<Layout> parseLayout(mlir::AsmParser& parser) {
std::string layout_str;
if (failed(parser.parseString(&layout_str))) {
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ std::ostream &operator<<(std::ostream &os, const Layout &v);
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v);
llvm::hash_code hash_value(const VectorLayout &layout);
mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v);
std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim);

std::optional<Layout> parseLayout(mlir::AsmParser &parser);

Expand Down
7 changes: 5 additions & 2 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2209,7 +2209,8 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
}
const std::array<bool, 2> allow_replicated = {!reduces[0], !reduces[1]};

if (!src_layout.hasNativeTiling(ctx.target_shape)) {
if ((reduces[0] || reduces[1]) &&
!src_layout.hasNativeTiling(ctx.target_shape)) {
return multi_reduction_op.emitOpError(
"Not implemented: Unsupported input layout: ")
<< src_layout;
Expand Down Expand Up @@ -2243,11 +2244,13 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
dst_implicit_dim =
VectorLayout::ImplicitDim::kSecondMinor; // Anything works.
} else if (reduces[0]) {
CHECK_EQ(src_layout.implicit_dim(), VectorLayout::ImplicitDim::kNone);
dst_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor;
} else if (reduces[1]) {
CHECK_EQ(src_layout.implicit_dim(), VectorLayout::ImplicitDim::kNone);
dst_implicit_dim = VectorLayout::ImplicitDim::kMinor;
} else {
dst_implicit_dim = VectorLayout::ImplicitDim::kNone;
dst_implicit_dim = src_layout.implicit_dim();
}
if (dst_layout.implicit_dim() != dst_implicit_dim) {
return multi_reduction_op.emitOpError(
Expand Down
72 changes: 52 additions & 20 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "xla/layout.h"
Expand Down Expand Up @@ -956,34 +957,65 @@ class VectorLayoutInferer {
auto src_ty = op.getSourceVectorType();
auto dst_ty = dyn_cast<VectorType>(op.getDestType());
TPU_CHECK_OP(dst_ty, "only reductions with vector results supported");
TPU_CHECK_OP(src_ty.getRank() == dst_ty.getRank() + 1,
"only 1D reductions supported");
int64_t dim = cast<IntegerAttr>(op.getReductionDims()[0]).getInt();
SmallVector<int64_t> dims;
dims.reserve(op.getReductionDims().size());
for (Attribute dim_attr : op.getReductionDims()) {
dims.push_back(cast<IntegerAttr>(dim_attr).getInt());
}
int64_t src_rank = src_ty.getRank();
auto acc_pad = getLayout(op.getAcc());
TPU_CHECK_OP(is_fully_replicated(acc_pad),
auto acc_layout = getLayout(op.getAcc());
TPU_CHECK_OP(is_fully_replicated(acc_layout),
"only constant accumulators supported");
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == kNativeBitwidth,
"only 32-bit reductions supported");
auto some_src_layout = getLayout(op.getSource());
TPU_CHECK_OP(some_src_layout, "missing vector layout");
auto &src_layout = *some_src_layout;
TPU_CHECK_OP(src_layout.implicit_dim() == ImplicitDim::kNone,
"only 2D layouts supported");
if (dim == src_rank - 1) {
setLayout(
op, {src_layout, acc_pad},
VectorLayout(kNativeBitwidth, {src_layout.offsets()[0], std::nullopt},
default_tiling_, ImplicitDim::kMinor));
} else if (dim == src_rank - 2) {
setLayout(
op, {src_layout, acc_pad},
VectorLayout(kNativeBitwidth, {std::nullopt, src_layout.offsets()[1]},
default_tiling_, ImplicitDim::kSecondMinor));
} else {
// Reduction happens over the unrolled dimension --- we can keep layout.
setLayout(op, {src_layout, acc_pad}, src_layout);
std::array<bool, 2> reduces;
switch (src_layout.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone:
reduces = {
std::find(dims.begin(), dims.end(), src_rank - 2) != dims.end(),
std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end()};
break;
case VectorLayout::ImplicitDim::kSecondMinor:
reduces = {false, std::find(dims.begin(), dims.end(), src_rank - 1) !=
dims.end()};
break;
case VectorLayout::ImplicitDim::kMinor:
reduces = {
std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end(),
false};
break;
}
if ((reduces[0] || reduces[1]) &&
!src_layout.hasNativeTiling(target_shape_)) {
src_layout = VectorLayout(kNativeBitwidth, src_layout.offsets(),
default_tiling_, src_layout.implicit_dim());
}
LayoutOffsets out_offsets = src_layout.offsets();
for (int i = 0; i < out_offsets.size(); ++i) {
if (reduces[i]) {
out_offsets[i] = std::nullopt;
}
}
ImplicitDim out_implicit_dim = src_layout.implicit_dim();
if ((reduces[0] && reduces[1]) ||
(src_layout.implicit_dim() != ImplicitDim::kNone &&
(reduces[0] || reduces[1]))) {
TPU_CHECK_OP(
dst_ty.getRank() > 0 && *(dst_ty.getShape().end() - 1) == 1,
"Not implemented: reductions over both trailing dimensions are only "
"supported when the resulting value has a trailing axis of size 1");
out_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor;
} else if (reduces[0]) {
out_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor;
} else if (reduces[1]) {
out_implicit_dim = VectorLayout::ImplicitDim::kMinor;
}
setLayout(op, {src_layout, acc_layout},
VectorLayout(src_layout.bitwidth(), out_offsets,
src_layout.tiling(), out_implicit_dim));
return success();
}

Expand Down
6 changes: 4 additions & 2 deletions jaxlib/mosaic/python/apply_vector_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3060,7 +3060,7 @@ def _vector_multi_reduction_rule( # pylint: disable=missing-function-docstring
reduces = TargetTuple((src_rank - 1) in dims, False)
allow_replicated = TargetTuple(not reduces.sublanes, not reduces.lanes)

if not src_layout.has_native_tiling:
if any(reduces) and not src_layout.has_native_tiling:
raise NotImplementedError("unsupported input layout")
if src_layout.tiling != dst_layout.tiling:
raise NotImplementedError("tiling shouldn't change")
Expand All @@ -3082,11 +3082,13 @@ def _vector_multi_reduction_rule( # pylint: disable=missing-function-docstring
)
dst_implicit_dim = ImplicitDim.SECOND_MINOR # Whatever works.
elif reduces.lanes:
assert src_layout.implicit_dim is None
dst_implicit_dim = ImplicitDim.MINOR
elif reduces.sublanes:
assert src_layout.implicit_dim is None
dst_implicit_dim = ImplicitDim.SECOND_MINOR
else:
dst_implicit_dim = None
dst_implicit_dim = src_layout.implicit_dim
if dst_implicit_dim != dst_layout.implicit_dim:
raise NotImplementedError("unsupported output implicit dim")

Expand Down

0 comments on commit 5d28961

Please sign in to comment.