Skip to content

Commit

Permalink
Support qmatmul with different dims tensors
Browse files Browse the repository at this point in the history
Summary:
MobileBERT exposes an issue in our kernel, where tensors have compatible (for PyTorch) but different batch dimensions.

This diff changes the meta kernel to support that (the kernel can already do it).

Differential Revision: D60314979
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 26, 2024
1 parent 5a20a49 commit 5524ec1
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from math import prod
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -186,28 +187,29 @@ def quantized_matmul_meta(
X_size = list(X.size())
Y_size = list(Y.size())

assert len(X_size) == len(
Y_size
), "quantized matmul not supported for tensors of different dimensions"

if len(X_size) == 3:
assert (
X_size[0] == Y_size[0]
), "quantized matmul only supported for batch dimension of same size"
if transposed:
assert X_size[2] == Y_size[2], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[1]]
else:
assert X_size[2] == Y_size[1], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[2]]
elif len(X_size) == 2:
if transposed:
assert X_size[1] == Y_size[1], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[0]]
else:
assert X_size[1] == Y_size[0], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[1]]
# Get the batch dimensions for both tensors
X_batch_dims = X_size[:-2]
Y_batch_dims = Y_size[:-2]

# If they don't match, check that they're compatible
if X_batch_dims != Y_batch_dims:
assert prod(X_batch_dims) == prod(
Y_batch_dims
), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}"

# Get the matmul output size
if transposed:
assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied"
mat_size = [X_size[-2], Y_size[-2]]
else:
raise AssertionError("quantized matmul only supported for 2D or 3D tensors")
assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied"
mat_size = [X_size[-2], Y_size[-1]]

# Combine the larger batch dimensions with the matmul output size
out_size = (
X_batch_dims + mat_size
if len(X_batch_dims) > len(Y_batch_dims)
else Y_batch_dims + mat_size
)

return X.new_empty(out_size, dtype=X.dtype)

0 comments on commit 5524ec1

Please sign in to comment.