From 227e271fa175a1a04a4b587e6af6434ba8f85339 Mon Sep 17 00:00:00 2001 From: reuvenp Date: Wed, 4 Dec 2024 12:07:04 +0200 Subject: [PATCH] fix type hints in sla loss function --- model_compression_toolkit/gptq/keras/gptq_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model_compression_toolkit/gptq/keras/gptq_loss.py b/model_compression_toolkit/gptq/keras/gptq_loss.py index 39bac7b6c..8baa43e1f 100644 --- a/model_compression_toolkit/gptq/keras/gptq_loss.py +++ b/model_compression_toolkit/gptq/keras/gptq_loss.py @@ -14,7 +14,7 @@ # ============================================================================== import tensorflow as tf -from typing import List +from typing import List, Tuple def mse_loss(y: tf.Tensor, x: tf.Tensor, normalized: bool = True) -> tf.Tensor: @@ -72,7 +72,7 @@ def sample_layer_attention_loss(y_list: List[tf.Tensor], flp_w_list, act_bn_mean, act_bn_std, - loss_weights: tf.Tensor) -> tf.Tensor: + loss_weights: Tuple[tf.Tensor]) -> tf.Tensor: """ Compute Sample Layer Attention loss between two lists of tensors using TensorFlow. @@ -80,7 +80,7 @@ def sample_layer_attention_loss(y_list: List[tf.Tensor], y_list: First list of tensors. x_list: Second list of tensors. fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface). - loss_weights: layer-sample weights tensor of shape (batch X layers) + loss_weights: layer-sample attention scores (tuplle by the same length as the number of layers, where each element is a tf.Tensor vector of length of number of samples). Returns: Sample Layer Attention loss (a scalar).