From a007d75282842504ed3e65692c883dff5823d8ec Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:57:08 -0800 Subject: [PATCH] format fix --- .../_quant_patterns_and_replacements.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index e29e34b962..5c2a054153 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -202,8 +202,12 @@ def embedding_2bit( weight_quant_max: int, indices: torch.Tensor, ) -> torch.Tensor: - assert weight_quant_min == -2, "embedding_2bit in ExecuTorch expects weight_quant_min == -2" - assert weight_quant_max == 1, "embedding_2bit in ExecuTorch expects weight_quant_max == 1" + assert ( + weight_quant_min == -2 + ), "embedding_2bit in ExecuTorch expects weight_quant_min == -2" + assert ( + weight_quant_max == 1 + ), "embedding_2bit in ExecuTorch expects weight_quant_max == 1" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (4 * weight.size(1)) // ( @@ -260,8 +264,12 @@ def embedding_2bit_dtype( indices: torch.Tensor, dtype: Optional[torch.dtype], ) -> torch.Tensor: - assert weight_quant_min == -2, "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2" - assert weight_quant_max == 1, "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1" + assert ( + weight_quant_min == -2 + ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2" + assert ( + weight_quant_max == 1 + ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (4 * weight.size(1)) // ( @@ -340,8 +348,12 @@ def embedding_4bit( weight_quant_max: int, indices: torch.Tensor, ) -> torch.Tensor: - assert weight_quant_min == -8, "embedding_4bit in ExecuTorch expects weight_quant_min == -8" - assert weight_quant_max == 7, "embedding_4bit in ExecuTorch expects weight_quant_max == 7" + assert ( + weight_quant_min == -8 + ), "embedding_4bit in ExecuTorch expects weight_quant_min == -8" + assert ( + weight_quant_max == 7 + ), "embedding_4bit in ExecuTorch expects weight_quant_max == 7" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (2 * weight.size(1)) // ( @@ -396,8 +408,12 @@ def embedding_4bit_dtype( indices: torch.Tensor, dtype: Optional[torch.dtype], ) -> torch.Tensor: - assert weight_quant_min == -8, "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8" - assert weight_quant_max == 7, "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7" + assert ( + weight_quant_min == -8 + ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8" + assert ( + weight_quant_max == 7 + ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (2 * weight.size(1)) // (