Skip to content

Commit

Permalink
[Fix] Delete frozen parameters when using paramwise_cfg (open-mmlab…
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla authored Apr 22, 2024
1 parent 9ecced8 commit acbc5e4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
5 changes: 4 additions & 1 deletion mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def add_params(self,
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
print_log((f'{prefix}.{name} is skipped since its '
f'requires_grad={param.requires_grad}'),
logger='current',
level=logging.WARNING)
continue

# if the parameter match one of the custom keys, ignore other rules
Expand Down
30 changes: 11 additions & 19 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,8 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self):
weight_decay=self.base_wd,
momentum=self.momentum))
paramwise_cfg = dict()
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
self._check_default_optimizer(optim_wrapper.optimizer, model)

Expand Down Expand Up @@ -591,23 +592,16 @@ def test_default_optimizer_constructor_no_grad(self):
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)

for param in self.model.parameters():
param.requires_grad = False
self.model.conv1.requires_grad_(False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
optimizer = optim_wrapper.optimizer
param_groups = optimizer.param_groups
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
for i, (name, param) in enumerate(self.model.named_parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == self.momentum
assert param_group['lr'] == self.base_lr
assert param_group['weight_decay'] == self.base_wd

all_params = []
for pg in optim_wrapper.param_groups:
all_params.extend(map(id, pg['params']))
self.assertNotIn(id(self.model.conv1.weight), all_params)
self.assertIn(id(self.model.conv2.weight), all_params)

def test_default_optimizer_constructor_bypass_duplicate(self):
# paramwise_cfg with bypass_duplicate option
Expand Down Expand Up @@ -663,10 +657,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)
assert len(optim_wrapper.optimizer.param_groups
) == len(model_parameters) - 1 == num_params - 1

def test_default_optimizer_constructor_custom_key(self):
# test DefaultOptimWrapperConstructor with custom_keys and
Expand Down

0 comments on commit acbc5e4

Please sign in to comment.