diff --git a/mmaction/evaluation/metrics/acc_metric.py b/mmaction/evaluation/metrics/acc_metric.py index 04985e5938..389a37af98 100644 --- a/mmaction/evaluation/metrics/acc_metric.py +++ b/mmaction/evaluation/metrics/acc_metric.py @@ -248,28 +248,67 @@ def process(self, data_batch, data_samples: Sequence[dict]) -> None: for data_sample in data_samples: pred_scores = data_sample.get('pred_score') gt_label = data_sample['gt_label'] - if pred_scores is not None: - pred_label = pred_scores.argmax(dim=0, keepdim=True) - self.num_classes = pred_scores.size(0) + + # Ad-hoc for RGBPoseConv3D + if isinstance(pred_scores, dict): + pred = {} + for item_name, score in pred_scores.items(): + pred[item_name] = score.to(self.collect_device) + self.num_classes = score.size(0) + self.results.append({'pred': pred, 'gt_label': gt_label}) else: - pred_label = data_sample['pred_label'] + if pred_scores is not None: + pred_label = pred_scores.argmax(dim=0, keepdim=True) + self.num_classes = pred_scores.size(0) + else: + pred_label = data_sample['pred_label'] - self.results.append({ - 'pred_label': pred_label, - 'gt_label': gt_label - }) + self.results.append({ + 'pred_label': pred_label, + 'gt_label': gt_label + }) def compute_metrics(self, results: list) -> dict: - pred_labels = [] - gt_labels = [] - for result in results: - pred_labels.append(result['pred_label']) - gt_labels.append(result['gt_label']) - confusion_matrix = ConfusionMatrix.calculate( - torch.cat(pred_labels), - torch.cat(gt_labels), - num_classes=self.num_classes) - return {'result': confusion_matrix} + # Ad-hoc for RGBPoseConv3D + if 'pred' in results[0]: + out = {} + gt_labels = torch.cat([x['gt_label'] for x in results]) + + for item_name in results[0]['pred'].keys(): + pred_labels = [x['pred'][item_name].argmax() for x in results] + pred_labels = torch.tensor(pred_labels) + out[item_name] = ConfusionMatrix.calculate( + pred_labels, gt_labels, num_classes=self.num_classes) + + if len(results[0]['pred']) == 2 and \ + 'rgb' in results[0]['pred'] and \ + 'pose' in results[0]['pred']: + + rgb = [x['pred']['rgb'] for x in results] + pose = [x['pred']['pose'] for x in results] + + preds = { + '1:1': get_weighted_score([rgb, pose], [1, 1]), + '2:1': get_weighted_score([rgb, pose], [2, 1]), + '1:2': get_weighted_score([rgb, pose], [1, 2]) + } + for item_name, pred in preds.items(): + pred = torch.tensor(pred) + pred_labels = pred.argmax(dim=1) + out[f'RGBPose_{item_name}'] = ConfusionMatrix.calculate( + pred_labels, gt_labels, num_classes=self.num_classes) + return out + else: + pred_labels = [] + gt_labels = [] + for result in results: + pred_labels.append(result['pred_label']) + gt_labels.append(result['gt_label']) + confusion_matrix = ConfusionMatrix.calculate( + torch.cat(pred_labels), + torch.cat(gt_labels), + num_classes=self.num_classes) + return {'result': confusion_matrix} @staticmethod def calculate(pred, target, num_classes=None) -> dict: