Skip to content

Commit

Permalink
UPDATE ConfusionMatrix support multimodal predicts
Browse files Browse the repository at this point in the history
  • Loading branch information
HoBeom committed Apr 4, 2024
1 parent 4d6c934 commit 789c12d
Showing 1 changed file with 57 additions and 18 deletions.
75 changes: 57 additions & 18 deletions mmaction/evaluation/metrics/acc_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,28 +248,67 @@ def process(self, data_batch, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
pred_scores = data_sample.get('pred_score')
gt_label = data_sample['gt_label']
if pred_scores is not None:
pred_label = pred_scores.argmax(dim=0, keepdim=True)
self.num_classes = pred_scores.size(0)

# Ad-hoc for RGBPoseConv3D
if isinstance(pred_scores, dict):
pred = {}
for item_name, score in pred_scores.items():
pred[item_name] = score.to(self.collect_device)
self.num_classes = score.size(0)
self.results.append({'pred': pred, 'gt_label': gt_label})
else:
pred_label = data_sample['pred_label']
if pred_scores is not None:
pred_label = pred_scores.argmax(dim=0, keepdim=True)
self.num_classes = pred_scores.size(0)
else:
pred_label = data_sample['pred_label']

self.results.append({
'pred_label': pred_label,
'gt_label': gt_label
})
self.results.append({
'pred_label': pred_label,
'gt_label': gt_label
})

def compute_metrics(self, results: list) -> dict:
pred_labels = []
gt_labels = []
for result in results:
pred_labels.append(result['pred_label'])
gt_labels.append(result['gt_label'])
confusion_matrix = ConfusionMatrix.calculate(
torch.cat(pred_labels),
torch.cat(gt_labels),
num_classes=self.num_classes)
return {'result': confusion_matrix}
# Ad-hoc for RGBPoseConv3D
if 'pred' in results[0]:
out = {}
gt_labels = torch.cat([x['gt_label'] for x in results])

for item_name in results[0]['pred'].keys():
pred_labels = [x['pred'][item_name].argmax() for x in results]
pred_labels = torch.tensor(pred_labels)
out[item_name] = ConfusionMatrix.calculate(
pred_labels, gt_labels, num_classes=self.num_classes)

if len(results[0]['pred']) == 2 and \
'rgb' in results[0]['pred'] and \
'pose' in results[0]['pred']:

rgb = [x['pred']['rgb'] for x in results]
pose = [x['pred']['pose'] for x in results]

preds = {
'1:1': get_weighted_score([rgb, pose], [1, 1]),
'2:1': get_weighted_score([rgb, pose], [2, 1]),
'1:2': get_weighted_score([rgb, pose], [1, 2])
}
for item_name, pred in preds.items():
pred = torch.tensor(pred)
pred_labels = pred.argmax(dim=1)
out[f'RGBPose_{item_name}'] = ConfusionMatrix.calculate(
pred_labels, gt_labels, num_classes=self.num_classes)
return out
else:
pred_labels = []
gt_labels = []
for result in results:
pred_labels.append(result['pred_label'])
gt_labels.append(result['gt_label'])
confusion_matrix = ConfusionMatrix.calculate(
torch.cat(pred_labels),
torch.cat(gt_labels),
num_classes=self.num_classes)
return {'result': confusion_matrix}

@staticmethod
def calculate(pred, target, num_classes=None) -> dict:
Expand Down

0 comments on commit 789c12d

Please sign in to comment.