From acbc5e46dc56861ec01606befb9a28b7ecf9e5aa Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:54:48 +0800 Subject: [PATCH] [Fix] Delete frozen parameters when using `paramwise_cfg` (#1441) --- .../optim/optimizer/default_constructor.py | 5 +++- .../test_optimizer/test_optimizer.py | 30 +++++++------------ 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 95233d86bb..ec223a7967 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -213,7 +213,10 @@ def add_params(self, level=logging.WARNING) continue if not param.requires_grad: - params.append(param_group) + print_log((f'{prefix}.{name} is skipped since its ' + f'requires_grad={param.requires_grad}'), + logger='current', + level=logging.WARNING) continue # if the parameter match one of the custom keys, ignore other rules diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 0cf60fcb83..113aacd6c8 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -549,7 +549,8 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict() - optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) self._check_default_optimizer(optim_wrapper.optimizer, model) @@ -591,23 +592,16 @@ def test_default_optimizer_constructor_no_grad(self): dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) - for param in self.model.parameters(): - param.requires_grad = False + self.model.conv1.requires_grad_(False) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) - optimizer = optim_wrapper.optimizer - param_groups = optimizer.param_groups - assert isinstance(optim_wrapper.optimizer, torch.optim.SGD) - assert optimizer.defaults['lr'] == self.base_lr - assert optimizer.defaults['momentum'] == self.momentum - assert optimizer.defaults['weight_decay'] == self.base_wd - for i, (name, param) in enumerate(self.model.named_parameters()): - param_group = param_groups[i] - assert torch.equal(param_group['params'][0], param) - assert param_group['momentum'] == self.momentum - assert param_group['lr'] == self.base_lr - assert param_group['weight_decay'] == self.base_wd + + all_params = [] + for pg in optim_wrapper.param_groups: + all_params.extend(map(id, pg['params'])) + self.assertNotIn(id(self.model.conv1.weight), all_params) + self.assertIn(id(self.model.conv2.weight), all_params) def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option @@ -663,10 +657,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self): optim_wrapper = optim_constructor(model) model_parameters = list(model.parameters()) num_params = 14 if MMCV_FULL_AVAILABLE else 11 - assert len(optim_wrapper.optimizer.param_groups) == len( - model_parameters) == num_params - self._check_sgd_optimizer(optim_wrapper.optimizer, model, - **paramwise_cfg) + assert len(optim_wrapper.optimizer.param_groups + ) == len(model_parameters) - 1 == num_params - 1 def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and