From 32959761042bb6ee5f1c7fa6256a06cb2e51b5d4 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Fri, 23 Feb 2024 15:07:45 +0800 Subject: [PATCH] [Fix] Use HistoryBuffer --- mmengine/runner/loops.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 270819d1bd..17beaf8d95 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from mmengine.evaluator import Evaluator -from mmengine.logging import print_log +from mmengine.logging import HistoryBuffer, print_log from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of @@ -363,7 +363,7 @@ def __init__(self, logger='current', level=logging.WARNING) self.fp16 = fp16 - self.val_loss: Dict[str, list] = dict() + self.val_loss: Dict[str, HistoryBuffer] = dict() def run(self) -> dict: """Launch validation.""" @@ -378,7 +378,7 @@ def run(self) -> dict: # get val loss and save to metrics val_loss = 0 for loss_name, loss_value in self.val_loss.items(): - avg_loss = sum(loss_value) / len(loss_value) + avg_loss = loss_value.mean() metrics[loss_name] = avg_loss if 'loss' in loss_name: val_loss += avg_loss # type: ignore @@ -408,13 +408,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]): else: loss = dict() # get val loss and avoid breaking change + # similar to MessageHub for loss_name, loss_value in loss.items(): if loss_name not in self.val_loss: - self.val_loss[loss_name] = [] + self.val_loss[loss_name] = HistoryBuffer() if isinstance(loss_value, torch.Tensor): - self.val_loss[loss_name].append(loss_value.item()) + loss_value = loss_value.mean().item() elif is_list_of(loss_value, torch.Tensor): - self.val_loss[loss_name].extend([v.item() for v in loss_value]) + loss_value = sum([v.mean() + for v in loss_value]).item() # type: ignore + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + self.val_loss[loss_name].update(loss_value) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( @@ -460,7 +466,7 @@ def __init__(self, logger='current', level=logging.WARNING) self.fp16 = fp16 - self.test_loss: Dict[str, list] = dict() + self.test_loss: Dict[str, HistoryBuffer] = dict() def run(self) -> dict: """Launch test.""" @@ -475,7 +481,7 @@ def run(self) -> dict: # get test loss and save to metrics test_loss = 0 for loss_name, loss_value in self.test_loss.items(): - avg_loss = sum(loss_value) / len(loss_value) + avg_loss = loss_value.mean() metrics[loss_name] = avg_loss if 'loss' in loss_name: test_loss += avg_loss # type: ignore @@ -504,14 +510,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: else: loss = dict() # get val loss and avoid breaking change + # similar to MessageHub for loss_name, loss_value in loss.items(): if loss_name not in self.test_loss: - self.test_loss[loss_name] = [] + self.test_loss[loss_name] = HistoryBuffer() if isinstance(loss_value, torch.Tensor): - self.test_loss[loss_name].append(loss_value.item()) + loss_value = loss_value.mean().item() elif is_list_of(loss_value, torch.Tensor): - self.test_loss[loss_name].extend( - [v.item() for v in loss_value]) + loss_value = sum([v.mean() + for v in loss_value]).item() # type: ignore + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + self.test_loss[loss_name].update(loss_value) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook(