Skip to content

Commit

Permalink
[Fix] Use HistoryBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Feb 23, 2024
1 parent 0d5c490 commit 3295976
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader

from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, list] = dict()
self.val_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch validation."""
Expand All @@ -378,7 +378,7 @@ def run(self) -> dict:
# get val loss and save to metrics
val_loss = 0
for loss_name, loss_value in self.val_loss.items():
avg_loss = sum(loss_value) / len(loss_value)
avg_loss = loss_value.mean()
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
val_loss += avg_loss # type: ignore
Expand Down Expand Up @@ -408,13 +408,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
else:
loss = dict()
# get val loss and avoid breaking change
# similar to MessageHub
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
self.val_loss[loss_name] = []
self.val_loss[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
self.val_loss[loss_name].append(loss_value.item())
loss_value = loss_value.mean().item()
elif is_list_of(loss_value, torch.Tensor):
self.val_loss[loss_name].extend([v.item() for v in loss_value])
loss_value = sum([v.mean()
for v in loss_value]).item() # type: ignore
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
self.val_loss[loss_name].update(loss_value)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down Expand Up @@ -460,7 +466,7 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, list] = dict()
self.test_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch test."""
Expand All @@ -475,7 +481,7 @@ def run(self) -> dict:
# get test loss and save to metrics
test_loss = 0
for loss_name, loss_value in self.test_loss.items():
avg_loss = sum(loss_value) / len(loss_value)
avg_loss = loss_value.mean()
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
test_loss += avg_loss # type: ignore
Expand Down Expand Up @@ -504,14 +510,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
else:
loss = dict()
# get val loss and avoid breaking change
# similar to MessageHub
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
self.test_loss[loss_name] = []
self.test_loss[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
self.test_loss[loss_name].append(loss_value.item())
loss_value = loss_value.mean().item()
elif is_list_of(loss_value, torch.Tensor):
self.test_loss[loss_name].extend(
[v.item() for v in loss_value])
loss_value = sum([v.mean()
for v in loss_value]).item() # type: ignore
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
self.test_loss[loss_name].update(loss_value)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down

0 comments on commit 3295976

Please sign in to comment.