Skip to content

Commit

Permalink
fix type hints in sla loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 4, 2024
1 parent b417db6 commit 227e271
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions model_compression_toolkit/gptq/keras/gptq_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

import tensorflow as tf
from typing import List
from typing import List, Tuple


def mse_loss(y: tf.Tensor, x: tf.Tensor, normalized: bool = True) -> tf.Tensor:
Expand Down Expand Up @@ -72,15 +72,15 @@ def sample_layer_attention_loss(y_list: List[tf.Tensor],
flp_w_list,
act_bn_mean,
act_bn_std,
loss_weights: tf.Tensor) -> tf.Tensor:
loss_weights: Tuple[tf.Tensor]) -> tf.Tensor:
"""
Compute Sample Layer Attention loss between two lists of tensors using TensorFlow.
Args:
y_list: First list of tensors.
x_list: Second list of tensors.
fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
loss_weights: layer-sample weights tensor of shape (batch X layers)
loss_weights: layer-sample attention scores (tuplle by the same length as the number of layers, where each element is a tf.Tensor vector of length of number of samples).
Returns:
Sample Layer Attention loss (a scalar).
Expand Down

0 comments on commit 227e271

Please sign in to comment.