Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add clip backbone #722

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = 'yolov5_l_clip_backbone_freeze_bn.py'

model = dict(
backbone=dict(
freeze_backbone=True, # 冻结backbone
freeze_bn=False))
55 changes: 55 additions & 0 deletions configs/yolov5/clip_backbone/yolov5_l_clip_backbone_freeze_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_base_ = '../yolov5_l-v61_syncbn_fast_8xb16-300e_coco.py'

find_unused_parameters = True

model = dict(
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
# 按照clip里的预处理方式
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]),
backbone=dict(
_delete_=True,
type='CLIPModifiedResNet',
freeze_backbone=False, # 只冻结bn,不冻结backbone
freeze_bn=True,
output_dim=1024,
layers=[3, 4, 6, 3],
width=64,
heads=64 * 32 // 64),
neck=dict(
type='YOLOv5PAFPN',
deepen_factor=_base_.deepen_factor,
widen_factor=_base_.widen_factor,
in_channels=[
int(512 / _base_.widen_factor),
int(1024 / _base_.widen_factor),
int(2048 / _base_.widen_factor)
],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=_base_.norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)))

base_lr = 0.002
weight_decay = 0.05
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=weight_decay),
paramwise_cfg=dict(
custom_keys=dict(backbone=dict(lr_mult=0.1, decay_mult=0.1)),
norm_decay_mult=0,
bias_decay_mult=0,
bypass_duplicate=True))

save_checkpoint_intervals = 10
max_keep_ckpts = 3
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=save_checkpoint_intervals,
max_keep_ckpts=max_keep_ckpts # only keep latest 3 checkpoints
))

load_from = 'CLIPResNet50.pth'
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
_base_ = '../yolov5_l-v61_syncbn_fast_8xb16-300e_coco.py'

find_unused_parameters = True

model = dict(
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
# 按照clip里的预处理方式
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]),
backbone=dict(
_delete_=True,
type='CLIPModifiedResNet',
freeze_backbone=False, # 只冻结bn,不冻结backbone
freeze_bn=True,
output_dim=1024,
layers=[3, 4, 6, 3],
width=64,
heads=64 * 32 // 64),
neck=dict(
type='YOLOv5PAFPN',
deepen_factor=_base_.deepen_factor,
widen_factor=_base_.widen_factor,
in_channels=[
int(512 / _base_.widen_factor),
int(1024 / _base_.widen_factor),
int(2048 / _base_.widen_factor)
],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=_base_.norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)))

base_lr = _base_.base_lr
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(lr=base_lr),
constructor='TempCLIPBackboneConstructor') # 单独拎出来backbone的参数

default_hooks = dict(
param_scheduler=dict(
type='TempCLIPParamSchedulerHook', backbone_lr_scale=0.1))

load_from = 'CLIPResNet50.pth'
5 changes: 3 additions & 2 deletions mmyolo/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
from .switch_to_deploy_hook import SwitchToDeployHook
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
from .yolov5_param_scheduler_hook import (TempCLIPParamSchedulerHook,
YOLOv5ParamSchedulerHook)
from .yolox_mode_switch_hook import YOLOXModeSwitchHook

__all__ = [
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
'PPYOLOEParamSchedulerHook'
'PPYOLOEParamSchedulerHook', 'TempCLIPParamSchedulerHook'
]
13 changes: 13 additions & 0 deletions mmyolo/engine/hooks/yolov5_param_scheduler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,16 @@ def after_train_epoch(self, runner: Runner):
for group_idx, param in enumerate(optimizer.param_groups):
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(
cur_epoch)


@HOOKS.register_module()
class TempCLIPParamSchedulerHook(YOLOv5ParamSchedulerHook):

def __init__(self, backbone_lr_scale: float = 0.1, *args, **kwargs):
self.backbone_lr_scale = backbone_lr_scale
super().__init__(*args, **kwargs)

def before_train(self, runner: Runner):
super().before_train(runner)
# backbone 进行lr缩小
self._base_lr[3] *= self.backbone_lr_scale
8 changes: 6 additions & 2 deletions mmyolo/engine/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .yolov5_optim_constructor import YOLOv5OptimizerConstructor
from .yolov5_optim_constructor import (TempCLIPBackboneConstructor,
YOLOv5OptimizerConstructor)
from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor

