From ee34d952e792fd1adea2c2e397b29faff68eaec9 Mon Sep 17 00:00:00 2001 From: Jintao Lin <528557675@qq.com> Date: Sun, 7 Mar 2021 14:38:29 +0800 Subject: [PATCH] [Bug] Add missing 'loss_aux' and related unittest (#683) * fix aux_loss bug * add related unittest --- mmaction/models/recognizers/recognizer2d.py | 8 +++----- tests/test_models/test_recognizers/test_recognizer2d.py | 1 + 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mmaction/models/recognizers/recognizer2d.py b/mmaction/models/recognizers/recognizer2d.py index 76208921f2..65f27ddb0a 100644 --- a/mmaction/models/recognizers/recognizer2d.py +++ b/mmaction/models/recognizers/recognizer2d.py @@ -21,9 +21,10 @@ def forward_train(self, imgs, labels, **kwargs): each.shape[1:]).transpose(1, 2).contiguous() for each in x ] - x, _ = self.neck(x, labels.squeeze()) + x, loss_aux = self.neck(x, labels.squeeze()) x = x.squeeze(2) num_segs = 1 + losses.update(loss_aux) cls_score = self.cls_head(x, num_segs) gt_labels = labels.squeeze() @@ -40,8 +41,6 @@ def _do_test(self, imgs): imgs = imgs.reshape((-1, ) + imgs.shape[2:]) num_segs = imgs.shape[0] // batches - losses = dict() - x = self.extract_feat(imgs) if hasattr(self, 'neck'): x = [ @@ -49,9 +48,8 @@ def _do_test(self, imgs): each.shape[1:]).transpose(1, 2).contiguous() for each in x ] - x, loss_aux = self.neck(x) + x, _ = self.neck(x) x = x.squeeze(2) - losses.update(loss_aux) num_segs = 1 # When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes] diff --git a/tests/test_models/test_recognizers/test_recognizer2d.py b/tests/test_models/test_recognizers/test_recognizer2d.py index 95d23d9009..f73e778263 100644 --- a/tests/test_models/test_recognizers/test_recognizer2d.py +++ b/tests/test_models/test_recognizers/test_recognizer2d.py @@ -86,6 +86,7 @@ def test_tpn(): losses = recognizer(imgs, gt_labels) assert isinstance(losses, dict) + assert 'loss_aux' in losses and 'loss_cls' in losses # Test forward test with torch.no_grad():