diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 30c0c5a3f..424bccd17 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -31,15 +31,20 @@ OPSET_MERGE_OPS = "MergeOps" OPSET_CONV = "Conv" OPSET_FULLY_CONNECTED = "FullyConnected" +OPSET_BATCH_NORM = "BatchNorm" OPSET_ANY_RELU = "AnyReLU" OPSET_ADD = "Add" OPSET_SUB = "Sub" OPSET_MUL = "Mul" OPSET_DIV = "Div" +OPSET_MIN_MAX = "MinMax" OPSET_PRELU = "PReLU" OPSET_SWISH = "Swish" OPSET_SIGMOID = "Sigmoid" OPSET_TANH = "Tanh" +OPSET_GELU = "Gelu" +OPSET_HARDSIGMOID = "HardSigmoid" +OPSET_HARDSWISH = "HardSwish" def get_tp_model() -> TargetPlatformModel: @@ -172,6 +177,11 @@ def generate_tp_model(default_config: OpQuantizationConfig, # If the QuantizationConfigOptions contains only one configuration, # this configuration will be used for the operation quantization: default_configuration_options = tp.QuantizationConfigOptions([default_config]) + default_config_input16 = default_config.clone_and_edit(supported_input_activation_n_bits=(8, 16)) + default_config_options_16bit = tp.QuantizationConfigOptions([default_config_input16, + default_config_input16.clone_and_edit(activation_n_bits=16, + signedness=Signedness.SIGNED)], + base_config=default_config_input16) # Create a QuantizationConfigOptions for quantizing constants in functional ops. # Constant configuration is similar to the default eight bit configuration except for PoT @@ -212,6 +222,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, weights_per_channel_threshold=False)) qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config]) + mp_cfg_list_16bit = [mp_cfg.clone_and_edit(activation_n_bits=16, signedness=Signedness.SIGNED) + for mp_cfg in mixed_precision_cfg_list] + # Create a TargetPlatformModel and set its default quantization config. # This default configuration will be used for all operations # unless specified otherwise (see OperatorsSet, for example): @@ -246,30 +259,37 @@ def generate_tp_model(default_config: OpQuantizationConfig, tp.OperatorsSet(OPSET_MERGE_OPS, const_configuration_options_inout16_per_tensor) # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects - mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list, + mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list + mp_cfg_list_16bit, base_config=base_config) # Define operator sets that use mixed_precision_configuration_options: conv = tp.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options) fc = tp.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options) - # Define operations sets without quantization configuration - # options (useful for creating fusing patterns, for example): - any_relu = tp.OperatorsSet(OPSET_ANY_RELU) + tp.OperatorsSet(OPSET_BATCH_NORM, default_config_options_16bit) + + # Note: Operations sets without quantization configuration are useful for creating fusing patterns + any_relu = tp.OperatorsSet(OPSET_ANY_RELU, default_config_options_16bit) add = tp.OperatorsSet(OPSET_ADD, const_configuration_options_inout16) sub = tp.OperatorsSet(OPSET_SUB, const_configuration_options_inout16) mul = tp.OperatorsSet(OPSET_MUL, const_configuration_options_inout16) div = tp.OperatorsSet(OPSET_DIV, const_configuration_options) - prelu = tp.OperatorsSet(OPSET_PRELU) - swish = tp.OperatorsSet(OPSET_SWISH) - sigmoid = tp.OperatorsSet(OPSET_SIGMOID) - tanh = tp.OperatorsSet(OPSET_TANH) + tp.OperatorsSet(OPSET_MIN_MAX, const_configuration_options_inout16) + prelu = tp.OperatorsSet(OPSET_PRELU, default_config_options_16bit) + swish = tp.OperatorsSet(OPSET_SWISH, default_config_options_16bit) + sigmoid = tp.OperatorsSet(OPSET_SIGMOID, default_config_options_16bit) + tanh = tp.OperatorsSet(OPSET_TANH, default_config_options_16bit) + gelu = tp.OperatorsSet(OPSET_GELU, default_config_options_16bit) + hardsigmoid = tp.OperatorsSet(OPSET_HARDSIGMOID, default_config_options_16bit) + hardswish = tp.OperatorsSet(OPSET_HARDSWISH, default_config_options_16bit) # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid) + activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, + tanh, gelu, hardswish, hardsigmoid) + activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid, tanh, gelu, + hardswish, hardsigmoid) any_binary = tp.OperatorSetConcat(add, sub, mul, div) # ------------------- # diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py index 76ff28af6..656d57116 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py @@ -26,11 +26,11 @@ if version.parse(tf.__version__) >= version.parse("2.13"): from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ - Conv2DTranspose, Identity, Concatenate + Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum else: from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ - Conv2DTranspose, Identity, Concatenate + Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model import model_compression_toolkit as mct @@ -38,7 +38,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import OPSET_NO_QUANTIZATION, \ OPSET_QUANTIZATION_PRESERVING, OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, OPSET_DIMENSION_MANIPULATION_OPS, \ OPSET_MERGE_OPS, OPSET_CONV, OPSET_FULLY_CONNECTED, OPSET_ANY_RELU, OPSET_ADD, OPSET_SUB, OPSET_MUL, OPSET_DIV, \ - OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH + OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH, OPSET_GELU, OPSET_BATCH_NORM, OPSET_MIN_MAX, OPSET_HARDSIGMOID tp = mct.target_platform @@ -117,6 +117,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): tp.OperationsSetToLayers(OPSET_FULLY_CONNECTED, [Dense], attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers(OPSET_BATCH_NORM, [BatchNormalization]) tp.OperationsSetToLayers(OPSET_ANY_RELU, [tf.nn.relu, tf.nn.relu6, tf.nn.leaky_relu, @@ -128,9 +129,13 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): tp.OperationsSetToLayers(OPSET_SUB, [tf.subtract, Subtract]) tp.OperationsSetToLayers(OPSET_MUL, [tf.math.multiply, Multiply]) tp.OperationsSetToLayers(OPSET_DIV, [tf.math.divide, tf.math.truediv]) + tp.OperationsSetToLayers(OPSET_MIN_MAX, [tf.math.minimum, tf.math.maximum, Minimum, Maximum]) tp.OperationsSetToLayers(OPSET_PRELU, [PReLU]) tp.OperationsSetToLayers(OPSET_SWISH, [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) tp.OperationsSetToLayers(OPSET_SIGMOID, [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) tp.OperationsSetToLayers(OPSET_TANH, [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) + tp.OperationsSetToLayers(OPSET_GELU, [tf.nn.gelu, tp.LayerFilterParams(Activation, activation="gelu")]) + tp.OperationsSetToLayers(OPSET_HARDSIGMOID, [tf.keras.activations.hard_sigmoid, + tp.LayerFilterParams(Activation, activation="hard_sigmoid")]) return keras_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py index 75ebf787e..0fa7bda97 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py @@ -17,11 +17,13 @@ import torch from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \ - chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract -from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d + chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \ + maximum +from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d from torch.nn import Dropout, Flatten, Hardtanh -from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU -from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu +from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU +import torch.nn.functional as F +from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu from model_compression_toolkit.defaultdict import DefaultDict from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, PYTORCH_KERNEL, \ @@ -32,7 +34,8 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import OPSET_NO_QUANTIZATION, \ OPSET_QUANTIZATION_PRESERVING, OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, OPSET_DIMENSION_MANIPULATION_OPS, \ OPSET_MERGE_OPS, OPSET_CONV, OPSET_FULLY_CONNECTED, OPSET_ANY_RELU, OPSET_ADD, OPSET_SUB, OPSET_MUL, OPSET_DIV, \ - OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH + OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH, OPSET_GELU, OPSET_BATCH_NORM, OPSET_MIN_MAX, OPSET_HARDSIGMOID, \ + OPSET_HARDSWISH tp = mct.target_platform @@ -95,6 +98,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): attr_mapping=pytorch_linear_attr_mapping) tp.OperationsSetToLayers(OPSET_FULLY_CONNECTED, [Linear], attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers(OPSET_BATCH_NORM, [BatchNorm2d]) tp.OperationsSetToLayers(OPSET_ANY_RELU, [torch.relu, ReLU, ReLU6, @@ -109,9 +113,13 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): tp.OperationsSetToLayers(OPSET_SUB, [operator.sub, sub, subtract]) tp.OperationsSetToLayers(OPSET_MUL, [operator.mul, mul, multiply]) tp.OperationsSetToLayers(OPSET_DIV, [operator.truediv, div, divide]) + tp.OperationsSetToLayers(OPSET_MIN_MAX, [minimum, maximum]) tp.OperationsSetToLayers(OPSET_PRELU, [PReLU, prelu]) - tp.OperationsSetToLayers(OPSET_SWISH, [SiLU, silu, Hardswish, hardswish]) - tp.OperationsSetToLayers(OPSET_SIGMOID, [Sigmoid, sigmoid]) - tp.OperationsSetToLayers(OPSET_TANH, [Tanh, tanh]) + tp.OperationsSetToLayers(OPSET_SWISH, [SiLU, silu]) + tp.OperationsSetToLayers(OPSET_SIGMOID, [Sigmoid, sigmoid, F.sigmoid]) + tp.OperationsSetToLayers(OPSET_TANH, [Tanh, tanh, F.tanh]) + tp.OperationsSetToLayers(OPSET_GELU, [GELU, gelu]) + tp.OperationsSetToLayers(OPSET_HARDSIGMOID, [Hardsigmoid, hardsigmoid]) + tp.OperationsSetToLayers(OPSET_HARDSWISH, [Hardswish, hardswish]) return pytorch_tpc diff --git a/tests/keras_tests/layer_tests/test_layers_runner.py b/tests/keras_tests/layer_tests/test_layers_runner.py index 572b69d6a..bc72fc578 100644 --- a/tests/keras_tests/layer_tests/test_layers_runner.py +++ b/tests/keras_tests/layer_tests/test_layers_runner.py @@ -30,6 +30,7 @@ def test_activation(self): BaseKerasLayerTest(self, [Activation('linear'), Activation('hard_sigmoid'), + tf.keras.activations.hard_sigmoid, Activation('exponential')]).run_test() def test_softplus(self): diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py index bf1f3bf74..cfc6fa2e8 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py @@ -19,6 +19,8 @@ from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.core import MixedPrecisionQuantizationConfig from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import \ + OPSET_MUL, OPSET_GELU, OPSET_TANH from model_compression_toolkit.core.pytorch.utils import get_working_device from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest @@ -28,10 +30,14 @@ class Activation16BitNet(torch.nn.Module): - def __init__(self, use_concat=True): + def __init__(self, use_concat=True, enable_head=True): super().__init__() self.use_concat = use_concat + self.enable_head = enable_head self.conv = torch.nn.Conv2d(3, 3, 1) + if enable_head: + self.conv_a = torch.nn.Conv2d(3, 3, 1) + self.conv_b = torch.nn.Conv2d(3, 3, 1) self.register_buffer('add_const', torch.rand((3, 1, 1))) self.register_buffer('sub_const', torch.rand((3, 1, 1))) self.register_buffer('div_const', 2*torch.ones((3, 1, 1))) @@ -47,20 +53,32 @@ def forward(self, x): x = torch.reshape(x, (-1, 3, 8*(1+int(self.use_concat)), 8)) x = self.conv(x) x = torch.divide(x, self.div_const) + + if self.enable_head: + x = torch.cat([torch.nn.functional.gelu(self.conv_a(x)), + torch.nn.functional.tanh(self.conv_b(x))], dim=1) + return x +def set_16bit_as_default(tpc, required_op_set, required_ops_list): + op_set = get_op_set(required_op_set, tpc.tp_model.operator_set) + op_set.qc_options.base_config = [l for l in op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + for op in required_ops_list: + tpc.layer2qco[op].base_config = [l for l in tpc.layer2qco[op].quantization_config_list if l.activation_n_bits == 16][0] + + class Activation16BitTest(BasePytorchFeatureNetworkTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v4') - mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + set_16bit_as_default(tpc, OPSET_MUL, [torch.mul, mul]) + set_16bit_as_default(tpc, OPSET_GELU, [torch.nn.GELU, torch.nn.functional.gelu]) + set_16bit_as_default(tpc, OPSET_TANH, [torch.nn.Tanh, torch.nn.functional.tanh, torch.tanh]) return tpc def create_networks(self): + # Activation16BitNet()(torch.from_numpy(self.generate_inputs()[0]).type(torch.float32)) return Activation16BitNet() def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): @@ -77,6 +95,10 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= "1st mul activation should be forced by TPC to be signed, even though activations as all positive.") self.unit_test.assertTrue(mul2_act_quant.activation_holder_quantizer.num_bits == 8, "2nd mul activation bits should be 8 bits because of following div node.") + self.unit_test.assertTrue(quantized_model.gelu_activation_holder_quantizer.activation_holder_quantizer.num_bits == 16, + "gelu activation bits should be 16 bits because of following concat node.") + self.unit_test.assertTrue(quantized_model.tanh_activation_holder_quantizer.activation_holder_quantizer.num_bits == 16, + "tanh activation bits should be 16 bits because of following concat node.") class Activation16BitMixedPrecisionTest(Activation16BitTest): @@ -103,7 +125,7 @@ def get_resource_utilization(self): return mct.core.ResourceUtilization(activation_memory=200) def create_networks(self): - return Activation16BitNet(use_concat=False) + return Activation16BitNet(use_concat=False, enable_head=False) def get_mixed_precision_config(self): return MixedPrecisionQuantizationConfig()