Skip to content

Commit

Permalink
format fix
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Dec 20, 2024
1 parent d0ca255 commit a007d75
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions exir/passes/_quant_patterns_and_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,12 @@ def embedding_2bit(
weight_quant_max: int,
indices: torch.Tensor,
) -> torch.Tensor:
assert weight_quant_min == -2, "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
assert weight_quant_max == 1, "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
assert (
weight_quant_min == -2
), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
assert (
weight_quant_max == 1
), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"

embedding_weight_checks(weight, weight_scales, weight_zero_points)
group_size = (4 * weight.size(1)) // (
Expand Down Expand Up @@ -260,8 +264,12 @@ def embedding_2bit_dtype(
indices: torch.Tensor,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
assert weight_quant_min == -2, "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
assert weight_quant_max == 1, "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
assert (
weight_quant_min == -2
), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
assert (
weight_quant_max == 1
), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"

embedding_weight_checks(weight, weight_scales, weight_zero_points)
group_size = (4 * weight.size(1)) // (
Expand Down Expand Up @@ -340,8 +348,12 @@ def embedding_4bit(
weight_quant_max: int,
indices: torch.Tensor,
) -> torch.Tensor:
assert weight_quant_min == -8, "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
assert weight_quant_max == 7, "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
assert (
weight_quant_min == -8
), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
assert (
weight_quant_max == 7
), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"

embedding_weight_checks(weight, weight_scales, weight_zero_points)
group_size = (2 * weight.size(1)) // (
Expand Down Expand Up @@ -396,8 +408,12 @@ def embedding_4bit_dtype(
indices: torch.Tensor,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
assert weight_quant_min == -8, "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
assert weight_quant_max == 7, "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
assert (
weight_quant_min == -8
), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
assert (
weight_quant_max == 7
), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"

embedding_weight_checks(weight, weight_scales, weight_zero_points)
group_size = (2 * weight.size(1)) // (
Expand Down

0 comments on commit a007d75

Please sign in to comment.