From 815539b4de657f3f0253cf612470aa2b34ea8607 Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Fri, 7 Apr 2023 18:18:10 +0800 Subject: [PATCH 1/8] add clip backbone --- ...v5_s_test_clip_backbone_freeze_backbone.py | 6 ++ .../yolov5_s_test_clip_backbone_freeze_bn.py | 46 ++++++++++++ mmyolo/engine/hooks/__init__.py | 5 +- .../hooks/yolov5_param_scheduler_hook.py | 13 ++++ mmyolo/engine/optimizers/__init__.py | 8 +- .../optimizers/yolov5_optim_constructor.py | 74 +++++++++++++++++++ mmyolo/models/backbones/__init__.py | 3 +- mmyolo/models/backbones/clip_backbone.py | 42 +++++++++++ mmyolo/models/necks/__init__.py | 4 +- mmyolo/models/necks/yolov5_pafpn.py | 22 ++++++ tools/model_converters/clip_to_mmyolo.py | 42 +++++++++++ 11 files changed, 258 insertions(+), 7 deletions(-) create mode 100644 configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py create mode 100644 configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py create mode 100644 mmyolo/models/backbones/clip_backbone.py create mode 100644 tools/model_converters/clip_to_mmyolo.py diff --git a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py new file mode 100644 index 000000000..ae6adfcfa --- /dev/null +++ b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py @@ -0,0 +1,6 @@ +_base_ = 'yolov5_s_test_clip_backbone_freeze_bn.py' + +model = dict( + backbone=dict( + freeze_backbone=True, # 冻结backbone + freeze_bn=False)) diff --git a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py new file mode 100644 index 000000000..139990099 --- /dev/null +++ b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py @@ -0,0 +1,46 @@ +_base_ = '../yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py' + +model = dict( + data_preprocessor=dict( + type='YOLOv5DetDataPreprocessor', + # 按照clip里的预处理方式 + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]), + backbone=dict( + _delete_=True, + type='CLIPModifiedResNet', + freeze_backbone=False, # 只冻结bn,不冻结backbone + freeze_bn=True, + output_dim=1024, + layers=[3, 4, 6, 3], + width=64, + heads=64 * 32 // 64), + neck=[ + dict( + type='TempCLIPdownsampleneck', + in_channels=[512, 1024, 2048], + output_channels=[128, 256, 512], + norm_cfg=_base_.norm_cfg, + act_cfg=dict(type='SiLU', inplace=True)), + dict( + type='YOLOv5PAFPN', + deepen_factor=_base_.deepen_factor, + widen_factor=_base_.widen_factor, + in_channels=[256, 512, 1024], + out_channels=[256, 512, 1024], + num_csp_blocks=3, + norm_cfg=_base_.norm_cfg, + act_cfg=dict(type='SiLU', inplace=True)) + ]) + +base_lr = _base_.base_lr +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(lr=base_lr), + constructor='TempCLIPBackboneConstructor') # 单独拎出来backbone的参数 + +default_hooks = dict( + param_scheduler=dict( + type='TempCLIPParamSchedulerHook', backbone_lr_scale=0.1)) + +load_from = 'CLIPResNet50.pth' diff --git a/mmyolo/engine/hooks/__init__.py b/mmyolo/engine/hooks/__init__.py index 0b8deebc8..83b9b9217 100644 --- a/mmyolo/engine/hooks/__init__.py +++ b/mmyolo/engine/hooks/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook from .switch_to_deploy_hook import SwitchToDeployHook -from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook +from .yolov5_param_scheduler_hook import (TempCLIPParamSchedulerHook, + YOLOv5ParamSchedulerHook) from .yolox_mode_switch_hook import YOLOXModeSwitchHook __all__ = [ 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook', - 'PPYOLOEParamSchedulerHook' + 'PPYOLOEParamSchedulerHook', 'TempCLIPParamSchedulerHook' ] diff --git a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py index 777bb49d7..7a7848561 100644 --- a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py +++ b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py @@ -128,3 +128,16 @@ def after_train_epoch(self, runner: Runner): for group_idx, param in enumerate(optimizer.param_groups): param['lr'] = self._base_lr[group_idx] * self.scheduler_fn( cur_epoch) + + +@HOOKS.register_module() +class TempCLIPParamSchedulerHook(YOLOv5ParamSchedulerHook): + + def __init__(self, backbone_lr_scale: float = 0.1, *args, **kwargs): + self.backbone_lr_scale = backbone_lr_scale + super().__init__(*args, **kwargs) + + def before_train(self, runner: Runner): + super().before_train(runner) + # backbone 进行lr缩小 + self._base_lr[3] *= self.backbone_lr_scale diff --git a/mmyolo/engine/optimizers/__init__.py b/mmyolo/engine/optimizers/__init__.py index b598020d0..e19c5ac2c 100644 --- a/mmyolo/engine/optimizers/__init__.py +++ b/mmyolo/engine/optimizers/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .yolov5_optim_constructor import YOLOv5OptimizerConstructor +from .yolov5_optim_constructor import (TempCLIPBackboneConstructor, + YOLOv5OptimizerConstructor) from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor -__all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor'] +__all__ = [ + 'YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor', + 'TempCLIPBackboneConstructor' +] diff --git a/mmyolo/engine/optimizers/yolov5_optim_constructor.py b/mmyolo/engine/optimizers/yolov5_optim_constructor.py index 5e5f42cb5..59e20f71b 100644 --- a/mmyolo/engine/optimizers/yolov5_optim_constructor.py +++ b/mmyolo/engine/optimizers/yolov5_optim_constructor.py @@ -130,3 +130,77 @@ def __call__(self, model: nn.Module) -> OptimWrapper: optim_wrapper = OPTIM_WRAPPERS.build( self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) return optim_wrapper + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class TempCLIPBackboneConstructor(YOLOv5OptimizerConstructor): + + def __call__(self, model: nn.Module) -> OptimWrapper: + if is_model_wrapper(model): + model = model.module + optimizer_cfg = self.optimizer_cfg.copy() + weight_decay = optimizer_cfg.pop('weight_decay', 0) + + if 'batch_size_per_gpu' in optimizer_cfg: + batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu') + # No scaling if total_batch_size is less than + # base_total_batch_size, otherwise linear scaling. + total_batch_size = get_world_size() * batch_size_per_gpu + accumulate = max( + round(self.base_total_batch_size / total_batch_size), 1) + scale_factor = total_batch_size * \ + accumulate / self.base_total_batch_size + + if scale_factor != 1: + weight_decay *= scale_factor + print_log(f'Scaled weight_decay to {weight_decay}', 'current') + + params_groups = [], [], [], [] + + # for v in model.modules(): + # if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): + # params_groups[2].append(v.bias) + # # Includes SyncBatchNorm + # if isinstance(v, nn.modules.batchnorm._NormBase): + # params_groups[1].append(v.weight) + # elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): + # params_groups[0].append(v.weight) + + params = list(model.named_parameters()) + for k, v in params: + # 单独出来backbone的para,使用单独的学习率 + if 'backbone.' in k: + params_groups[3].append(v) + continue + if 'bias' in k: + params_groups[2].append(v) + continue + if 'bn' in k: + params_groups[1].append(v) + elif 'weight' in k and isinstance(v, nn.Parameter): + params_groups[0].append(v) + + # Note: Make sure bias is in the last parameter group + optimizer_cfg['params'] = [] + # conv + optimizer_cfg['params'].append({ + 'params': params_groups[0], + 'weight_decay': weight_decay + }) + # bn + optimizer_cfg['params'].append({'params': params_groups[1]}) + # bias + optimizer_cfg['params'].append({'params': params_groups[2]}) + # backbone + optimizer_cfg['params'].append({'params': params_groups[3]}) + + print_log( + 'Optimizer groups: %g .bias, %g conv.weight, %g other' % + (len(params_groups[2]), len(params_groups[0]), len( + params_groups[1])), 'current') + del params_groups + + optimizer = OPTIMIZERS.build(optimizer_cfg) + optim_wrapper = OPTIM_WRAPPERS.build( + self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) + return optim_wrapper diff --git a/mmyolo/models/backbones/__init__.py b/mmyolo/models/backbones/__init__.py index 48c8e28b1..df2fd267c 100644 --- a/mmyolo/models/backbones/__init__.py +++ b/mmyolo/models/backbones/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_backbone import BaseBackbone +from .clip_backbone import CLIPModifiedResNet from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet from .csp_resnet import PPYOLOECSPResNet from .cspnext import CSPNeXt @@ -9,5 +10,5 @@ __all__ = [ 'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep', 'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet', - 'YOLOv8CSPDarknet' + 'YOLOv8CSPDarknet', 'CLIPModifiedResNet' ] diff --git a/mmyolo/models/backbones/clip_backbone.py b/mmyolo/models/backbones/clip_backbone.py new file mode 100644 index 000000000..85d049d17 --- /dev/null +++ b/mmyolo/models/backbones/clip_backbone.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from clip.model import ModifiedResNet +from torch.nn.modules.batchnorm import _BatchNorm + +from mmyolo.registry import MODELS + + +@MODELS.register_module() +class CLIPModifiedResNet(ModifiedResNet): + + def __init__(self, freeze_backbone, freeze_bn, *args, **kwargs): + super().__init__(*args, **kwargs) + self.freeze_backbone = freeze_backbone + self.freeze_bn = freeze_bn + + def forward(self, x): + + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x1 = self.layer2(x) + x2 = self.layer3(x1) + x3 = self.layer4(x2) + # x = self.attnpool(x) + # print('!!!', x1.shape, x2.shape, x3.shape) + return [x1, x2, x3] + + def train(self, mode=True): + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + elif self.freeze_backbone: + for m in self.modules(): + m.eval() diff --git a/mmyolo/models/necks/__init__.py b/mmyolo/models/necks/__init__.py index 6da9641ce..5364de703 100644 --- a/mmyolo/models/necks/__init__.py +++ b/mmyolo/models/necks/__init__.py @@ -2,7 +2,7 @@ from .base_yolo_neck import BaseYOLONeck from .cspnext_pafpn import CSPNeXtPAFPN from .ppyoloe_csppan import PPYOLOECSPPAFPN -from .yolov5_pafpn import YOLOv5PAFPN +from .yolov5_pafpn import TempCLIPdownsampleneck, YOLOv5PAFPN from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN from .yolov7_pafpn import YOLOv7PAFPN from .yolov8_pafpn import YOLOv8PAFPN @@ -11,5 +11,5 @@ __all__ = [ 'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN', 'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN', - 'YOLOv8PAFPN' + 'YOLOv8PAFPN', 'TempCLIPdownsampleneck' ] diff --git a/mmyolo/models/necks/yolov5_pafpn.py b/mmyolo/models/necks/yolov5_pafpn.py index b95147fc5..5e9d0755f 100644 --- a/mmyolo/models/necks/yolov5_pafpn.py +++ b/mmyolo/models/necks/yolov5_pafpn.py @@ -12,6 +12,28 @@ from .base_yolo_neck import BaseYOLONeck +@MODELS.register_module() +class TempCLIPdownsampleneck(nn.Module): + + def __init__(self, in_channels, output_channels, norm_cfg, act_cfg): + super().__init__() + self.layers = nn.ModuleList() + for in_channel, output_channel in zip(in_channels, output_channels): + self.layers.append( + ConvModule( + in_channel, + output_channel, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + res = [] + for ind in range(len(x)): + res.append(self.layers[ind](x[ind])) + return res + + @MODELS.register_module() class YOLOv5PAFPN(BaseYOLONeck): """Path Aggregation Network used in YOLOv5. diff --git a/tools/model_converters/clip_to_mmyolo.py b/tools/model_converters/clip_to_mmyolo.py new file mode 100644 index 000000000..dd0297e68 --- /dev/null +++ b/tools/model_converters/clip_to_mmyolo.py @@ -0,0 +1,42 @@ +import argparse +from collections import OrderedDict + +import clip +import torch + +# _MODELS = { +# "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa +# "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa +# "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa +# "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa +# "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # noqa +# "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa +# "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa +# "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", # noqa +# "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", # noqa +# } + + +def convert(src, dst): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model, preprocess = clip.load(src, device=device) + visual_model = model.visual + state_dict = visual_model.state_dict() + new_state_dict = OrderedDict() + for key, value in state_dict.items(): + new_state_dict['backbone.' + key] = value + data = {'state_dict': new_state_dict} + torch.save(data, dst) + print('save') + + +def main(): + parser = argparse.ArgumentParser(description='Convert model keys') + parser.add_argument('--src', default='RN50', help='src clip model key') + parser.add_argument('--dst', default='CLIPResNet50.pth', help='save path') + args = parser.parse_args() + convert(args.src, args.dst) + + +if __name__ == '__main__': + main() From 17bef93f8c0a90afe098430f3ff6265b454d7077 Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Tue, 18 Apr 2023 10:11:48 +0800 Subject: [PATCH 2/8] add clip backbone --- .../yolov5_s_test_clip_backbone_freeze_bn.py | 21 ++++-- mmyolo/engine/hooks/__init__.py | 5 +- .../hooks/yolov5_param_scheduler_hook.py | 13 ---- mmyolo/engine/optimizers/__init__.py | 8 +- .../optimizers/yolov5_optim_constructor.py | 74 ------------------- 5 files changed, 20 insertions(+), 101 deletions(-) diff --git a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py index 139990099..53068e525 100644 --- a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py +++ b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py @@ -33,14 +33,25 @@ act_cfg=dict(type='SiLU', inplace=True)) ]) -base_lr = _base_.base_lr +base_lr = 0.004 +weight_decay = 0.05 optim_wrapper = dict( + _delete_=True, type='OptimWrapper', - optimizer=dict(lr=base_lr), - constructor='TempCLIPBackboneConstructor') # 单独拎出来backbone的参数 + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=weight_decay), + paramwise_cfg=dict( + custom_keys=dict(backbone=dict(lr_mult=0.1, decay_mult=0.1)), + norm_decay_mult=0, + bias_decay_mult=0, + bypass_duplicate=True)) +save_checkpoint_intervals = 10 +max_keep_ckpts = 3 default_hooks = dict( - param_scheduler=dict( - type='TempCLIPParamSchedulerHook', backbone_lr_scale=0.1)) + checkpoint=dict( + type='CheckpointHook', + interval=save_checkpoint_intervals, + max_keep_ckpts=max_keep_ckpts # only keep latest 3 checkpoints + )) load_from = 'CLIPResNet50.pth' diff --git a/mmyolo/engine/hooks/__init__.py b/mmyolo/engine/hooks/__init__.py index 83b9b9217..0b8deebc8 100644 --- a/mmyolo/engine/hooks/__init__.py +++ b/mmyolo/engine/hooks/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook from .switch_to_deploy_hook import SwitchToDeployHook -from .yolov5_param_scheduler_hook import (TempCLIPParamSchedulerHook, - YOLOv5ParamSchedulerHook) +from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook from .yolox_mode_switch_hook import YOLOXModeSwitchHook __all__ = [ 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook', - 'PPYOLOEParamSchedulerHook', 'TempCLIPParamSchedulerHook' + 'PPYOLOEParamSchedulerHook' ] diff --git a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py index 7a7848561..777bb49d7 100644 --- a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py +++ b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py @@ -128,16 +128,3 @@ def after_train_epoch(self, runner: Runner): for group_idx, param in enumerate(optimizer.param_groups): param['lr'] = self._base_lr[group_idx] * self.scheduler_fn( cur_epoch) - - -@HOOKS.register_module() -class TempCLIPParamSchedulerHook(YOLOv5ParamSchedulerHook): - - def __init__(self, backbone_lr_scale: float = 0.1, *args, **kwargs): - self.backbone_lr_scale = backbone_lr_scale - super().__init__(*args, **kwargs) - - def before_train(self, runner: Runner): - super().before_train(runner) - # backbone 进行lr缩小 - self._base_lr[3] *= self.backbone_lr_scale diff --git a/mmyolo/engine/optimizers/__init__.py b/mmyolo/engine/optimizers/__init__.py index e19c5ac2c..b598020d0 100644 --- a/mmyolo/engine/optimizers/__init__.py +++ b/mmyolo/engine/optimizers/__init__.py @@ -1,9 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .yolov5_optim_constructor import (TempCLIPBackboneConstructor, - YOLOv5OptimizerConstructor) +from .yolov5_optim_constructor import YOLOv5OptimizerConstructor from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor -__all__ = [ - 'YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor', - 'TempCLIPBackboneConstructor' -] +__all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor'] diff --git a/mmyolo/engine/optimizers/yolov5_optim_constructor.py b/mmyolo/engine/optimizers/yolov5_optim_constructor.py index 59e20f71b..5e5f42cb5 100644 --- a/mmyolo/engine/optimizers/yolov5_optim_constructor.py +++ b/mmyolo/engine/optimizers/yolov5_optim_constructor.py @@ -130,77 +130,3 @@ def __call__(self, model: nn.Module) -> OptimWrapper: optim_wrapper = OPTIM_WRAPPERS.build( self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) return optim_wrapper - - -@OPTIM_WRAPPER_CONSTRUCTORS.register_module() -class TempCLIPBackboneConstructor(YOLOv5OptimizerConstructor): - - def __call__(self, model: nn.Module) -> OptimWrapper: - if is_model_wrapper(model): - model = model.module - optimizer_cfg = self.optimizer_cfg.copy() - weight_decay = optimizer_cfg.pop('weight_decay', 0) - - if 'batch_size_per_gpu' in optimizer_cfg: - batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu') - # No scaling if total_batch_size is less than - # base_total_batch_size, otherwise linear scaling. - total_batch_size = get_world_size() * batch_size_per_gpu - accumulate = max( - round(self.base_total_batch_size / total_batch_size), 1) - scale_factor = total_batch_size * \ - accumulate / self.base_total_batch_size - - if scale_factor != 1: - weight_decay *= scale_factor - print_log(f'Scaled weight_decay to {weight_decay}', 'current') - - params_groups = [], [], [], [] - - # for v in model.modules(): - # if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): - # params_groups[2].append(v.bias) - # # Includes SyncBatchNorm - # if isinstance(v, nn.modules.batchnorm._NormBase): - # params_groups[1].append(v.weight) - # elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): - # params_groups[0].append(v.weight) - - params = list(model.named_parameters()) - for k, v in params: - # 单独出来backbone的para,使用单独的学习率 - if 'backbone.' in k: - params_groups[3].append(v) - continue - if 'bias' in k: - params_groups[2].append(v) - continue - if 'bn' in k: - params_groups[1].append(v) - elif 'weight' in k and isinstance(v, nn.Parameter): - params_groups[0].append(v) - - # Note: Make sure bias is in the last parameter group - optimizer_cfg['params'] = [] - # conv - optimizer_cfg['params'].append({ - 'params': params_groups[0], - 'weight_decay': weight_decay - }) - # bn - optimizer_cfg['params'].append({'params': params_groups[1]}) - # bias - optimizer_cfg['params'].append({'params': params_groups[2]}) - # backbone - optimizer_cfg['params'].append({'params': params_groups[3]}) - - print_log( - 'Optimizer groups: %g .bias, %g conv.weight, %g other' % - (len(params_groups[2]), len(params_groups[0]), len( - params_groups[1])), 'current') - del params_groups - - optimizer = OPTIMIZERS.build(optimizer_cfg) - optim_wrapper = OPTIM_WRAPPERS.build( - self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) - return optim_wrapper From 310878c9795ea723e0225c173d5f11a55daed78f Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Tue, 18 Apr 2023 11:25:29 +0800 Subject: [PATCH 3/8] yolov5 dynamic neck outputshape --- .../yolov5_s_test_clip_backbone_freeze_bn.py | 30 ++++++++----------- mmyolo/models/necks/yolov5_pafpn.py | 11 +++---- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py index 53068e525..8ddfed5cf 100644 --- a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py +++ b/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py @@ -15,23 +15,19 @@ layers=[3, 4, 6, 3], width=64, heads=64 * 32 // 64), - neck=[ - dict( - type='TempCLIPdownsampleneck', - in_channels=[512, 1024, 2048], - output_channels=[128, 256, 512], - norm_cfg=_base_.norm_cfg, - act_cfg=dict(type='SiLU', inplace=True)), - dict( - type='YOLOv5PAFPN', - deepen_factor=_base_.deepen_factor, - widen_factor=_base_.widen_factor, - in_channels=[256, 512, 1024], - out_channels=[256, 512, 1024], - num_csp_blocks=3, - norm_cfg=_base_.norm_cfg, - act_cfg=dict(type='SiLU', inplace=True)) - ]) + neck=dict( + type='YOLOv5PAFPN', + deepen_factor=_base_.deepen_factor, + widen_factor=_base_.widen_factor, + in_channels=[ + int(512 / _base_.widen_factor), + int(1024 / _base_.widen_factor), + int(2048 / _base_.widen_factor) + ], + out_channels=[256, 512, 1024], + num_csp_blocks=3, + norm_cfg=_base_.norm_cfg, + act_cfg=dict(type='SiLU', inplace=True))) base_lr = 0.004 weight_decay = 0.05 diff --git a/mmyolo/models/necks/yolov5_pafpn.py b/mmyolo/models/necks/yolov5_pafpn.py index 5e9d0755f..8b96b199c 100644 --- a/mmyolo/models/necks/yolov5_pafpn.py +++ b/mmyolo/models/necks/yolov5_pafpn.py @@ -127,7 +127,7 @@ def build_top_down_layer(self, idx: int): return CSPLayer( make_divisible(self.in_channels[idx - 1] * 2, self.widen_factor), - make_divisible(self.in_channels[idx - 1], self.widen_factor), + make_divisible(self.out_channels[idx - 1], self.widen_factor), num_blocks=make_round(self.num_csp_blocks, self.deepen_factor), add_identity=False, norm_cfg=self.norm_cfg, @@ -163,8 +163,8 @@ def build_downsample_layer(self, idx: int) -> nn.Module: nn.Module: The downsample layer. """ return ConvModule( - make_divisible(self.in_channels[idx], self.widen_factor), - make_divisible(self.in_channels[idx], self.widen_factor), + make_divisible(self.out_channels[idx], self.widen_factor), + make_divisible(self.out_channels[idx], self.widen_factor), kernel_size=3, stride=2, padding=1, @@ -181,8 +181,9 @@ def build_bottom_up_layer(self, idx: int) -> nn.Module: nn.Module: The bottom up layer. """ return CSPLayer( - make_divisible(self.in_channels[idx] * 2, self.widen_factor), - make_divisible(self.in_channels[idx + 1], self.widen_factor), + make_divisible(self.in_channels[idx] + self.out_channels[idx], + self.widen_factor), + make_divisible(self.out_channels[idx + 1], self.widen_factor), num_blocks=make_round(self.num_csp_blocks, self.deepen_factor), add_identity=False, norm_cfg=self.norm_cfg, From b73c07f6a45adec58517ed784be26081ee13dfe9 Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Tue, 18 Apr 2023 11:32:51 +0800 Subject: [PATCH 4/8] del temp neck --- mmyolo/models/necks/__init__.py | 4 ++-- mmyolo/models/necks/yolov5_pafpn.py | 22 ---------------------- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/mmyolo/models/necks/__init__.py b/mmyolo/models/necks/__init__.py index 5364de703..6da9641ce 100644 --- a/mmyolo/models/necks/__init__.py +++ b/mmyolo/models/necks/__init__.py @@ -2,7 +2,7 @@ from .base_yolo_neck import BaseYOLONeck from .cspnext_pafpn import CSPNeXtPAFPN from .ppyoloe_csppan import PPYOLOECSPPAFPN -from .yolov5_pafpn import TempCLIPdownsampleneck, YOLOv5PAFPN +from .yolov5_pafpn import YOLOv5PAFPN from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN from .yolov7_pafpn import YOLOv7PAFPN from .yolov8_pafpn import YOLOv8PAFPN @@ -11,5 +11,5 @@ __all__ = [ 'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN', 'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN', - 'YOLOv8PAFPN', 'TempCLIPdownsampleneck' + 'YOLOv8PAFPN' ] diff --git a/mmyolo/models/necks/yolov5_pafpn.py b/mmyolo/models/necks/yolov5_pafpn.py index 8b96b199c..3f0206469 100644 --- a/mmyolo/models/necks/yolov5_pafpn.py +++ b/mmyolo/models/necks/yolov5_pafpn.py @@ -12,28 +12,6 @@ from .base_yolo_neck import BaseYOLONeck -@MODELS.register_module() -class TempCLIPdownsampleneck(nn.Module): - - def __init__(self, in_channels, output_channels, norm_cfg, act_cfg): - super().__init__() - self.layers = nn.ModuleList() - for in_channel, output_channel in zip(in_channels, output_channels): - self.layers.append( - ConvModule( - in_channel, - output_channel, - kernel_size=1, - norm_cfg=norm_cfg, - act_cfg=act_cfg)) - - def forward(self, x): - res = [] - for ind in range(len(x)): - res.append(self.layers[ind](x[ind])) - return res - - @MODELS.register_module() class YOLOv5PAFPN(BaseYOLONeck): """Path Aggregation Network used in YOLOv5. From b225f276b8980129a81c35f14cd25134814f3855 Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Thu, 20 Apr 2023 09:54:39 +0800 Subject: [PATCH 5/8] l mode --- ...kbone.py => yolov5_l_clip_backbone_freeze_backbone.py} | 2 +- ...e_freeze_bn.py => yolov5_l_clip_backbone_freeze_bn.py} | 2 +- mmyolo/models/backbones/clip_backbone.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) rename configs/yolov5/clip_backbone/{yolov5_s_test_clip_backbone_freeze_backbone.py => yolov5_l_clip_backbone_freeze_backbone.py} (67%) rename configs/yolov5/clip_backbone/{yolov5_s_test_clip_backbone_freeze_bn.py => yolov5_l_clip_backbone_freeze_bn.py} (96%) diff --git a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_backbone.py similarity index 67% rename from configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py rename to configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_backbone.py index ae6adfcfa..1666b937f 100644 --- a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_backbone.py +++ b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_backbone.py @@ -1,4 +1,4 @@ -_base_ = 'yolov5_s_test_clip_backbone_freeze_bn.py' +_base_ = 'yolov5_l_clip_backbone_freeze_bn.py' model = dict( backbone=dict( diff --git a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py similarity index 96% rename from configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py rename to configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py index 8ddfed5cf..71c5b66e0 100644 --- a/configs/yolov5/clip_backbone/yolov5_s_test_clip_backbone_freeze_bn.py +++ b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py @@ -1,4 +1,4 @@ -_base_ = '../yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py' +_base_ = '../yolov5_l-v61_syncbn_fast_8xb16-300e_coco.py' model = dict( data_preprocessor=dict( diff --git a/mmyolo/models/backbones/clip_backbone.py b/mmyolo/models/backbones/clip_backbone.py index 85d049d17..b0f075f64 100644 --- a/mmyolo/models/backbones/clip_backbone.py +++ b/mmyolo/models/backbones/clip_backbone.py @@ -29,7 +29,6 @@ def stem(x): x2 = self.layer3(x1) x3 = self.layer4(x2) # x = self.attnpool(x) - # print('!!!', x1.shape, x2.shape, x3.shape) return [x1, x2, x3] def train(self, mode=True): @@ -37,6 +36,7 @@ def train(self, mode=True): for m in self.modules(): if isinstance(m, _BatchNorm): m.eval() - elif self.freeze_backbone: - for m in self.modules(): - m.eval() + elif self.freeze_backbone: # 写法要优化 + for ind, m in enumerate(self.modules()): + if ind != 0: + m.eval() From 392c6e236189ad06fc03fc03054cebd5e83947b4 Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Thu, 20 Apr 2023 10:47:35 +0800 Subject: [PATCH 6/8] add unused param --- .../yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py index 71c5b66e0..ec1acb210 100644 --- a/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py +++ b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py @@ -1,5 +1,7 @@ _base_ = '../yolov5_l-v61_syncbn_fast_8xb16-300e_coco.py' +find_unused_parameters = True + model = dict( data_preprocessor=dict( type='YOLOv5DetDataPreprocessor', From 0a336393051f5a810286f1a6b52a77ff7136861d Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Thu, 20 Apr 2023 13:15:26 +0800 Subject: [PATCH 7/8] lr --- .../yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py index ec1acb210..ad02bc415 100644 --- a/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py +++ b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py @@ -31,7 +31,7 @@ norm_cfg=_base_.norm_cfg, act_cfg=dict(type='SiLU', inplace=True))) -base_lr = 0.004 +base_lr = 0.002 weight_decay = 0.05 optim_wrapper = dict( _delete_=True, From e5d082b0b8411e524fc4431a5b790d2468c82567 Mon Sep 17 00:00:00 2001 From: Lingrui Gu Date: Fri, 21 Apr 2023 09:46:57 +0800 Subject: [PATCH 8/8] add v5 optimizer --- .../yolov5_l_clip_backbone_v5_optimizer.py | 44 +++++++++++ mmyolo/engine/hooks/__init__.py | 5 +- .../hooks/yolov5_param_scheduler_hook.py | 13 ++++ mmyolo/engine/optimizers/__init__.py | 8 +- .../optimizers/yolov5_optim_constructor.py | 74 +++++++++++++++++++ 5 files changed, 140 insertions(+), 4 deletions(-) create mode 100644 configs/yolov5/clip_backbone/yolov5_l_clip_backbone_v5_optimizer.py diff --git a/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_v5_optimizer.py b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_v5_optimizer.py new file mode 100644 index 000000000..76cce425d --- /dev/null +++ b/configs/yolov5/clip_backbone/yolov5_l_clip_backbone_v5_optimizer.py @@ -0,0 +1,44 @@ +_base_ = '../yolov5_l-v61_syncbn_fast_8xb16-300e_coco.py' + +find_unused_parameters = True + +model = dict( + data_preprocessor=dict( + type='YOLOv5DetDataPreprocessor', + # 按照clip里的预处理方式 + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]), + backbone=dict( + _delete_=True, + type='CLIPModifiedResNet', + freeze_backbone=False, # 只冻结bn,不冻结backbone + freeze_bn=True, + output_dim=1024, + layers=[3, 4, 6, 3], + width=64, + heads=64 * 32 // 64), + neck=dict( + type='YOLOv5PAFPN', + deepen_factor=_base_.deepen_factor, + widen_factor=_base_.widen_factor, + in_channels=[ + int(512 / _base_.widen_factor), + int(1024 / _base_.widen_factor), + int(2048 / _base_.widen_factor) + ], + out_channels=[256, 512, 1024], + num_csp_blocks=3, + norm_cfg=_base_.norm_cfg, + act_cfg=dict(type='SiLU', inplace=True))) + +base_lr = _base_.base_lr +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(lr=base_lr), + constructor='TempCLIPBackboneConstructor') # 单独拎出来backbone的参数 + +default_hooks = dict( + param_scheduler=dict( + type='TempCLIPParamSchedulerHook', backbone_lr_scale=0.1)) + +load_from = 'CLIPResNet50.pth' diff --git a/mmyolo/engine/hooks/__init__.py b/mmyolo/engine/hooks/__init__.py index 0b8deebc8..83b9b9217 100644 --- a/mmyolo/engine/hooks/__init__.py +++ b/mmyolo/engine/hooks/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook from .switch_to_deploy_hook import SwitchToDeployHook -from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook +from .yolov5_param_scheduler_hook import (TempCLIPParamSchedulerHook, + YOLOv5ParamSchedulerHook) from .yolox_mode_switch_hook import YOLOXModeSwitchHook __all__ = [ 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook', - 'PPYOLOEParamSchedulerHook' + 'PPYOLOEParamSchedulerHook', 'TempCLIPParamSchedulerHook' ] diff --git a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py index 777bb49d7..7a7848561 100644 --- a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py +++ b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py @@ -128,3 +128,16 @@ def after_train_epoch(self, runner: Runner): for group_idx, param in enumerate(optimizer.param_groups): param['lr'] = self._base_lr[group_idx] * self.scheduler_fn( cur_epoch) + + +@HOOKS.register_module() +class TempCLIPParamSchedulerHook(YOLOv5ParamSchedulerHook): + + def __init__(self, backbone_lr_scale: float = 0.1, *args, **kwargs): + self.backbone_lr_scale = backbone_lr_scale + super().__init__(*args, **kwargs) + + def before_train(self, runner: Runner): + super().before_train(runner) + # backbone 进行lr缩小 + self._base_lr[3] *= self.backbone_lr_scale diff --git a/mmyolo/engine/optimizers/__init__.py b/mmyolo/engine/optimizers/__init__.py index b598020d0..e19c5ac2c 100644 --- a/mmyolo/engine/optimizers/__init__.py +++ b/mmyolo/engine/optimizers/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .yolov5_optim_constructor import YOLOv5OptimizerConstructor +from .yolov5_optim_constructor import (TempCLIPBackboneConstructor, + YOLOv5OptimizerConstructor) from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor -__all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor'] +__all__ = [ + 'YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor', + 'TempCLIPBackboneConstructor' +] diff --git a/mmyolo/engine/optimizers/yolov5_optim_constructor.py b/mmyolo/engine/optimizers/yolov5_optim_constructor.py index 5e5f42cb5..59e20f71b 100644 --- a/mmyolo/engine/optimizers/yolov5_optim_constructor.py +++ b/mmyolo/engine/optimizers/yolov5_optim_constructor.py @@ -130,3 +130,77 @@ def __call__(self, model: nn.Module) -> OptimWrapper: optim_wrapper = OPTIM_WRAPPERS.build( self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) return optim_wrapper + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class TempCLIPBackboneConstructor(YOLOv5OptimizerConstructor): + + def __call__(self, model: nn.Module) -> OptimWrapper: + if is_model_wrapper(model): + model = model.module + optimizer_cfg = self.optimizer_cfg.copy() + weight_decay = optimizer_cfg.pop('weight_decay', 0) + + if 'batch_size_per_gpu' in optimizer_cfg: + batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu') + # No scaling if total_batch_size is less than + # base_total_batch_size, otherwise linear scaling. + total_batch_size = get_world_size() * batch_size_per_gpu + accumulate = max( + round(self.base_total_batch_size / total_batch_size), 1) + scale_factor = total_batch_size * \ + accumulate / self.base_total_batch_size + + if scale_factor != 1: + weight_decay *= scale_factor + print_log(f'Scaled weight_decay to {weight_decay}', 'current') + + params_groups = [], [], [], [] + + # for v in model.modules(): + # if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): + # params_groups[2].append(v.bias) + # # Includes SyncBatchNorm + # if isinstance(v, nn.modules.batchnorm._NormBase): + # params_groups[1].append(v.weight) + # elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): + # params_groups[0].append(v.weight) + + params = list(model.named_parameters()) + for k, v in params: + # 单独出来backbone的para,使用单独的学习率 + if 'backbone.' in k: + params_groups[3].append(v) + continue + if 'bias' in k: + params_groups[2].append(v) + continue + if 'bn' in k: + params_groups[1].append(v) + elif 'weight' in k and isinstance(v, nn.Parameter): + params_groups[0].append(v) + + # Note: Make sure bias is in the last parameter group + optimizer_cfg['params'] = [] + # conv + optimizer_cfg['params'].append({ + 'params': params_groups[0], + 'weight_decay': weight_decay + }) + # bn + optimizer_cfg['params'].append({'params': params_groups[1]}) + # bias + optimizer_cfg['params'].append({'params': params_groups[2]}) + # backbone + optimizer_cfg['params'].append({'params': params_groups[3]}) + + print_log( + 'Optimizer groups: %g .bias, %g conv.weight, %g other' % + (len(params_groups[2]), len(params_groups[0]), len( + params_groups[1])), 'current') + del params_groups + + optimizer = OPTIMIZERS.build(optimizer_cfg) + optim_wrapper = OPTIM_WRAPPERS.build( + self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) + return optim_wrapper