Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Feature] Support RTMDet-Ins fast training #649

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
340 changes: 340 additions & 0 deletions configs/rtmdet_ins/rtmdet-ins_l_syncbn_fast_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
_base_ = ['../_base_/default_runtime.py']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置应该是放到 config/rtmdet/ins_seg 下比较好?更好管理?你觉得呢


# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path

num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 10
# persistent_workers must be False if num_workers is 0.
persistent_workers = True

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
# Base learning rate for optim_wrapper. Corresponding to 8xb32=256 bs

base_lr = 0.004
max_epochs = 300 # Maximum training epochs
# Change train_pipeline for final 20 epochs (stage 2)
num_epochs_stage2 = 20

model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=1000,
score_thr=0.05, # Threshold to filter out boxes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以加个注释说明:实例分割任务相比目标检测后处理速度更慢,因此需要加大 score_thr 和减少 nms_pre 和 max_per_img 等参数

nms=dict(type='nms', iou_threshold=0.6), # NMS type and threshold
max_per_img=100, # Max number of detections of each image
mask_thr_binary=0.5) # Threshold of binary mask

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# ratio range for random resize
random_resize_ratio_range = (0.1, 2.0)
# Cached images number in mosaic
mosaic_max_cached_images = 40
# Number of cached images in mixup
mixup_max_cached_images = 20
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 10
use_mask2refine = True
copypaste_prob = 0.3
mask_downsample_stride = 4

# Config of batch shapes. Only on val.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1.0
# The scaling factor that controls the width of the network structure
widen_factor = 1.0
# Strides of multi-scale prior box
strides = [8, 16, 32]

norm_cfg = dict(type='BN') # Normalization config

# -----train val related-----
lr_start_factor = 1.0e-5
dsl_topk = 13 # Number of bbox selected in each level
loss_cls_weight = 1.0
loss_bbox_weight = 2.0
loss_mask_weight = 2.0
qfl_beta = 2.0 # beta of QualityFocalLoss
weight_decay = 0.05

# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# validation intervals in stage 2
val_interval_stage2 = 1
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False),
backbone=dict(
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
channel_attention=True,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='CSPNeXtPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=256,
num_csp_blocks=3,
expand_ratio=0.5,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='RTMDetInsHead',
head_module=dict(
type='RTMDetInsSepBNHeadModule',
num_classes=num_classes,
in_channels=256,
stacked_convs=2,
feat_channels=256,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
share_conv=True,
pred_kernel_size=1,
featmap_strides=strides),
prior_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0, strides=strides),
bbox_coder=dict(type='DistancePointBBoxCoder'),
loss_cls=dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=qfl_beta,
loss_weight=loss_cls_weight),
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=loss_bbox_weight),
loss_mask=dict(
type='mmdet.DiceLoss',
loss_weight=loss_mask_weight,
eps=5e-6,
reduction='mean'),
mask_loss_stride=mask_downsample_stride),
train_cfg=dict(
assigner=dict(
type='BatchDynamicSoftLabelAssigner',
num_classes=num_classes,
topk=dsl_topk,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=model_test_cfg,
)

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(
type='LoadAnnotations',
with_bbox=True,
with_mask=True,
mask2bbox=use_mask2refine),
dict(
type='Mosaic',
img_scale=img_scale,
use_cached=True,
max_cached_images=mosaic_max_cached_images,
pad_val=114.0),
dict(type='YOLOv5CopyPaste', prob=copypaste_prob),
dict(
type='mmdet.RandomResize',
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=random_resize_ratio_range,
resize_type='mmdet.Resize',
keep_ratio=True),
dict(
type='mmdet.RandomCrop',
crop_size=img_scale,
recompute_bbox=True,
allow_negative_crop=True),
dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1, 1)),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type='YOLOv5MixUp',
use_cached=True,
max_cached_images=mixup_max_cached_images),
dict(type='Mask2Tensor', downsample_stride=mask_downsample_stride),
dict(type='mmdet.PackDetInputs')
]

train_pipeline_stage2 = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(
type='LoadAnnotations',
with_bbox=True,
with_mask=True,
mask2bbox=use_mask2refine),
dict(
type='mmdet.RandomResize',
scale=img_scale,
ratio_range=random_resize_ratio_range,
resize_type='mmdet.Resize',
keep_ratio=True),
dict(
type='mmdet.RandomCrop',
crop_size=img_scale,
recompute_bbox=True,
allow_negative_crop=True),
dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1, 1)),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='Mask2Tensor', downsample_stride=mask_downsample_stride),
dict(type='mmdet.PackDetInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(
type='LoadAnnotations',
with_bbox=True,
with_mask=True,
_scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]

train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
collate_fn=dict(type='yolov5_collate'),
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=val_ann_file,
data_prefix=dict(img=val_data_prefix),
test_mode=True,
batch_shapes_cfg=batch_shapes_cfg,
pipeline=test_pipeline))

test_dataloader = val_dataloader

# Reduce evaluation time
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric=['bbox', 'segm'])
test_evaluator = val_evaluator

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=weight_decay),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=lr_start_factor,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]

# hooks
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=save_checkpoint_intervals,
max_keep_ckpts=max_keep_ckpts # only keep latest 3 checkpoints
))

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - num_epochs_stage2,
switch_pipeline=train_pipeline_stage2)
]

train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_checkpoint_intervals,
dynamic_intervals=[(max_epochs - num_epochs_stage2, val_interval_stage2)])

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
Loading