diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index e29e34b962..5c2a054153 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -202,8 +202,12 @@ def embedding_2bit( weight_quant_max: int, indices: torch.Tensor, ) -> torch.Tensor: - assert weight_quant_min == -2, "embedding_2bit in ExecuTorch expects weight_quant_min == -2" - assert weight_quant_max == 1, "embedding_2bit in ExecuTorch expects weight_quant_max == 1" + assert ( + weight_quant_min == -2 + ), "embedding_2bit in ExecuTorch expects weight_quant_min == -2" + assert ( + weight_quant_max == 1 + ), "embedding_2bit in ExecuTorch expects weight_quant_max == 1" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (4 * weight.size(1)) // ( @@ -260,8 +264,12 @@ def embedding_2bit_dtype( indices: torch.Tensor, dtype: Optional[torch.dtype], ) -> torch.Tensor: - assert weight_quant_min == -2, "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2" - assert weight_quant_max == 1, "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1" + assert ( + weight_quant_min == -2 + ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2" + assert ( + weight_quant_max == 1 + ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (4 * weight.size(1)) // ( @@ -340,8 +348,12 @@ def embedding_4bit( weight_quant_max: int, indices: torch.Tensor, ) -> torch.Tensor: - assert weight_quant_min == -8, "embedding_4bit in ExecuTorch expects weight_quant_min == -8" - assert weight_quant_max == 7, "embedding_4bit in ExecuTorch expects weight_quant_max == 7" + assert ( + weight_quant_min == -8 + ), "embedding_4bit in ExecuTorch expects weight_quant_min == -8" + assert ( + weight_quant_max == 7 + ), "embedding_4bit in ExecuTorch expects weight_quant_max == 7" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (2 * weight.size(1)) // ( @@ -396,8 +408,12 @@ def embedding_4bit_dtype( indices: torch.Tensor, dtype: Optional[torch.dtype], ) -> torch.Tensor: - assert weight_quant_min == -8, "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8" - assert weight_quant_max == 7, "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7" + assert ( + weight_quant_min == -8 + ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8" + assert ( + weight_quant_max == 7 + ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7" embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = (2 * weight.size(1)) // (