__all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor']
__all__ = [
'YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor',
'TempCLIPBackboneConstructor'
]
74 changes: 74 additions & 0 deletions mmyolo/engine/optimizers/yolov5_optim_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,77 @@ def __call__(self, model: nn.Module) -> OptimWrapper:
optim_wrapper = OPTIM_WRAPPERS.build(
self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class TempCLIPBackboneConstructor(YOLOv5OptimizerConstructor):

def __call__(self, model: nn.Module) -> OptimWrapper:
if is_model_wrapper(model):
model = model.module
optimizer_cfg = self.optimizer_cfg.copy()
weight_decay = optimizer_cfg.pop('weight_decay', 0)

if 'batch_size_per_gpu' in optimizer_cfg:
batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
# No scaling if total_batch_size is less than
# base_total_batch_size, otherwise linear scaling.
total_batch_size = get_world_size() * batch_size_per_gpu
accumulate = max(
round(self.base_total_batch_size / total_batch_size), 1)
scale_factor = total_batch_size * \
accumulate / self.base_total_batch_size

if scale_factor != 1:
weight_decay *= scale_factor
print_log(f'Scaled weight_decay to {weight_decay}', 'current')

params_groups = [], [], [], []

# for v in model.modules():
# if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
# params_groups[2].append(v.bias)
# # Includes SyncBatchNorm
# if isinstance(v, nn.modules.batchnorm._NormBase):
# params_groups[1].append(v.weight)
# elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
# params_groups[0].append(v.weight)

params = list(model.named_parameters())
for k, v in params:
# 单独出来backbone的para,使用单独的学习率
if 'backbone.' in k:
params_groups[3].append(v)
continue
if 'bias' in k:
params_groups[2].append(v)
continue
if 'bn' in k:
params_groups[1].append(v)
elif 'weight' in k and isinstance(v, nn.Parameter):
params_groups[0].append(v)

# Note: Make sure bias is in the last parameter group
optimizer_cfg['params'] = []
# conv
optimizer_cfg['params'].append({
'params': params_groups[0],
'weight_decay': weight_decay
})
# bn
optimizer_cfg['params'].append({'params': params_groups[1]})
# bias
optimizer_cfg['params'].append({'params': params_groups[2]})
# backbone
optimizer_cfg['params'].append({'params': params_groups[3]})

print_log(
'Optimizer groups: %g .bias, %g conv.weight, %g other' %
(len(params_groups[2]), len(params_groups[0]), len(
params_groups[1])), 'current')
del params_groups

optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper
3 changes: 2 additions & 1 deletion mmyolo/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_backbone import BaseBackbone
from .clip_backbone import CLIPModifiedResNet
from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
from .csp_resnet import PPYOLOECSPResNet
from .cspnext import CSPNeXt
Expand All @@ -9,5 +10,5 @@
__all__ = [
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
'YOLOv8CSPDarknet'
'YOLOv8CSPDarknet', 'CLIPModifiedResNet'
]
42 changes: 42 additions & 0 deletions mmyolo/models/backbones/clip_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from clip.model import ModifiedResNet
from torch.nn.modules.batchnorm import _BatchNorm

from mmyolo.registry import MODELS


@MODELS.register_module()
class CLIPModifiedResNet(ModifiedResNet):

def __init__(self, freeze_backbone, freeze_bn, *args, **kwargs):
super().__init__(*args, **kwargs)
self.freeze_backbone = freeze_backbone
self.freeze_bn = freeze_bn

def forward(self, x):

def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x

x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x1 = self.layer2(x)
x2 = self.layer3(x1)
x3 = self.layer4(x2)
# x = self.attnpool(x)
return [x1, x2, x3]

def train(self, mode=True):
if self.freeze_bn:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
elif self.freeze_backbone: # 写法要优化
for ind, m in enumerate(self.modules()):
if ind != 0:
m.eval()
11 changes: 6 additions & 5 deletions mmyolo/models/necks/yolov5_pafpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def build_top_down_layer(self, idx: int):
return CSPLayer(
make_divisible(self.in_channels[idx - 1] * 2,
self.widen_factor),
make_divisible(self.in_channels[idx - 1], self.widen_factor),
make_divisible(self.out_channels[idx - 1], self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
Expand Down Expand Up @@ -141,8 +141,8 @@ def build_downsample_layer(self, idx: int) -> nn.Module:
nn.Module: The downsample layer.
"""
return ConvModule(
make_divisible(self.in_channels[idx], self.widen_factor),
make_divisible(self.in_channels[idx], self.widen_factor),
make_divisible(self.out_channels[idx], self.widen_factor),
make_divisible(self.out_channels[idx], self.widen_factor),
kernel_size=3,
stride=2,
padding=1,
Expand All @@ -159,8 +159,9 @@ def build_bottom_up_layer(self, idx: int) -> nn.Module:
nn.Module: The bottom up layer.
"""
return CSPLayer(
make_divisible(self.in_channels[idx] * 2, self.widen_factor),
make_divisible(self.in_channels[idx + 1], self.widen_factor),
make_divisible(self.in_channels[idx] + self.out_channels[idx],
self.widen_factor),
make_divisible(self.out_channels[idx + 1], self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
Expand Down
42 changes: 42 additions & 0 deletions tools/model_converters/clip_to_mmyolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse
from collections import OrderedDict

import clip
import torch

# _MODELS = {
# "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa
# "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa
# "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa
# "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa
# "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # noqa
# "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa
# "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa
# "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", # noqa
# "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", # noqa
# }


def convert(src, dst):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, preprocess = clip.load(src, device=device)
visual_model = model.visual
state_dict = visual_model.state_dict()
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_state_dict['backbone.' + key] = value
data = {'state_dict': new_state_dict}
torch.save(data, dst)
print('save')


def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('--src', default='RN50', help='src clip model key')
parser.add_argument('--dst', default='CLIPResNet50.pth', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)


if __name__ == '__main__':
main()