Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The problem of mmpretrain modifying the backbone network #12270

Open
GG22bond opened this issue Dec 13, 2024 · 1 comment
Open

The problem of mmpretrain modifying the backbone network #12270

GG22bond opened this issue Dec 13, 2024 · 1 comment
Assignees

Comments

@GG22bond
Copy link

GG22bond commented Dec 13, 2024

  import torch
  from mmpretrain.models.backbones.timm_backbone import TIMMBackbone
  
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  inputs = torch.rand(1, 3, 640, 640).to(device)
  
  # model = TIMMBackbone(model_name='cspdarknet53', pretrained=False, features_only=True, out_indices=(3,4,5))
  model = TIMMBackbone(model_name='cspresnet50w', pretrained=False, features_only=True, out_indices=[2,3,4])
  # model = TIMMBackbone(model_name='resnet50d', pretrained=False, features_only=True, out_indices=(2,3,4))
  
  # model = mmpretrain.TIMMBackbone(model_name='cspresnet50w', pretrained=False, features_only=True,out_indices=(2,3,4))
  
  model.to(device)
  
  level_outputs = model.forward(inputs)
  
  for level_out in level_outputs:
      print(tuple(level_out.shape))

图片
The '**kwargs' parameter in TIMMBackbone supports out_indices to specify output features, but I modified it according to the official website https://mmdetection.readthedocs.io/zh-cn/latest/advanced_guides/how_to.html tutorial and got the following error:
AttributeError: 'TIMMBackbone' object has no attribute 'out_indices'

图片

My configuration file:

_base_ = [
    '../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]

eval_size = (640, 640)

custom_imports = dict(imports=['mmpretrain.models'], allow_failed_imports=False)

model = dict(
    type='RTDETR',
    num_queries=300,  # num_matching_queries
    with_box_refine=True,
    as_two_stage=True,
    eval_size=eval_size,
    data_preprocessor=dict(
        type='DetDataPreprocessor',
        mean=[0, 0, 0],
        std=[255, 255, 255],
        bgr_to_rgb=True,
        pad_size_divisor=32),
    backbone=dict(
        type='mmpretrain.TIMMBackbone',
        model_name='cspresnet50w',
        pretrained=False,
        features_only=True,
        out_indices=(2,3,4),
    ),

    neck=dict(
        type='HybridEncoder',
        num_encoder_layers=1,
        use_encoder_idx=[2],
        layer_cfg=dict(
            self_attn_cfg=dict(embed_dims=256, num_heads=8,
                               dropout=0.0),  # 0.1 for DeformDETR
            ffn_cfg=dict(
                embed_dims=256,
                feedforward_channels=1024,  # 1024 for DeformDETR
                ffn_drop=0.0,
                act_cfg=dict(type='GELU'))),
        projector=dict(
            type='ChannelMapper',
            in_channels=[256, 256, 256],
            kernel_size=1,
            out_channels=256,
            act_cfg=None,
            norm_cfg=dict(type='BN'),
            num_outs=3)),  # 0.1 for DeformDETR
    encoder=None,
    decoder=dict(
        num_layers=6,
        eval_idx=-1,
        layer_cfg=dict(
            self_attn_cfg=dict(embed_dims=256, num_heads=8,
                               dropout=0.0),  # 0.1 for DeformDETR
            cross_attn_cfg=dict(
                embed_dims=256,
                num_levels=3,  # 4 for DeformDETR
                dropout=0.0),  # 0.1 for DeformDETR
            ffn_cfg=dict(
                embed_dims=256,
                feedforward_channels=1024,  # 2048 for DINO
                ffn_drop=0.0)),  # 0.1 for DeformDETR
        post_norm_cfg=None),
    positional_encoding=dict(
        num_feats=128,
        normalize=True,
        offset=0.0,  # -0.5 for DeformDETR
        temperature=20),  # 10000 for DeformDETR
    bbox_head=dict(
        type='RTDETRHead',
        num_classes=1,
        sync_cls_avg_factor=True,
        loss_cls=dict(
            type='VarifocalLoss',  
            use_sigmoid=True,
            use_rtdetr=True,
            gamma=2.0,
            alpha=0.75,  # 0.25 in DINO
            loss_weight=1.0),  # 2.0 in DeformDETR
        loss_bbox=dict(type='L1Loss', loss_weight=5.0), 
        loss_iou=dict(type='CIoULoss', loss_weight=2.0)),
    dn_cfg=dict(
        label_noise_scale=0.5,
        box_noise_scale=1.0,  # 0.4 for DN-DETR
        group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
    # training and testing settings
    train_cfg=dict(
        assigner=dict(
            type='HungarianAssigner',  
            match_costs=[
                dict(type='FocalLossCost', weight=2.0),
                dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
                dict(type='IoUCost', iou_mode='giou', weight=2.0)
            ])),
    test_cfg=dict(max_per_img=300))  # 100 for DeformDETR


dataset_type = 'CocoDataset'
data_root = 'data/ISIC/'
evalute_type = 'CocoMetric'

metainfo = {
    'classes': ('Lesions'),
    # 'palette': [
    #     (220, 20, 60),
    # ]
}


# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
    dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='PhotoMetricDistortion'),
    dict(
        type='Expand',
        mean=[103.53, 116.28, 123.675],
        to_rgb=True,
        ratio_range=(1, 4),
        ),
    dict(
        type='RandomCrop',
        crop_size=(0.3, 1.0),
        crop_type='relative_range',
        ),
    dict(type='RandomFlip', prob=0.5),
    dict(
        type='RandomChoiceResize',
        scales=[(480, 480), (512, 512), (544, 544), (576, 576), (608, 608),
                (640, 640), (640, 640), (640, 640), (672, 672), (704, 704),
                (736, 736), (768, 768), (800, 800)],
        keep_ratio=False,
        random_interpolation=True),
    dict(type='PackDetInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
    dict(
        type='Resize',
        scale=eval_size,
        keep_ratio=False,
        interpolation='bicubic'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]

train_dataloader = dict(
    batch_size=4,
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        metainfo=metainfo,
        ann_file='annotations/train.json',
        data_prefix=dict(img='train/images/'),
        filter_cfg=dict(filter_empty_gt=False),
        pipeline=train_pipeline))

val_dataloader = dict(
    batch_size=4,
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        metainfo=metainfo,
        ann_file='annotations/valid.json',
        data_prefix=dict(img='valid/images/'),
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(ann_file=data_root + 'annotations/valid.json')
test_evaluator = val_evaluator

# optimizer
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(
        type='AdamW',
        lr=0.0001,  # 0.0002 for DeformDETR
        weight_decay=0.0001),

    # optimizer=dict(
    #     type='SGD',
    #     lr=0.001,  # SGD for ultralytics
    #     momentum=0.9,
    #     weight_decay=0.0001),

    clip_grad=dict(max_norm=0.1, norm_type=2),
    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
)  # custom_keys contains sampling_offsets and reference_points in DeformDETR  # noqa

# learning policy
max_epochs = 20
train_cfg = dict(
    type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# load_from = ''

param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
        end=2000),
    dict(
        type='MultiStepLR',
        begin=0,
        end=max_epochs,
        by_epoch=True,
        milestones=[100],
        gamma=1.0)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)

