From 90668a46245f5a91f237dc5430bc03d28a443f56 Mon Sep 17 00:00:00 2001 From: Lupin1998 <1070535169@qq.com> Date: Thu, 6 Jul 2023 10:07:04 +0000 Subject: [PATCH 1/2] add MogaNet --- configs/_base_/models/moganet/moganet_base.py | 23 + .../_base_/models/moganet/moganet_large.py | 23 + .../_base_/models/moganet/moganet_small.py | 23 + configs/_base_/models/moganet/moganet_tiny.py | 23 + .../_base_/models/moganet/moganet_xlarge.py | 23 + .../_base_/models/moganet/moganet_xtiny.py | 23 + configs/moganet/README.md | 81 +++ configs/moganet/metafile.yml | 130 ++++ configs/moganet/moganet-base_8xb128_in1k.py | 42 ++ configs/moganet/moganet-large_8xb128_in1k.py | 101 +++ configs/moganet/moganet-small_8xb128_in1k.py | 39 ++ configs/moganet/moganet-tiny_8xb128_in1k.py | 95 +++ configs/moganet/moganet-xlarge_16xb32_in1k.py | 98 +++ configs/moganet/moganet-xtiny_8xb128_in1k.py | 95 +++ mmpretrain/__init__.py | 4 +- mmpretrain/models/backbones/__init__.py | 2 + mmpretrain/models/backbones/moganet.py | 640 ++++++++++++++++++ model-index.yml | 1 + 18 files changed, 1464 insertions(+), 2 deletions(-) create mode 100644 configs/_base_/models/moganet/moganet_base.py create mode 100644 configs/_base_/models/moganet/moganet_large.py create mode 100644 configs/_base_/models/moganet/moganet_small.py create mode 100644 configs/_base_/models/moganet/moganet_tiny.py create mode 100644 configs/_base_/models/moganet/moganet_xlarge.py create mode 100644 configs/_base_/models/moganet/moganet_xtiny.py create mode 100644 configs/moganet/README.md create mode 100644 configs/moganet/metafile.yml create mode 100644 configs/moganet/moganet-base_8xb128_in1k.py create mode 100644 configs/moganet/moganet-large_8xb128_in1k.py create mode 100644 configs/moganet/moganet-small_8xb128_in1k.py create mode 100644 configs/moganet/moganet-tiny_8xb128_in1k.py create mode 100644 configs/moganet/moganet-xlarge_16xb32_in1k.py create mode 100644 configs/moganet/moganet-xtiny_8xb128_in1k.py create mode 100644 mmpretrain/models/backbones/moganet.py diff --git a/configs/_base_/models/moganet/moganet_base.py b/configs/_base_/models/moganet/moganet_base.py new file mode 100644 index 00000000000..4d5b4e12cab --- /dev/null +++ b/configs/_base_/models/moganet/moganet_base.py @@ -0,0 +1,23 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='MogaNet', arch='base', drop_path_rate=0.2, + attn_force_fp32=True, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), + ]), +) diff --git a/configs/_base_/models/moganet/moganet_large.py b/configs/_base_/models/moganet/moganet_large.py new file mode 100644 index 00000000000..e1d29cec22e --- /dev/null +++ b/configs/_base_/models/moganet/moganet_large.py @@ -0,0 +1,23 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='MogaNet', arch='large', drop_path_rate=0.3, + attn_force_fp32=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=640, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), + ]), +) diff --git a/configs/_base_/models/moganet/moganet_small.py b/configs/_base_/models/moganet/moganet_small.py new file mode 100644 index 00000000000..1c7ce16f06d --- /dev/null +++ b/configs/_base_/models/moganet/moganet_small.py @@ -0,0 +1,23 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='MogaNet', arch='small', drop_path_rate=0.1, + attn_force_fp32=True, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), + ]), +) diff --git a/configs/_base_/models/moganet/moganet_tiny.py b/configs/_base_/models/moganet/moganet_tiny.py new file mode 100644 index 00000000000..de981289e51 --- /dev/null +++ b/configs/_base_/models/moganet/moganet_tiny.py @@ -0,0 +1,23 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='MogaNet', arch='tiny', drop_path_rate=0.1, + attn_force_fp32=True, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=256, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.1), + dict(type='CutMix', alpha=1.0), + ]), +) diff --git a/configs/_base_/models/moganet/moganet_xlarge.py b/configs/_base_/models/moganet/moganet_xlarge.py new file mode 100644 index 00000000000..b893d11032f --- /dev/null +++ b/configs/_base_/models/moganet/moganet_xlarge.py @@ -0,0 +1,23 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='MogaNet', arch='x-large', drop_path_rate=0.4, + attn_force_fp32=True, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=960, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), + ]), +) diff --git a/configs/_base_/models/moganet/moganet_xtiny.py b/configs/_base_/models/moganet/moganet_xtiny.py new file mode 100644 index 00000000000..1030a71abbf --- /dev/null +++ b/configs/_base_/models/moganet/moganet_xtiny.py @@ -0,0 +1,23 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='MogaNet', arch='x-tiny', drop_path_rate=0.05, + attn_force_fp32=True, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=192, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.1), + dict(type='CutMix', alpha=1.0), + ]), +) diff --git a/configs/moganet/README.md b/configs/moganet/README.md new file mode 100644 index 00000000000..fbd76178805 --- /dev/null +++ b/configs/moganet/README.md @@ -0,0 +1,81 @@ +# Efficient Multi-order Gated Aggregation Network + +> [Efficient Multi-order Gated Aggregation Network](https://arxiv.org/abs/2211.03295) + + + +## Abstract + +Since the recent success of Vision Transformers (ViTs), explorations toward ViT-style architectures have triggered the resurgence of ConvNets. In this work, we explore the representation ability of modern ConvNets from a novel view of multi-order game-theoretic interaction, which reflects inter-variable interaction effects w.r.t.~contexts of different scales based on game theory. Within the modern ConvNet framework, we tailor the two feature mixers with conceptually simple yet effective depthwise convolutions to facilitate middle-order information across spatial and channel spaces respectively. In this light, a new family of pure ConvNet architecture, dubbed MogaNet, is proposed, which shows excellent scalability and attains competitive results among state-of-the-art models with more efficient use of parameters on ImageNet and multifarious typical vision benchmarks, including COCO object detection, ADE20K semantic segmentation, 2D\&3D human pose estimation, and video prediction. Typically, MogaNet hits 80.0\% and 87.8\% top-1 accuracy with 5.2M and 181M parameters on ImageNet, outperforming ParC-Net-S and ConvNeXt-L while saving 59\% FLOPs and 17M parameters. The source code is available at https://github.com/Westlake-AI/MogaNet. + +
+ +
+ +## How to use it? + + + +**Predict image** + +```python +from mmpretrain import inference_model + +predict = inference_model('moganet-tiny_3rdparty_8xb128_in1k', 'demo/bird.JPEG') +print(predict['pred_class']) +print(predict['pred_score']) +``` + +**Use the model** + +```python +import torch +from mmpretrain import get_model + +model = get_model('moganet-tiny_3rdparty_8xb128_in1k', pretrained=True) +inputs = torch.rand(1, 3, 224, 224) +out = model(inputs) +print(type(out)) +# To extract features. +feats = model.extract_feat(inputs) +print(type(feats)) +``` + +**Test Command** + +Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). + +Test: + +```shell +python tools/test.py configs/moganet/moganet-tiny_8xb128_in1k.py https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth +``` + + + +## Models and results + +### Image Classification on ImageNet-1k + +| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :-------------------------------------- | :----------: | :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | +| `moganet-xtiny_3rdparty_8xb128_in1k`\* | From scratch | 2.97 | 0.79 | 76.48 | 93.49 | [config](moganet-xtiny_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xtiny_3rdparty_8xb128_in1k.pth) | +| `moganet-tiny_3rdparty_8xb128_in1k`\* | From scratch | 5.20 | 1.09 | 77.24 | 93.51 | [config](moganet-tiny_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-tiny_3rdparty_8xb128_in1k.pth) | +| `moganet-small_3rdparty_8xb128_in1k`\* | From scratch | 4.94 | 25.35 | 83.38 | 96.58 | [config](moganet-small_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-small_3rdparty_8xb128_in1k.pth) | +| `moganet-base_3rdparty_8xb128_in1k`\* | From scratch | 9.88 | 43.72 | 84.20 | 96.77 | [config](moganet-base_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-base_3rdparty_8xb128_in1k.pth) | +| `moganet-large_3rdparty_8xb128_in1k`\* | From scratch | 15.84 | 82.48 | 84.76 | 97.15 | [config](moganet-large_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-large_3rdparty_8xb128_in1k.pth) | +| `moganet-xlarge_3rdparty_16xb32_in1k`\* | From scratch | 34.43 | 180.8 | 85.11 | 97.38 | [config](moganet-xlarge_16xb32_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xlarge_3rdparty_16xb32_in1k.pth) | + +*Models with * are converted from the [official repo](https://github.com/Westlake-AI/MogaNet). The config files of these models are only for inference. We haven't reproduce the training results.* + +## Citation + +```bibtex +@article{Li2022MogaNet, + title={Efficient Multi-order Gated Aggregation Network}, + author={Siyuan Li and Zedong Wang and Zicheng Liu and Cheng Tan and Haitao Lin and Di Wu and Zhiyuan Chen and Jiangbin Zheng and Stan Z. Li}, + journal={ArXiv}, + year={2022}, + volume={abs/2211.03295} +} +``` diff --git a/configs/moganet/metafile.yml b/configs/moganet/metafile.yml new file mode 100644 index 00000000000..0d6925f03a7 --- /dev/null +++ b/configs/moganet/metafile.yml @@ -0,0 +1,130 @@ +Collections: + - Name: MogaNet + Metadata: + Training Data: ImageNet-1k + Architecture: + - Gating + - 1x1 Convolution + - LayerScale + Paper: + URL: https://arxiv.org/abs/2211.03295 + Title: Efficient Multi-order Gated Aggregation Network + README: configs/moganet/README.md + Code: + Version: v1.0.0 + URL: https://github.com/Lupin1998/mmpretrain/tree/main/mmpretrain/models/backbones/moganet.py + +Models: + - Name: moganet-xtiny_3rdparty_8xb128_in1k + Metadata: + FLOPs: 843961073 + Parameters: 3114270 + In Collection: MogaNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.48 + Top 5 Accuracy: 93.49 + Task: Image Classification + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth + Config: configs/moganet/moganet-xtiny_8xb128_in1k.py + Converted From: + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth + Code: https://github.com/Westlake-AI/openmixup + - Name: moganet-tiny_3rdparty_8xb128_in1k + Metadata: + FLOPs: 1168231104 + Parameters: 5449449 + In Collection: MogaNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 77.24 + Top 5 Accuracy: 93.51 + Task: Image Classification + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth + Config: configs/moganet/moganet-tiny_8xb128_in1k.py + Converted From: + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth + Code: https://github.com/Westlake-AI/openmixup + - Name: moganet-small_3rdparty_8xb128_in1k + Metadata: + FLOPs: 5304284610 + Parameters: 26566721 + In Collection: MogaNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.38 + Top 5 Accuracy: 96.58 + Task: Image Classification + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth + Config: configs/moganet/moganet-small_8xb128_in1k.py + Converted From: + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth + Code: https://github.com/Westlake-AI/openmixup + - Name: moganet-base_3rdparty_8xb128_in1k + Metadata: + FLOPs: 10608569221 + Parameters: 45843742 + In Collection: MogaNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 84.20 + Top 5 Accuracy: 96.77 + Task: Image Classification + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth + Config: configs/moganet/moganet-base_8xb128_in1k.py + Converted From: + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth + Code: https://github.com/Westlake-AI/openmixup + - Name: moganet-large_3rdparty_8xb128_in1k + Metadata: + FLOPs: 17008070492 + Parameters: 86486548 + In Collection: MogaNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 84.76 + Top 5 Accuracy: 97.15 + Task: Image Classification + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth + Config: configs/moganet/moganet-large_8xb128_in1k.py + Converted From: + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth + Code: https://github.com/Westlake-AI/openmixup + - Name: moganet-xlarge_3rdparty_16xb32_in1k + Metadata: + FLOPs: 36968931000 + Parameters: 189582540 + In Collection: MogaNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 85.11 + Top 5 Accuracy: 97.38 + Task: Image Classification + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth + Config: configs/moganet/moganet-xlarge_16xb32_in1k.py + Converted From: + Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth + Code: https://github.com/Westlake-AI/openmixup + + # - Name: poolformer-m48_3rdparty_32xb128_in1k + # Metadata: + # FLOPs: 11801805696 + # Parameters: 73473448 + # In Collection: PoolFormer + # Results: + # - Dataset: ImageNet-1k + # Metrics: + # Top 1 Accuracy: 82.51 + # Top 5 Accuracy: 95.95 + # Task: Image Classification + # Weights: https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth + # Config: configs/poolformer/poolformer-m48_32xb128_in1k.py + # Converted From: + # Weights: https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar + # Code: https://github.com/sail-sg/poolformer diff --git a/configs/moganet/moganet-base_8xb128_in1k.py b/configs/moganet/moganet-base_8xb128_in1k.py new file mode 100644 index 00000000000..800f11520eb --- /dev/null +++ b/configs/moganet/moganet-base_8xb128_in1k.py @@ -0,0 +1,42 @@ +_base_ = [ + '../_base_/models/moganet/moganet_base.py', + '../_base_/datasets/imagenet_bs128_poolformer_small_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py', +] + +# schedule settings +optim_wrapper = dict( + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.layer_scale': dict(decay_mult=0.0), + '.scale': dict(decay_mult=0.0), + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5) +] + +# runtime setting +custom_hooks = [ + dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'), + dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL') +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/moganet/moganet-large_8xb128_in1k.py b/configs/moganet/moganet-large_8xb128_in1k.py new file mode 100644 index 00000000000..95ef60be733 --- /dev/null +++ b/configs/moganet/moganet-large_8xb128_in1k.py @@ -0,0 +1,101 @@ +_base_ = [ + '../_base_/models/moganet/moganet_large.py', + '../_base_/datasets/imagenet_bs128_poolformer_small_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py', +] + +# model setting +model = dict(backbone=dict(attn_force_fp32=False)) + +# dataset setting +data_preprocessor = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=248, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=128) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) + +# schedule settings +optim_wrapper = dict( + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.layer_scale': dict(decay_mult=0.0), + '.scale': dict(decay_mult=0.0), + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5) +] + +# runtime setting +custom_hooks = [ + dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'), + dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL') +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/moganet/moganet-small_8xb128_in1k.py b/configs/moganet/moganet-small_8xb128_in1k.py new file mode 100644 index 00000000000..1743933b357 --- /dev/null +++ b/configs/moganet/moganet-small_8xb128_in1k.py @@ -0,0 +1,39 @@ +_base_ = [ + '../_base_/models/moganet/moganet_small.py', + '../_base_/datasets/imagenet_bs128_poolformer_small_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py', +] + +# schedule settings +optim_wrapper = dict( + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.layer_scale': dict(decay_mult=0.0), + '.scale': dict(decay_mult=0.0), + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5) +] + +# runtime setting +custom_hooks = [dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/moganet/moganet-tiny_8xb128_in1k.py b/configs/moganet/moganet-tiny_8xb128_in1k.py new file mode 100644 index 00000000000..b030ffdbafc --- /dev/null +++ b/configs/moganet/moganet-tiny_8xb128_in1k.py @@ -0,0 +1,95 @@ +_base_ = [ + '../_base_/models/moganet/moganet_tiny.py', + '../_base_/datasets/imagenet_bs128_poolformer_small_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py', +] + +# dataset setting +data_preprocessor = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=248, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=128) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) + +# schedule settings +optim_wrapper = dict( + optimizer=dict(weight_decay=0.04), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.layer_scale': dict(decay_mult=0.0), + '.scale': dict(decay_mult=0.0), + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', eta_min=1e-6, by_epoch=False, begin=5) +] + +# runtime setting +custom_hooks = [dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/moganet/moganet-xlarge_16xb32_in1k.py b/configs/moganet/moganet-xlarge_16xb32_in1k.py new file mode 100644 index 00000000000..a170aebd900 --- /dev/null +++ b/configs/moganet/moganet-xlarge_16xb32_in1k.py @@ -0,0 +1,98 @@ +_base_ = [ + '../_base_/models/moganet/moganet_xlarge.py', + '../_base_/datasets/imagenet_bs128_poolformer_small_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py', +] + +# dataset setting +data_preprocessor = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=248, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=32) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) + +# schedule settings +optim_wrapper = dict( + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.layer_scale': dict(decay_mult=0.0), + '.scale': dict(decay_mult=0.0), + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5) +] + +# runtime setting +custom_hooks = [ + dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'), + dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL') +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (16 GPUs) x (32 samples per GPU) +auto_scale_lr = dict(base_batch_size=512) diff --git a/configs/moganet/moganet-xtiny_8xb128_in1k.py b/configs/moganet/moganet-xtiny_8xb128_in1k.py new file mode 100644 index 00000000000..d2a2937a8a8 --- /dev/null +++ b/configs/moganet/moganet-xtiny_8xb128_in1k.py @@ -0,0 +1,95 @@ +_base_ = [ + '../_base_/models/moganet/moganet_xtiny.py', + '../_base_/datasets/imagenet_bs128_poolformer_small_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py', +] + +# dataset setting +data_preprocessor = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=7, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=248, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=128) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) + +# schedule settings +optim_wrapper = dict( + optimizer=dict(weight_decay=0.03), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.layer_scale': dict(decay_mult=0.0), + '.scale': dict(decay_mult=0.0), + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=1e-3, + by_epoch=True, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', eta_min=1e-6, by_epoch=False, begin=5) +] + +# runtime setting +custom_hooks = [dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py index 1d065db1cd8..3f5424699a1 100644 --- a/mmpretrain/__init__.py +++ b/mmpretrain/__init__.py @@ -6,11 +6,11 @@ from .apis import * # noqa: F401, F403 from .version import __version__ -mmcv_minimum_version = '2.0.0' +mmcv_minimum_version = '1.7.0' mmcv_maximum_version = '2.1.0' mmcv_version = digit_version(mmcv.__version__) -mmengine_minimum_version = '0.8.0' +mmengine_minimum_version = '0.7.0' mmengine_maximum_version = '1.0.0' mmengine_version = digit_version(mmengine.__version__) diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py index 60e37fb7b6e..c121559ce5d 100644 --- a/mmpretrain/models/backbones/__init__.py +++ b/mmpretrain/models/backbones/__init__.py @@ -25,6 +25,7 @@ from .mobilenet_v3 import MobileNetV3 from .mobileone import MobileOne from .mobilevit import MobileViT +from .moganet import MogaNet from .mvit import MViT from .poolformer import PoolFormer from .regnet import RegNet @@ -112,6 +113,7 @@ 'DeiT3', 'HorNet', 'MobileViT', + 'MogaNet', 'DaViT', 'BEiTViT', 'RevVisionTransformer', diff --git a/mmpretrain/models/backbones/moganet.py b/mmpretrain/models/backbones/moganet.py new file mode 100644 index 00000000000..44121efe790 --- /dev/null +++ b/mmpretrain/models/backbones/moganet.py @@ -0,0 +1,640 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import Conv2d, DropPath, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class ElementScale(nn.Module): + """A learnable element-wise scaler.""" + + def __init__(self, embed_dims, init_value=0., requires_grad=True): + super(ElementScale, self).__init__() + self.scale = nn.Parameter( + init_value * torch.ones((1, embed_dims, 1, 1)), + requires_grad=requires_grad + ) + + def forward(self, x): + return x * self.scale + + +class ChannelAggregationFFN(BaseModule): + """An implementation of FFN with Channel Aggregation. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + kernel_size (int): The depth-wise conv kernel size as the + depth-wise convolution. Defaults to 3. + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + kernel_size=3, + act_cfg=dict(type='GELU'), + ffn_drop=0., + init_cfg=None): + super(ChannelAggregationFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + + self.fc1 = Conv2d( + in_channels=embed_dims, + out_channels=self.feedforward_channels, + kernel_size=1) + self.dwconv = Conv2d( + in_channels=self.feedforward_channels, + out_channels=self.feedforward_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=True, + groups=self.feedforward_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=embed_dims, + kernel_size=1) + self.drop = nn.Dropout(ffn_drop) + + self.decompose = Conv2d( + in_channels=self.feedforward_channels, # C -> 1 + out_channels=1, kernel_size=1, + ) + self.sigma = ElementScale( + self.feedforward_channels, init_value=1e-5, requires_grad=True) + self.decompose_act = build_activation_layer(act_cfg) + + def feat_decompose(self, x): + # x_d: [B, C, H, W] -> [B, 1, H, W] + x = x + self.sigma(x - self.decompose_act(self.decompose(x))) + return x + + def forward(self, x): + # proj 1 + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + # proj 2 + x = self.feat_decompose(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiOrderDWConv(BaseModule): + """Multi-order Features with Dilated DWConv Kernel. + + Args: + embed_dims (int): Number of input channels. + dw_dilation (list): Dilations of three DWConv layers. + channel_split (list): The raletive ratio of three splited channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + dw_dilation=[1, 2, 3,], + channel_split=[1, 3, 4,], + init_cfg=None): + super(MultiOrderDWConv, self).__init__(init_cfg=init_cfg) + + self.split_ratio = [i / sum(channel_split) for i in channel_split] + self.embed_dims_1 = int(self.split_ratio[1] * embed_dims) + self.embed_dims_2 = int(self.split_ratio[2] * embed_dims) + self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2 + self.embed_dims = embed_dims + assert len(dw_dilation) == len(channel_split) == 3 + assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3 + assert embed_dims % sum(channel_split) == 0 + + # basic DW conv + self.DW_conv0 = Conv2d( + in_channels=self.embed_dims, + out_channels=self.embed_dims, + kernel_size=5, + padding=(1 + 4 * dw_dilation[0]) // 2, + groups=self.embed_dims, + stride=1, dilation=dw_dilation[0], + ) + # DW conv 1 + self.DW_conv1 = Conv2d( + in_channels=self.embed_dims_1, + out_channels=self.embed_dims_1, + kernel_size=5, + padding=(1 + 4 * dw_dilation[1]) // 2, + groups=self.embed_dims_1, + stride=1, dilation=dw_dilation[1], + ) + # DW conv 2 + self.DW_conv2 = Conv2d( + in_channels=self.embed_dims_2, + out_channels=self.embed_dims_2, + kernel_size=7, + padding=(1 + 6 * dw_dilation[2]) // 2, + groups=self.embed_dims_2, + stride=1, dilation=dw_dilation[2], + ) + # a channel convolution + self.PW_conv = Conv2d( # point-wise convolution + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=1) + + def forward(self, x): + x_0 = self.DW_conv0(x) + x_1 = self.DW_conv1( + x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...]) + x_2 = self.DW_conv2( + x_0[:, self.embed_dims-self.embed_dims_2:, ...]) + x = torch.cat([ + x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1) + x = self.PW_conv(x) + return x + + +class MultiOrderGatedAggregation(BaseModule): + """Spatial Block with Multi-order Gated Aggregation. + + Args: + embed_dims (int): Number of input channels. + attn_dw_dilation (list): Dilations of three DWConv layers. + attn_channel_split (list): The raletive ratio of splited channels. + attn_act_cfg (dict, optional): The activation config for Spatial Block. + Default: dict(type='SiLU'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + attn_dw_dilation=[1, 2, 3], + attn_channel_split=[1, 3, 4], + attn_act_cfg=dict(type='SiLU'), + attn_force_fp32=False, + init_cfg=None): + super(MultiOrderGatedAggregation, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.attn_force_fp32 = attn_force_fp32 + self.proj_1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + self.gate = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + self.value = MultiOrderDWConv( + embed_dims=embed_dims, + dw_dilation=attn_dw_dilation, + channel_split=attn_channel_split, + ) + self.proj_2 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + # activation for gating and value + self.act_value = build_activation_layer(attn_act_cfg) + self.act_gate = build_activation_layer(attn_act_cfg) + + # decompose + self.sigma = ElementScale( + embed_dims, init_value=1e-5, requires_grad=True) + + def feat_decompose(self, x): + x = self.proj_1(x) + # x_d: [B, C, H, W] -> [B, C, 1, 1] + x_d = F.adaptive_avg_pool2d(x, output_size=1) + x = x + self.sigma(x - x_d) + x = self.act_value(x) + return x + + def forward_gating(self, g, v): + """ Force to computing gating with fp32 + + Warning: If you use `attn_force_fp32=True` during training, you + should also keep it during evaluation, because the output results + of whether to use `attn_force_fp32` are slightly different. + """ + g = g.to(torch.float32) + v = v.to(torch.float32) + return self.proj_2(self.act_gate(g) * self.act_gate(v)) + + def forward(self, x): + shortcut = x.clone() + # proj 1x1 + x = self.feat_decompose(x) + # gating and value branch + g = self.gate(x) + v = self.value(x) + # aggregation + if not self.attn_force_fp32: + x = self.proj_2(self.act_gate(g) * self.act_gate(v)) + else: + x = self.forward_gating(self.act_gate(g), self.act_gate(v)) + x = x + shortcut + return x + + +class MogaBlock(BaseModule): + """A block of MogaNet. + + Args: + embed_dims (int): Number of input channels. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + act_cfg (dict, optional): The activation config for projections and FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + init_value (float): Init value for Layer Scale. Defaults to 1e-5. + attn_dw_dilation (list): Dilations of three DWConv layers. + attn_channel_split (list): The raletive ratio of splited channels. + attn_act_cfg (dict): The activation config for the gating branch. + Default: dict(type='SiLU'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + ffn_ratio=4., + drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN', eps=1e-5), + init_value=1e-5, + attn_dw_dilation=[1, 2, 3], + attn_channel_split=[1, 3, 4], + attn_act_cfg=dict(type='SiLU'), + attn_force_fp32=False, + init_cfg=None): + super(MogaBlock, self).__init__(init_cfg=init_cfg) + self.out_channels = embed_dims + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + # spatial attention + self.attn = MultiOrderGatedAggregation( + embed_dims, + attn_dw_dilation=attn_dw_dilation, + attn_channel_split=attn_channel_split, + attn_act_cfg=attn_act_cfg, + attn_force_fp32=attn_force_fp32, + ) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + # channel MLP + mlp_hidden_dim = int(embed_dims * ffn_ratio) + self.mlp = ChannelAggregationFFN( # DWConv + Channel Aggregation FFN + embed_dims=embed_dims, + feedforward_channels=mlp_hidden_dim, + act_cfg=act_cfg, + ffn_drop=drop_rate, + ) + + # init layer scale + self.layer_scale_1 = nn.Parameter( + init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True) + + def forward(self, x): + # spatial + identity = x + x = self.layer_scale_1 * self.attn(self.norm1(x)) + x = identity + self.drop_path(x) + # channel + identity = x + x = self.layer_scale_2 * self.mlp(self.norm2(x)) + x = identity + self.drop_path(x) + return x + + +class ConvPatchEmbed(PatchEmbed): + """An implementation of Conv patch embedding layer. + + The differences between ConvPatchEmbed & ViT PatchEmbed: + 1. Use BN. + 2. Do not use 'flatten' and 'transpose'. + """ + + def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs): + super(ConvPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs) + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class StackConvPatchEmbed(BaseModule): + """An implementation of Stack Conv patch embedding layer. + + Args: + in_features (int): The feature dimension. + embed_dims (int): The output dimension of PatchEmbed. + kernel_size (int): The conv kernel size of stack patch embedding. + Defaults to 3. + stride (int): The conv stride of stack patch embedding. + Defaults to 2. + act_cfg (dict, optional): The activation config in PatchEmbed. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer in PatchEmbed. + Defaults: dict(type='BN'). + """ + + def __init__(self, + in_channels, + embed_dims, + kernel_size=3, + stride=2, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + init_cfg=None, + ): + super(StackConvPatchEmbed, self).__init__(init_cfg) + + self.projection = nn.Sequential( + Conv2d(in_channels, embed_dims // 2, kernel_size=kernel_size, + stride=stride, padding=kernel_size // 2), + build_norm_layer(norm_cfg, embed_dims // 2)[1], + build_activation_layer(act_cfg), + Conv2d(embed_dims // 2, embed_dims, kernel_size=kernel_size, + stride=stride, padding=kernel_size // 2), + build_norm_layer(norm_cfg, embed_dims)[1], + ) + + def forward(self, x): + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + return x, out_size + + +@MODELS.register_module() +class MogaNet(BaseBackbone): + """MogaNet. + + A PyTorch implementation of MogaNet introduced by: + `Efficient Multi-order Gated Aggregation Network `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): MogaNet architecture. + If use string, choose from 'xtiny', 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **ffn_ratios** (List[int]): The number of expansion ratio of + feedforward network hidden layer channels. + + Defaults to 'tiny'. + patch_sizes (List[int | tuple]): The patch size in patch embeddings. + Defaults to [3, 3, 3, 3]. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + init_value (float): Init value for Layer Scale. Defaults to 1e-5. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + stem_norm_cfg (dict): Config dict for normalization layer for stems. + Defaults to ``dict(type='LN')``. + conv_norm_cfg (dict): Config dict for convolution normalization layer. + Defaults to ``dict(type='BN')``. + patchembed_types (list): The type of PatchEmbedding in each stage. + Defaults to ``['ConvEmbed', 'Conv', 'Conv', 'Conv',]``. + attn_dw_dilation (list): The dilate rate of depth-wise convolutions in + Moga Blocks. Defaults to ``[1, 2, 3]``. + attn_channel_split (list): The channel split rate of three depth-wise + convolutions in Moga Blocks. Defaults to ``[1, 3, 4]``, i.e., + divided into ``[1/8, 3/8, 4/8]``. + attn_act_cfg (dict): Config dict for activation of gating in Moga + Blocks. Defaults to ``dict(type='SiLU')``. + attn_final_dilation (bool): Whether to adopt dilated depth-wise + convolutions in the final stage. Defaults to True. + attn_force_fp32 (bool): Whether to force the gating running with fp32. + Warning: If you use `attn_force_fp32=True` during training, you + should also keep it during evaluation, because the output results + of whether to use `attn_force_fp32` are different. Defaults to True. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + init_cfg (dict, optional): Initialization config dict + """ + + arch_settings = { + **dict.fromkeys(['xt', 'x-tiny'], + {'embed_dims': [32, 64, 96, 192], + 'depths': [3, 3, 10, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': [32, 64, 128, 256], + 'depths': [3, 3, 12, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [2, 3, 12, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 160, 320, 512], + 'depths': [4, 6, 22, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 160, 320, 640], + 'depths': [4, 6, 44, 4], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['xl', 'x-large'], + {'embed_dims': [96, 192, 480, 960], + 'depths': [6, 6, 44, 4], + 'ffn_ratios': [8, 8, 4, 4]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + patch_sizes=[3, 3, 3, 3], + in_channels=3, + drop_rate=0., + drop_path_rate=0., + init_value=1e-5, + out_indices=(3, ), + frozen_stages=-1, + norm_eval=False, + stem_norm_cfg=dict(type='BN', eps=1e-5), + conv_norm_cfg=dict(type='BN', eps=1e-5), + patchembed_types=['ConvEmbed', 'Conv', 'Conv', 'Conv',], + attn_dw_dilation=[1, 2, 3], + attn_channel_split=[1, 3, 4], + attn_act_cfg=dict(type='SiLU'), + attn_final_dilation=True, + attn_force_fp32=False, + block_cfgs=dict(), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + self.arch_settings = self.arch_settings[arch] + else: + essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.ffn_ratios = self.arch_settings['ffn_ratios'] + self.num_stages = len(self.depths) + self.out_indices = out_indices + self.norm_eval = norm_eval + self.attn_force_fp32 = attn_force_fp32 + self.use_layer_norm = stem_norm_cfg['type'] == 'LN' + assert stem_norm_cfg['type'] in ['BN', 'SyncBN', 'LN', 'LN2d',] + assert len(patchembed_types) == self.num_stages + + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + for i, depth in enumerate(self.depths): + if i == 0 and patchembed_types[i] == "ConvEmbed": + assert patch_sizes[i] <= 3 + patch_embed = StackConvPatchEmbed( + in_channels=in_channels, + embed_dims=self.embed_dims[i], + kernel_size=patch_sizes[i], + stride=patch_sizes[i] // 2 + 1, + norm_cfg=conv_norm_cfg, + ) + else: + patch_embed = ConvPatchEmbed( + in_channels=in_channels if i == 0 else self.embed_dims[i - 1], + input_size=None, + embed_dims=self.embed_dims[i], + kernel_size=patch_sizes[i], + stride=patch_sizes[i] // 2 + 1, + padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), + norm_cfg=conv_norm_cfg) + + if i == self.num_stages - 1 and not attn_final_dilation: + attn_dw_dilation = [1, 2, 1] + blocks = nn.ModuleList([ + MogaBlock( + embed_dims=self.embed_dims[i], + ffn_ratio=self.ffn_ratios[i], + drop_rate=drop_rate, + drop_path_rate=dpr[cur_block_idx + j], + norm_cfg=conv_norm_cfg, + init_value=init_value, + attn_dw_dilation=attn_dw_dilation, + attn_channel_split=attn_channel_split, + attn_act_cfg=attn_act_cfg, + attn_force_fp32=attn_force_fp32, + **block_cfgs) for j in range(depth) + ]) + cur_block_idx += depth + norm = build_norm_layer(stem_norm_cfg, self.embed_dims[i])[1] + + self.add_module(f'patch_embed{i + 1}', patch_embed) + self.add_module(f'blocks{i + 1}', blocks) + self.add_module(f'norm{i + 1}', norm) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def forward(self, x): + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + blocks = getattr(self, f'blocks{i + 1}') + norm = getattr(self, f'norm{i + 1}') + + x, hw_shape = patch_embed(x) + for block in blocks: + x = block(x) + if self.use_layer_norm: + x = x.flatten(2).transpose(1, 2) + x = norm(x) + x = x.reshape(-1, *hw_shape, + blocks.out_channels).permute(0, 3, 1, 2).contiguous() + else: + x = norm(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = getattr(self, f'patch_embed{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze blocks + m = getattr(self, f'blocks{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MogaNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, (_BatchNorm, nn.SyncBatchNorm)): + m.eval() diff --git a/model-index.yml b/model-index.yml index 3fb3d0457d6..1c903bf9739 100644 --- a/model-index.yml +++ b/model-index.yml @@ -33,6 +33,7 @@ Import: - configs/mvit/metafile.yml - configs/edgenext/metafile.yml - configs/mobileone/metafile.yml + - configs/moganet/metafile.yml - configs/efficientformer/metafile.yml - configs/swin_transformer_v2/metafile.yml - configs/deit3/metafile.yml From ce54aa85ea7fb65bb8d2254bcac4b8f52f4bc550 Mon Sep 17 00:00:00 2001 From: Lupin1998 <1070535169@qq.com> Date: Thu, 6 Jul 2023 10:21:23 +0000 Subject: [PATCH 2/2] fix bug --- configs/moganet/metafile.yml | 29 ++++++----------------------- mmpretrain/__init__.py | 4 ++-- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/configs/moganet/metafile.yml b/configs/moganet/metafile.yml index 0d6925f03a7..fd8ac971ccc 100644 --- a/configs/moganet/metafile.yml +++ b/configs/moganet/metafile.yml @@ -26,7 +26,7 @@ Models: Top 1 Accuracy: 76.48 Top 5 Accuracy: 93.49 Task: Image Classification - Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth + Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xtiny_3rdparty_8xb128_in1k.pth Config: configs/moganet/moganet-xtiny_8xb128_in1k.py Converted From: Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth @@ -42,7 +42,7 @@ Models: Top 1 Accuracy: 77.24 Top 5 Accuracy: 93.51 Task: Image Classification - Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth + Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-tiny_3rdparty_8xb128_in1k.pth Config: configs/moganet/moganet-tiny_8xb128_in1k.py Converted From: Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth @@ -58,7 +58,7 @@ Models: Top 1 Accuracy: 83.38 Top 5 Accuracy: 96.58 Task: Image Classification - Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth + Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-small_3rdparty_8xb128_in1k.pth Config: configs/moganet/moganet-small_8xb128_in1k.py Converted From: Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth @@ -74,7 +74,7 @@ Models: Top 1 Accuracy: 84.20 Top 5 Accuracy: 96.77 Task: Image Classification - Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth + Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-base_3rdparty_8xb128_in1k.pth Config: configs/moganet/moganet-base_8xb128_in1k.py Converted From: Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth @@ -90,7 +90,7 @@ Models: Top 1 Accuracy: 84.76 Top 5 Accuracy: 97.15 Task: Image Classification - Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth + Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-large_3rdparty_8xb128_in1k.pth Config: configs/moganet/moganet-large_8xb128_in1k.py Converted From: Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth @@ -106,25 +106,8 @@ Models: Top 1 Accuracy: 85.11 Top 5 Accuracy: 97.38 Task: Image Classification - Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth + Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xlarge_3rdparty_16xb32_in1k.pth Config: configs/moganet/moganet-xlarge_16xb32_in1k.py Converted From: Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth Code: https://github.com/Westlake-AI/openmixup - - # - Name: poolformer-m48_3rdparty_32xb128_in1k - # Metadata: - # FLOPs: 11801805696 - # Parameters: 73473448 - # In Collection: PoolFormer - # Results: - # - Dataset: ImageNet-1k - # Metrics: - # Top 1 Accuracy: 82.51 - # Top 5 Accuracy: 95.95 - # Task: Image Classification - # Weights: https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth - # Config: configs/poolformer/poolformer-m48_32xb128_in1k.py - # Converted From: - # Weights: https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar - # Code: https://github.com/sail-sg/poolformer diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py index 3f5424699a1..1d065db1cd8 100644 --- a/mmpretrain/__init__.py +++ b/mmpretrain/__init__.py @@ -6,11 +6,11 @@ from .apis import * # noqa: F401, F403 from .version import __version__ -mmcv_minimum_version = '1.7.0' +mmcv_minimum_version = '2.0.0' mmcv_maximum_version = '2.1.0' mmcv_version = digit_version(mmcv.__version__) -mmengine_minimum_version = '0.7.0' +mmengine_minimum_version = '0.8.0' mmengine_maximum_version = '1.0.0' mmengine_version = digit_version(mmengine.__version__)