custom_hooks = [
    dict(
        type='EMAHook',
        ema_type='ExpMomentumEMA',
        momentum=0.0001,
        update_buffers=True,
        priority=49),
]

# python tools/train.py configs/rtdetr/cspresnet53_backbone_rtdetr.py --work-dir work_dirs/cspresnet53_backbone_rtdetr
@GG22bond
Copy link
Author

GG22bond commented Dec 13, 2024

This question has been modified

Add 'out_indices' parameter in 'class TIMMBackbone(BaseBackbone):'

class TIMMBackbone(BaseBackbone):
    """Wrapper to use backbones from timm library with optional out_indices.

    More details can be found in
    `timm <https://github.com/rwightman/pytorch-image-models>`_.
    See especially the document for `feature extraction
    <https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.

    Args:
        model_name (str): Name of timm model to instantiate.
        features_only (bool): Whether to extract feature pyramid (multi-scale
            feature maps from the deepest layer at each stride). For Vision
            Transformer models that do not support this argument,
            set this False. Defaults to False.
        pretrained (bool): Whether to load pretrained weights.
            Defaults to False.
        checkpoint_path (str): Path of checkpoint to load at the last of
            ``timm.create_model``. Defaults to empty string, which means
            not loading.
        in_channels (int): Number of input image channels. Defaults to 3.
        out_indices (list[int]): List of indices of layers to output features.
        init_cfg (dict or list[dict], optional): Initialization config dict of
            OpenMMLab projects. Defaults to None.
        **kwargs: Other timm & model specific arguments.
    """

    @require('timm')
    def __init__(self,
                 model_name,
                 features_only=False,
                 pretrained=False,
                 checkpoint_path='',
                 in_channels=3,
                 out_indices=None,  # Add out_indices to constructor
                 init_cfg=None,
                 **kwargs):
        import timm

        if not isinstance(pretrained, bool):
            raise TypeError('pretrained must be bool, not str for model path')
        if features_only and checkpoint_path:
            warnings.warn(
                'Using both features_only and checkpoint_path will cause error'
                ' in timm. See '
                'https://github.com/rwightman/pytorch-image-models/issues/488')

        super(TIMMBackbone, self).__init__(init_cfg)

        # Store out_indices for later use
        self.out_indices = out_indices or []

        self.timm_model = timm.create_model(
            model_name=model_name,
            features_only=features_only,
            pretrained=pretrained,
            in_chans=in_channels,
            checkpoint_path=checkpoint_path,
            **kwargs)

        # reset classifier
        if hasattr(self.timm_model, 'reset_classifier'):
            self.timm_model.reset_classifier(0, '')

        # Hack to use pretrained weights from timm
        if pretrained or checkpoint_path:
            self._is_init = True

        feature_info = getattr(self.timm_model, 'feature_info', None)
        print_timm_feature_info(feature_info)

    def forward(self, x):
        # Get features from the timm model
        features = self.timm_model(x)

        # If features are returned as a tuple or list, extract the required layers
        if isinstance(features, (list, tuple)):
            if self.out_indices:
                # Select the layers specified by out_indices
                features = [features[i] for i in self.out_indices]
            else:
                features = tuple(features)
        else:
            # Single feature case
            features = (features,)

        return features

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants