From 90668a46245f5a91f237dc5430bc03d28a443f56 Mon Sep 17 00:00:00 2001
From: Lupin1998 <1070535169@qq.com>
Date: Thu, 6 Jul 2023 10:07:04 +0000
Subject: [PATCH 1/2] add MogaNet
---
configs/_base_/models/moganet/moganet_base.py | 23 +
.../_base_/models/moganet/moganet_large.py | 23 +
.../_base_/models/moganet/moganet_small.py | 23 +
configs/_base_/models/moganet/moganet_tiny.py | 23 +
.../_base_/models/moganet/moganet_xlarge.py | 23 +
.../_base_/models/moganet/moganet_xtiny.py | 23 +
configs/moganet/README.md | 81 +++
configs/moganet/metafile.yml | 130 ++++
configs/moganet/moganet-base_8xb128_in1k.py | 42 ++
configs/moganet/moganet-large_8xb128_in1k.py | 101 +++
configs/moganet/moganet-small_8xb128_in1k.py | 39 ++
configs/moganet/moganet-tiny_8xb128_in1k.py | 95 +++
configs/moganet/moganet-xlarge_16xb32_in1k.py | 98 +++
configs/moganet/moganet-xtiny_8xb128_in1k.py | 95 +++
mmpretrain/__init__.py | 4 +-
mmpretrain/models/backbones/__init__.py | 2 +
mmpretrain/models/backbones/moganet.py | 640 ++++++++++++++++++
model-index.yml | 1 +
18 files changed, 1464 insertions(+), 2 deletions(-)
create mode 100644 configs/_base_/models/moganet/moganet_base.py
create mode 100644 configs/_base_/models/moganet/moganet_large.py
create mode 100644 configs/_base_/models/moganet/moganet_small.py
create mode 100644 configs/_base_/models/moganet/moganet_tiny.py
create mode 100644 configs/_base_/models/moganet/moganet_xlarge.py
create mode 100644 configs/_base_/models/moganet/moganet_xtiny.py
create mode 100644 configs/moganet/README.md
create mode 100644 configs/moganet/metafile.yml
create mode 100644 configs/moganet/moganet-base_8xb128_in1k.py
create mode 100644 configs/moganet/moganet-large_8xb128_in1k.py
create mode 100644 configs/moganet/moganet-small_8xb128_in1k.py
create mode 100644 configs/moganet/moganet-tiny_8xb128_in1k.py
create mode 100644 configs/moganet/moganet-xlarge_16xb32_in1k.py
create mode 100644 configs/moganet/moganet-xtiny_8xb128_in1k.py
create mode 100644 mmpretrain/models/backbones/moganet.py
diff --git a/configs/_base_/models/moganet/moganet_base.py b/configs/_base_/models/moganet/moganet_base.py
new file mode 100644
index 00000000000..4d5b4e12cab
--- /dev/null
+++ b/configs/_base_/models/moganet/moganet_base.py
@@ -0,0 +1,23 @@
+# Model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MogaNet', arch='base', drop_path_rate=0.2,
+ attn_force_fp32=True,
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=512,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.8),
+ dict(type='CutMix', alpha=1.0),
+ ]),
+)
diff --git a/configs/_base_/models/moganet/moganet_large.py b/configs/_base_/models/moganet/moganet_large.py
new file mode 100644
index 00000000000..e1d29cec22e
--- /dev/null
+++ b/configs/_base_/models/moganet/moganet_large.py
@@ -0,0 +1,23 @@
+# Model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MogaNet', arch='large', drop_path_rate=0.3,
+ attn_force_fp32=False,
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=640,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.8),
+ dict(type='CutMix', alpha=1.0),
+ ]),
+)
diff --git a/configs/_base_/models/moganet/moganet_small.py b/configs/_base_/models/moganet/moganet_small.py
new file mode 100644
index 00000000000..1c7ce16f06d
--- /dev/null
+++ b/configs/_base_/models/moganet/moganet_small.py
@@ -0,0 +1,23 @@
+# Model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MogaNet', arch='small', drop_path_rate=0.1,
+ attn_force_fp32=True,
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=512,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.8),
+ dict(type='CutMix', alpha=1.0),
+ ]),
+)
diff --git a/configs/_base_/models/moganet/moganet_tiny.py b/configs/_base_/models/moganet/moganet_tiny.py
new file mode 100644
index 00000000000..de981289e51
--- /dev/null
+++ b/configs/_base_/models/moganet/moganet_tiny.py
@@ -0,0 +1,23 @@
+# Model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MogaNet', arch='tiny', drop_path_rate=0.1,
+ attn_force_fp32=True,
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=256,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.1),
+ dict(type='CutMix', alpha=1.0),
+ ]),
+)
diff --git a/configs/_base_/models/moganet/moganet_xlarge.py b/configs/_base_/models/moganet/moganet_xlarge.py
new file mode 100644
index 00000000000..b893d11032f
--- /dev/null
+++ b/configs/_base_/models/moganet/moganet_xlarge.py
@@ -0,0 +1,23 @@
+# Model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MogaNet', arch='x-large', drop_path_rate=0.4,
+ attn_force_fp32=True,
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=960,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.8),
+ dict(type='CutMix', alpha=1.0),
+ ]),
+)
diff --git a/configs/_base_/models/moganet/moganet_xtiny.py b/configs/_base_/models/moganet/moganet_xtiny.py
new file mode 100644
index 00000000000..1030a71abbf
--- /dev/null
+++ b/configs/_base_/models/moganet/moganet_xtiny.py
@@ -0,0 +1,23 @@
+# Model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MogaNet', arch='x-tiny', drop_path_rate=0.05,
+ attn_force_fp32=True,
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=192,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.1),
+ dict(type='CutMix', alpha=1.0),
+ ]),
+)
diff --git a/configs/moganet/README.md b/configs/moganet/README.md
new file mode 100644
index 00000000000..fbd76178805
--- /dev/null
+++ b/configs/moganet/README.md
@@ -0,0 +1,81 @@
+# Efficient Multi-order Gated Aggregation Network
+
+> [Efficient Multi-order Gated Aggregation Network](https://arxiv.org/abs/2211.03295)
+
+
+
+## Abstract
+
+Since the recent success of Vision Transformers (ViTs), explorations toward ViT-style architectures have triggered the resurgence of ConvNets. In this work, we explore the representation ability of modern ConvNets from a novel view of multi-order game-theoretic interaction, which reflects inter-variable interaction effects w.r.t.~contexts of different scales based on game theory. Within the modern ConvNet framework, we tailor the two feature mixers with conceptually simple yet effective depthwise convolutions to facilitate middle-order information across spatial and channel spaces respectively. In this light, a new family of pure ConvNet architecture, dubbed MogaNet, is proposed, which shows excellent scalability and attains competitive results among state-of-the-art models with more efficient use of parameters on ImageNet and multifarious typical vision benchmarks, including COCO object detection, ADE20K semantic segmentation, 2D\&3D human pose estimation, and video prediction. Typically, MogaNet hits 80.0\% and 87.8\% top-1 accuracy with 5.2M and 181M parameters on ImageNet, outperforming ParC-Net-S and ConvNeXt-L while saving 59\% FLOPs and 17M parameters. The source code is available at https://github.com/Westlake-AI/MogaNet.
+
+
+
+
+
+## How to use it?
+
+
+
+**Predict image**
+
+```python
+from mmpretrain import inference_model
+
+predict = inference_model('moganet-tiny_3rdparty_8xb128_in1k', 'demo/bird.JPEG')
+print(predict['pred_class'])
+print(predict['pred_score'])
+```
+
+**Use the model**
+
+```python
+import torch
+from mmpretrain import get_model
+
+model = get_model('moganet-tiny_3rdparty_8xb128_in1k', pretrained=True)
+inputs = torch.rand(1, 3, 224, 224)
+out = model(inputs)
+print(type(out))
+# To extract features.
+feats = model.extract_feat(inputs)
+print(type(feats))
+```
+
+**Test Command**
+
+Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
+
+Test:
+
+```shell
+python tools/test.py configs/moganet/moganet-tiny_8xb128_in1k.py https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
+```
+
+
+
+## Models and results
+
+### Image Classification on ImageNet-1k
+
+| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download |
+| :-------------------------------------- | :----------: | :--------: | :-------: | :-------: | :-------: | :-------: | :-------: |
+| `moganet-xtiny_3rdparty_8xb128_in1k`\* | From scratch | 2.97 | 0.79 | 76.48 | 93.49 | [config](moganet-xtiny_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xtiny_3rdparty_8xb128_in1k.pth) |
+| `moganet-tiny_3rdparty_8xb128_in1k`\* | From scratch | 5.20 | 1.09 | 77.24 | 93.51 | [config](moganet-tiny_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-tiny_3rdparty_8xb128_in1k.pth) |
+| `moganet-small_3rdparty_8xb128_in1k`\* | From scratch | 4.94 | 25.35 | 83.38 | 96.58 | [config](moganet-small_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-small_3rdparty_8xb128_in1k.pth) |
+| `moganet-base_3rdparty_8xb128_in1k`\* | From scratch | 9.88 | 43.72 | 84.20 | 96.77 | [config](moganet-base_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-base_3rdparty_8xb128_in1k.pth) |
+| `moganet-large_3rdparty_8xb128_in1k`\* | From scratch | 15.84 | 82.48 | 84.76 | 97.15 | [config](moganet-large_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-large_3rdparty_8xb128_in1k.pth) |
+| `moganet-xlarge_3rdparty_16xb32_in1k`\* | From scratch | 34.43 | 180.8 | 85.11 | 97.38 | [config](moganet-xlarge_16xb32_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xlarge_3rdparty_16xb32_in1k.pth) |
+
+*Models with * are converted from the [official repo](https://github.com/Westlake-AI/MogaNet). The config files of these models are only for inference. We haven't reproduce the training results.*
+
+## Citation
+
+```bibtex
+@article{Li2022MogaNet,
+ title={Efficient Multi-order Gated Aggregation Network},
+ author={Siyuan Li and Zedong Wang and Zicheng Liu and Cheng Tan and Haitao Lin and Di Wu and Zhiyuan Chen and Jiangbin Zheng and Stan Z. Li},
+ journal={ArXiv},
+ year={2022},
+ volume={abs/2211.03295}
+}
+```
diff --git a/configs/moganet/metafile.yml b/configs/moganet/metafile.yml
new file mode 100644
index 00000000000..0d6925f03a7
--- /dev/null
+++ b/configs/moganet/metafile.yml
@@ -0,0 +1,130 @@
+Collections:
+ - Name: MogaNet
+ Metadata:
+ Training Data: ImageNet-1k
+ Architecture:
+ - Gating
+ - 1x1 Convolution
+ - LayerScale
+ Paper:
+ URL: https://arxiv.org/abs/2211.03295
+ Title: Efficient Multi-order Gated Aggregation Network
+ README: configs/moganet/README.md
+ Code:
+ Version: v1.0.0
+ URL: https://github.com/Lupin1998/mmpretrain/tree/main/mmpretrain/models/backbones/moganet.py
+
+Models:
+ - Name: moganet-xtiny_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 843961073
+ Parameters: 3114270
+ In Collection: MogaNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 76.48
+ Top 5 Accuracy: 93.49
+ Task: Image Classification
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth
+ Config: configs/moganet/moganet-xtiny_8xb128_in1k.py
+ Converted From:
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth
+ Code: https://github.com/Westlake-AI/openmixup
+ - Name: moganet-tiny_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 1168231104
+ Parameters: 5449449
+ In Collection: MogaNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 77.24
+ Top 5 Accuracy: 93.51
+ Task: Image Classification
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
+ Config: configs/moganet/moganet-tiny_8xb128_in1k.py
+ Converted From:
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
+ Code: https://github.com/Westlake-AI/openmixup
+ - Name: moganet-small_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 5304284610
+ Parameters: 26566721
+ In Collection: MogaNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.38
+ Top 5 Accuracy: 96.58
+ Task: Image Classification
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth
+ Config: configs/moganet/moganet-small_8xb128_in1k.py
+ Converted From:
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth
+ Code: https://github.com/Westlake-AI/openmixup
+ - Name: moganet-base_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 10608569221
+ Parameters: 45843742
+ In Collection: MogaNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 84.20
+ Top 5 Accuracy: 96.77
+ Task: Image Classification
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth
+ Config: configs/moganet/moganet-base_8xb128_in1k.py
+ Converted From:
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth
+ Code: https://github.com/Westlake-AI/openmixup
+ - Name: moganet-large_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 17008070492
+ Parameters: 86486548
+ In Collection: MogaNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 84.76
+ Top 5 Accuracy: 97.15
+ Task: Image Classification
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth
+ Config: configs/moganet/moganet-large_8xb128_in1k.py
+ Converted From:
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth
+ Code: https://github.com/Westlake-AI/openmixup
+ - Name: moganet-xlarge_3rdparty_16xb32_in1k
+ Metadata:
+ FLOPs: 36968931000
+ Parameters: 189582540
+ In Collection: MogaNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 85.11
+ Top 5 Accuracy: 97.38
+ Task: Image Classification
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth
+ Config: configs/moganet/moganet-xlarge_16xb32_in1k.py
+ Converted From:
+ Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth
+ Code: https://github.com/Westlake-AI/openmixup
+
+ # - Name: poolformer-m48_3rdparty_32xb128_in1k
+ # Metadata:
+ # FLOPs: 11801805696
+ # Parameters: 73473448
+ # In Collection: PoolFormer
+ # Results:
+ # - Dataset: ImageNet-1k
+ # Metrics:
+ # Top 1 Accuracy: 82.51
+ # Top 5 Accuracy: 95.95
+ # Task: Image Classification
+ # Weights: https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth
+ # Config: configs/poolformer/poolformer-m48_32xb128_in1k.py
+ # Converted From:
+ # Weights: https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar
+ # Code: https://github.com/sail-sg/poolformer
diff --git a/configs/moganet/moganet-base_8xb128_in1k.py b/configs/moganet/moganet-base_8xb128_in1k.py
new file mode 100644
index 00000000000..800f11520eb
--- /dev/null
+++ b/configs/moganet/moganet-base_8xb128_in1k.py
@@ -0,0 +1,42 @@
+_base_ = [
+ '../_base_/models/moganet/moganet_base.py',
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+# schedule settings
+optim_wrapper = dict(
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.layer_scale': dict(decay_mult=0.0),
+ '.scale': dict(decay_mult=0.0),
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type='LinearLR',
+ start_factor=1e-3,
+ by_epoch=True,
+ end=5,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5)
+]
+
+# runtime setting
+custom_hooks = [
+ dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'),
+ dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')
+]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (8 GPUs) x (128 samples per GPU)
+auto_scale_lr = dict(base_batch_size=1024)
diff --git a/configs/moganet/moganet-large_8xb128_in1k.py b/configs/moganet/moganet-large_8xb128_in1k.py
new file mode 100644
index 00000000000..95ef60be733
--- /dev/null
+++ b/configs/moganet/moganet-large_8xb128_in1k.py
@@ -0,0 +1,101 @@
+_base_ = [
+ '../_base_/models/moganet/moganet_large.py',
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+# model setting
+model = dict(backbone=dict(attn_force_fp32=False))
+
+# dataset setting
+data_preprocessor = dict(
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+bgr_mean = data_preprocessor['mean'][::-1]
+bgr_std = data_preprocessor['std'][::-1]
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomResizedCrop',
+ scale=224,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandAugment',
+ policies='timm_increasing',
+ num_policies=2,
+ total_level=10,
+ magnitude_level=9,
+ magnitude_std=0.5,
+ hparams=dict(
+ pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(
+ type='RandomErasing',
+ erase_prob=0.25,
+ mode='rand',
+ min_area_ratio=0.02,
+ max_area_ratio=1 / 3,
+ fill_color=bgr_mean,
+ fill_std=bgr_std),
+ dict(type='PackInputs'),
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeEdge',
+ scale=248,
+ edge='short',
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='CenterCrop', crop_size=224),
+ dict(type='PackInputs'),
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=128)
+val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+
+# schedule settings
+optim_wrapper = dict(
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.layer_scale': dict(decay_mult=0.0),
+ '.scale': dict(decay_mult=0.0),
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type='LinearLR',
+ start_factor=1e-3,
+ by_epoch=True,
+ end=5,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5)
+]
+
+# runtime setting
+custom_hooks = [
+ dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'),
+ dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')
+]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (8 GPUs) x (128 samples per GPU)
+auto_scale_lr = dict(base_batch_size=1024)
diff --git a/configs/moganet/moganet-small_8xb128_in1k.py b/configs/moganet/moganet-small_8xb128_in1k.py
new file mode 100644
index 00000000000..1743933b357
--- /dev/null
+++ b/configs/moganet/moganet-small_8xb128_in1k.py
@@ -0,0 +1,39 @@
+_base_ = [
+ '../_base_/models/moganet/moganet_small.py',
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+# schedule settings
+optim_wrapper = dict(
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.layer_scale': dict(decay_mult=0.0),
+ '.scale': dict(decay_mult=0.0),
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type='LinearLR',
+ start_factor=1e-3,
+ by_epoch=True,
+ end=5,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5)
+]
+
+# runtime setting
+custom_hooks = [dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL')]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (8 GPUs) x (128 samples per GPU)
+auto_scale_lr = dict(base_batch_size=1024)
diff --git a/configs/moganet/moganet-tiny_8xb128_in1k.py b/configs/moganet/moganet-tiny_8xb128_in1k.py
new file mode 100644
index 00000000000..b030ffdbafc
--- /dev/null
+++ b/configs/moganet/moganet-tiny_8xb128_in1k.py
@@ -0,0 +1,95 @@
+_base_ = [
+ '../_base_/models/moganet/moganet_tiny.py',
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+# dataset setting
+data_preprocessor = dict(
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+bgr_mean = data_preprocessor['mean'][::-1]
+bgr_std = data_preprocessor['std'][::-1]
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomResizedCrop',
+ scale=224,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandAugment',
+ policies='timm_increasing',
+ num_policies=2,
+ total_level=10,
+ magnitude_level=9,
+ magnitude_std=0.5,
+ hparams=dict(
+ pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
+ dict(
+ type='RandomErasing',
+ erase_prob=0.25,
+ mode='rand',
+ min_area_ratio=0.02,
+ max_area_ratio=1 / 3,
+ fill_color=bgr_mean,
+ fill_std=bgr_std),
+ dict(type='PackInputs'),
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeEdge',
+ scale=248,
+ edge='short',
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='CenterCrop', crop_size=224),
+ dict(type='PackInputs'),
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=128)
+val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+
+# schedule settings
+optim_wrapper = dict(
+ optimizer=dict(weight_decay=0.04),
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.layer_scale': dict(decay_mult=0.0),
+ '.scale': dict(decay_mult=0.0),
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type='LinearLR',
+ start_factor=1e-3,
+ by_epoch=True,
+ end=5,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type='CosineAnnealingLR', eta_min=1e-6, by_epoch=False, begin=5)
+]
+
+# runtime setting
+custom_hooks = [dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL')]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (8 GPUs) x (128 samples per GPU)
+auto_scale_lr = dict(base_batch_size=1024)
diff --git a/configs/moganet/moganet-xlarge_16xb32_in1k.py b/configs/moganet/moganet-xlarge_16xb32_in1k.py
new file mode 100644
index 00000000000..a170aebd900
--- /dev/null
+++ b/configs/moganet/moganet-xlarge_16xb32_in1k.py
@@ -0,0 +1,98 @@
+_base_ = [
+ '../_base_/models/moganet/moganet_xlarge.py',
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+# dataset setting
+data_preprocessor = dict(
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+bgr_mean = data_preprocessor['mean'][::-1]
+bgr_std = data_preprocessor['std'][::-1]
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomResizedCrop',
+ scale=224,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandAugment',
+ policies='timm_increasing',
+ num_policies=2,
+ total_level=10,
+ magnitude_level=9,
+ magnitude_std=0.5,
+ hparams=dict(
+ pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(
+ type='RandomErasing',
+ erase_prob=0.25,
+ mode='rand',
+ min_area_ratio=0.02,
+ max_area_ratio=1 / 3,
+ fill_color=bgr_mean,
+ fill_std=bgr_std),
+ dict(type='PackInputs'),
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeEdge',
+ scale=248,
+ edge='short',
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='CenterCrop', crop_size=224),
+ dict(type='PackInputs'),
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=32)
+val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+
+# schedule settings
+optim_wrapper = dict(
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.layer_scale': dict(decay_mult=0.0),
+ '.scale': dict(decay_mult=0.0),
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type='LinearLR',
+ start_factor=1e-3,
+ by_epoch=True,
+ end=5,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5)
+]
+
+# runtime setting
+custom_hooks = [
+ dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'),
+ dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')
+]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (16 GPUs) x (32 samples per GPU)
+auto_scale_lr = dict(base_batch_size=512)
diff --git a/configs/moganet/moganet-xtiny_8xb128_in1k.py b/configs/moganet/moganet-xtiny_8xb128_in1k.py
new file mode 100644
index 00000000000..d2a2937a8a8
--- /dev/null
+++ b/configs/moganet/moganet-xtiny_8xb128_in1k.py
@@ -0,0 +1,95 @@
+_base_ = [
+ '../_base_/models/moganet/moganet_xtiny.py',
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+# dataset setting
+data_preprocessor = dict(
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+bgr_mean = data_preprocessor['mean'][::-1]
+bgr_std = data_preprocessor['std'][::-1]
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomResizedCrop',
+ scale=224,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandAugment',
+ policies='timm_increasing',
+ num_policies=2,
+ total_level=10,
+ magnitude_level=7,
+ magnitude_std=0.5,
+ hparams=dict(
+ pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
+ dict(
+ type='RandomErasing',
+ erase_prob=0.25,
+ mode='rand',
+ min_area_ratio=0.02,
+ max_area_ratio=1 / 3,
+ fill_color=bgr_mean,
+ fill_std=bgr_std),
+ dict(type='PackInputs'),
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeEdge',
+ scale=248,
+ edge='short',
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='CenterCrop', crop_size=224),
+ dict(type='PackInputs'),
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline), batch_size=128)
+val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+
+# schedule settings
+optim_wrapper = dict(
+ optimizer=dict(weight_decay=0.03),
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.layer_scale': dict(decay_mult=0.0),
+ '.scale': dict(decay_mult=0.0),
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type='LinearLR',
+ start_factor=1e-3,
+ by_epoch=True,
+ end=5,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type='CosineAnnealingLR', eta_min=1e-6, by_epoch=False, begin=5)
+]
+
+# runtime setting
+custom_hooks = [dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL')]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (8 GPUs) x (128 samples per GPU)
+auto_scale_lr = dict(base_batch_size=1024)
diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py
index 1d065db1cd8..3f5424699a1 100644
--- a/mmpretrain/__init__.py
+++ b/mmpretrain/__init__.py
@@ -6,11 +6,11 @@
from .apis import * # noqa: F401, F403
from .version import __version__
-mmcv_minimum_version = '2.0.0'
+mmcv_minimum_version = '1.7.0'
mmcv_maximum_version = '2.1.0'
mmcv_version = digit_version(mmcv.__version__)
-mmengine_minimum_version = '0.8.0'
+mmengine_minimum_version = '0.7.0'
mmengine_maximum_version = '1.0.0'
mmengine_version = digit_version(mmengine.__version__)
diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py
index 60e37fb7b6e..c121559ce5d 100644
--- a/mmpretrain/models/backbones/__init__.py
+++ b/mmpretrain/models/backbones/__init__.py
@@ -25,6 +25,7 @@
from .mobilenet_v3 import MobileNetV3
from .mobileone import MobileOne
from .mobilevit import MobileViT
+from .moganet import MogaNet
from .mvit import MViT
from .poolformer import PoolFormer
from .regnet import RegNet
@@ -112,6 +113,7 @@
'DeiT3',
'HorNet',
'MobileViT',
+ 'MogaNet',
'DaViT',
'BEiTViT',
'RevVisionTransformer',
diff --git a/mmpretrain/models/backbones/moganet.py b/mmpretrain/models/backbones/moganet.py
new file mode 100644
index 00000000000..44121efe790
--- /dev/null
+++ b/mmpretrain/models/backbones/moganet.py
@@ -0,0 +1,640 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn.bricks import Conv2d, DropPath, build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks.transformer import PatchEmbed
+from mmengine.model import BaseModule
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class ElementScale(nn.Module):
+ """A learnable element-wise scaler."""
+
+ def __init__(self, embed_dims, init_value=0., requires_grad=True):
+ super(ElementScale, self).__init__()
+ self.scale = nn.Parameter(
+ init_value * torch.ones((1, embed_dims, 1, 1)),
+ requires_grad=requires_grad
+ )
+
+ def forward(self, x):
+ return x * self.scale
+
+
+class ChannelAggregationFFN(BaseModule):
+ """An implementation of FFN with Channel Aggregation.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ kernel_size (int): The depth-wise conv kernel size as the
+ depth-wise convolution. Defaults to 3.
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ feedforward_channels,
+ kernel_size=3,
+ act_cfg=dict(type='GELU'),
+ ffn_drop=0.,
+ init_cfg=None):
+ super(ChannelAggregationFFN, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.act_cfg = act_cfg
+
+ self.fc1 = Conv2d(
+ in_channels=embed_dims,
+ out_channels=self.feedforward_channels,
+ kernel_size=1)
+ self.dwconv = Conv2d(
+ in_channels=self.feedforward_channels,
+ out_channels=self.feedforward_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ bias=True,
+ groups=self.feedforward_channels)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = Conv2d(
+ in_channels=feedforward_channels,
+ out_channels=embed_dims,
+ kernel_size=1)
+ self.drop = nn.Dropout(ffn_drop)
+
+ self.decompose = Conv2d(
+ in_channels=self.feedforward_channels, # C -> 1
+ out_channels=1, kernel_size=1,
+ )
+ self.sigma = ElementScale(
+ self.feedforward_channels, init_value=1e-5, requires_grad=True)
+ self.decompose_act = build_activation_layer(act_cfg)
+
+ def feat_decompose(self, x):
+ # x_d: [B, C, H, W] -> [B, 1, H, W]
+ x = x + self.sigma(x - self.decompose_act(self.decompose(x)))
+ return x
+
+ def forward(self, x):
+ # proj 1
+ x = self.fc1(x)
+ x = self.dwconv(x)
+ x = self.act(x)
+ x = self.drop(x)
+ # proj 2
+ x = self.feat_decompose(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class MultiOrderDWConv(BaseModule):
+ """Multi-order Features with Dilated DWConv Kernel.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ dw_dilation (list): Dilations of three DWConv layers.
+ channel_split (list): The raletive ratio of three splited channels.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ dw_dilation=[1, 2, 3,],
+ channel_split=[1, 3, 4,],
+ init_cfg=None):
+ super(MultiOrderDWConv, self).__init__(init_cfg=init_cfg)
+
+ self.split_ratio = [i / sum(channel_split) for i in channel_split]
+ self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
+ self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
+ self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
+ self.embed_dims = embed_dims
+ assert len(dw_dilation) == len(channel_split) == 3
+ assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
+ assert embed_dims % sum(channel_split) == 0
+
+ # basic DW conv
+ self.DW_conv0 = Conv2d(
+ in_channels=self.embed_dims,
+ out_channels=self.embed_dims,
+ kernel_size=5,
+ padding=(1 + 4 * dw_dilation[0]) // 2,
+ groups=self.embed_dims,
+ stride=1, dilation=dw_dilation[0],
+ )
+ # DW conv 1
+ self.DW_conv1 = Conv2d(
+ in_channels=self.embed_dims_1,
+ out_channels=self.embed_dims_1,
+ kernel_size=5,
+ padding=(1 + 4 * dw_dilation[1]) // 2,
+ groups=self.embed_dims_1,
+ stride=1, dilation=dw_dilation[1],
+ )
+ # DW conv 2
+ self.DW_conv2 = Conv2d(
+ in_channels=self.embed_dims_2,
+ out_channels=self.embed_dims_2,
+ kernel_size=7,
+ padding=(1 + 6 * dw_dilation[2]) // 2,
+ groups=self.embed_dims_2,
+ stride=1, dilation=dw_dilation[2],
+ )
+ # a channel convolution
+ self.PW_conv = Conv2d( # point-wise convolution
+ in_channels=embed_dims,
+ out_channels=embed_dims,
+ kernel_size=1)
+
+ def forward(self, x):
+ x_0 = self.DW_conv0(x)
+ x_1 = self.DW_conv1(
+ x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
+ x_2 = self.DW_conv2(
+ x_0[:, self.embed_dims-self.embed_dims_2:, ...])
+ x = torch.cat([
+ x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
+ x = self.PW_conv(x)
+ return x
+
+
+class MultiOrderGatedAggregation(BaseModule):
+ """Spatial Block with Multi-order Gated Aggregation.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ attn_dw_dilation (list): Dilations of three DWConv layers.
+ attn_channel_split (list): The raletive ratio of splited channels.
+ attn_act_cfg (dict, optional): The activation config for Spatial Block.
+ Default: dict(type='SiLU').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ attn_dw_dilation=[1, 2, 3],
+ attn_channel_split=[1, 3, 4],
+ attn_act_cfg=dict(type='SiLU'),
+ attn_force_fp32=False,
+ init_cfg=None):
+ super(MultiOrderGatedAggregation, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ self.attn_force_fp32 = attn_force_fp32
+ self.proj_1 = Conv2d(
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
+ self.gate = Conv2d(
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
+ self.value = MultiOrderDWConv(
+ embed_dims=embed_dims,
+ dw_dilation=attn_dw_dilation,
+ channel_split=attn_channel_split,
+ )
+ self.proj_2 = Conv2d(
+ in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
+
+ # activation for gating and value
+ self.act_value = build_activation_layer(attn_act_cfg)
+ self.act_gate = build_activation_layer(attn_act_cfg)
+
+ # decompose
+ self.sigma = ElementScale(
+ embed_dims, init_value=1e-5, requires_grad=True)
+
+ def feat_decompose(self, x):
+ x = self.proj_1(x)
+ # x_d: [B, C, H, W] -> [B, C, 1, 1]
+ x_d = F.adaptive_avg_pool2d(x, output_size=1)
+ x = x + self.sigma(x - x_d)
+ x = self.act_value(x)
+ return x
+
+ def forward_gating(self, g, v):
+ """ Force to computing gating with fp32
+
+ Warning: If you use `attn_force_fp32=True` during training, you
+ should also keep it during evaluation, because the output results
+ of whether to use `attn_force_fp32` are slightly different.
+ """
+ g = g.to(torch.float32)
+ v = v.to(torch.float32)
+ return self.proj_2(self.act_gate(g) * self.act_gate(v))
+
+ def forward(self, x):
+ shortcut = x.clone()
+ # proj 1x1
+ x = self.feat_decompose(x)
+ # gating and value branch
+ g = self.gate(x)
+ v = self.value(x)
+ # aggregation
+ if not self.attn_force_fp32:
+ x = self.proj_2(self.act_gate(g) * self.act_gate(v))
+ else:
+ x = self.forward_gating(self.act_gate(g), self.act_gate(v))
+ x = x + shortcut
+ return x
+
+
+class MogaBlock(BaseModule):
+ """A block of MogaNet.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ drop_rate (float): Dropout rate after embedding. Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
+ act_cfg (dict, optional): The activation config for projections and FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ init_value (float): Init value for Layer Scale. Defaults to 1e-5.
+ attn_dw_dilation (list): Dilations of three DWConv layers.
+ attn_channel_split (list): The raletive ratio of splited channels.
+ attn_act_cfg (dict): The activation config for the gating branch.
+ Default: dict(type='SiLU').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ ffn_ratio=4.,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='BN', eps=1e-5),
+ init_value=1e-5,
+ attn_dw_dilation=[1, 2, 3],
+ attn_channel_split=[1, 3, 4],
+ attn_act_cfg=dict(type='SiLU'),
+ attn_force_fp32=False,
+ init_cfg=None):
+ super(MogaBlock, self).__init__(init_cfg=init_cfg)
+ self.out_channels = embed_dims
+
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ # spatial attention
+ self.attn = MultiOrderGatedAggregation(
+ embed_dims,
+ attn_dw_dilation=attn_dw_dilation,
+ attn_channel_split=attn_channel_split,
+ attn_act_cfg=attn_act_cfg,
+ attn_force_fp32=attn_force_fp32,
+ )
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ # channel MLP
+ mlp_hidden_dim = int(embed_dims * ffn_ratio)
+ self.mlp = ChannelAggregationFFN( # DWConv + Channel Aggregation FFN
+ embed_dims=embed_dims,
+ feedforward_channels=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ ffn_drop=drop_rate,
+ )
+
+ # init layer scale
+ self.layer_scale_1 = nn.Parameter(
+ init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
+ self.layer_scale_2 = nn.Parameter(
+ init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
+
+ def forward(self, x):
+ # spatial
+ identity = x
+ x = self.layer_scale_1 * self.attn(self.norm1(x))
+ x = identity + self.drop_path(x)
+ # channel
+ identity = x
+ x = self.layer_scale_2 * self.mlp(self.norm2(x))
+ x = identity + self.drop_path(x)
+ return x
+
+
+class ConvPatchEmbed(PatchEmbed):
+ """An implementation of Conv patch embedding layer.
+
+ The differences between ConvPatchEmbed & ViT PatchEmbed:
+ 1. Use BN.
+ 2. Do not use 'flatten' and 'transpose'.
+ """
+
+ def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
+ super(ConvPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (out_h, out_w).
+ """
+
+ if self.adaptive_padding:
+ x = self.adaptive_padding(x)
+ x = self.projection(x)
+ out_size = (x.shape[2], x.shape[3])
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, out_size
+
+
+class StackConvPatchEmbed(BaseModule):
+ """An implementation of Stack Conv patch embedding layer.
+
+ Args:
+ in_features (int): The feature dimension.
+ embed_dims (int): The output dimension of PatchEmbed.
+ kernel_size (int): The conv kernel size of stack patch embedding.
+ Defaults to 3.
+ stride (int): The conv stride of stack patch embedding.
+ Defaults to 2.
+ act_cfg (dict, optional): The activation config in PatchEmbed.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer in PatchEmbed.
+ Defaults: dict(type='BN').
+ """
+
+ def __init__(self,
+ in_channels,
+ embed_dims,
+ kernel_size=3,
+ stride=2,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='BN'),
+ init_cfg=None,
+ ):
+ super(StackConvPatchEmbed, self).__init__(init_cfg)
+
+ self.projection = nn.Sequential(
+ Conv2d(in_channels, embed_dims // 2, kernel_size=kernel_size,
+ stride=stride, padding=kernel_size // 2),
+ build_norm_layer(norm_cfg, embed_dims // 2)[1],
+ build_activation_layer(act_cfg),
+ Conv2d(embed_dims // 2, embed_dims, kernel_size=kernel_size,
+ stride=stride, padding=kernel_size // 2),
+ build_norm_layer(norm_cfg, embed_dims)[1],
+ )
+
+ def forward(self, x):
+ x = self.projection(x)
+ out_size = (x.shape[2], x.shape[3])
+ return x, out_size
+
+
+@MODELS.register_module()
+class MogaNet(BaseBackbone):
+ """MogaNet.
+
+ A PyTorch implementation of MogaNet introduced by:
+ `Efficient Multi-order Gated Aggregation Network `_
+
+ Modified from the `official repo
+ `.
+
+ Args:
+ arch (str | dict): MogaNet architecture.
+ If use string, choose from 'xtiny', 'tiny', 'small', 'base' and 'large'.
+ If use dict, it should have below keys:
+
+ - **embed_dims** (List[int]): The dimensions of embedding.
+ - **depths** (List[int]): The number of blocks in each stage.
+ - **ffn_ratios** (List[int]): The number of expansion ratio of
+ feedforward network hidden layer channels.
+
+ Defaults to 'tiny'.
+ patch_sizes (List[int | tuple]): The patch size in patch embeddings.
+ Defaults to [3, 3, 3, 3].
+ in_channels (int): The num of input channels. Defaults to 3.
+ drop_rate (float): Dropout rate after embedding. Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
+ init_value (float): Init value for Layer Scale. Defaults to 1e-5.
+ out_indices (Sequence | int): Output from which network position.
+ Index 0-6 respectively corresponds to
+ [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
+ Defaults to -1, means the last stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ stem_norm_cfg (dict): Config dict for normalization layer for stems.
+ Defaults to ``dict(type='LN')``.
+ conv_norm_cfg (dict): Config dict for convolution normalization layer.
+ Defaults to ``dict(type='BN')``.
+ patchembed_types (list): The type of PatchEmbedding in each stage.
+ Defaults to ``['ConvEmbed', 'Conv', 'Conv', 'Conv',]``.
+ attn_dw_dilation (list): The dilate rate of depth-wise convolutions in
+ Moga Blocks. Defaults to ``[1, 2, 3]``.
+ attn_channel_split (list): The channel split rate of three depth-wise
+ convolutions in Moga Blocks. Defaults to ``[1, 3, 4]``, i.e.,
+ divided into ``[1/8, 3/8, 4/8]``.
+ attn_act_cfg (dict): Config dict for activation of gating in Moga
+ Blocks. Defaults to ``dict(type='SiLU')``.
+ attn_final_dilation (bool): Whether to adopt dilated depth-wise
+ convolutions in the final stage. Defaults to True.
+ attn_force_fp32 (bool): Whether to force the gating running with fp32.
+ Warning: If you use `attn_force_fp32=True` during training, you
+ should also keep it during evaluation, because the output results
+ of whether to use `attn_force_fp32` are different. Defaults to True.
+ block_cfgs (Sequence[dict] | dict): The extra config of each block.
+ Defaults to empty dicts.
+ init_cfg (dict, optional): Initialization config dict
+ """
+
+ arch_settings = {
+ **dict.fromkeys(['xt', 'x-tiny'],
+ {'embed_dims': [32, 64, 96, 192],
+ 'depths': [3, 3, 10, 2],
+ 'ffn_ratios': [8, 8, 4, 4]}),
+ **dict.fromkeys(['t', 'tiny'],
+ {'embed_dims': [32, 64, 128, 256],
+ 'depths': [3, 3, 12, 2],
+ 'ffn_ratios': [8, 8, 4, 4]}),
+ **dict.fromkeys(['s', 'small'],
+ {'embed_dims': [64, 128, 320, 512],
+ 'depths': [2, 3, 12, 2],
+ 'ffn_ratios': [8, 8, 4, 4]}),
+ **dict.fromkeys(['b', 'base'],
+ {'embed_dims': [64, 160, 320, 512],
+ 'depths': [4, 6, 22, 3],
+ 'ffn_ratios': [8, 8, 4, 4]}),
+ **dict.fromkeys(['l', 'large'],
+ {'embed_dims': [64, 160, 320, 640],
+ 'depths': [4, 6, 44, 4],
+ 'ffn_ratios': [8, 8, 4, 4]}),
+ **dict.fromkeys(['xl', 'x-large'],
+ {'embed_dims': [96, 192, 480, 960],
+ 'depths': [6, 6, 44, 4],
+ 'ffn_ratios': [8, 8, 4, 4]}),
+ } # yapf: disable
+
+ def __init__(self,
+ arch='tiny',
+ patch_sizes=[3, 3, 3, 3],
+ in_channels=3,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ init_value=1e-5,
+ out_indices=(3, ),
+ frozen_stages=-1,
+ norm_eval=False,
+ stem_norm_cfg=dict(type='BN', eps=1e-5),
+ conv_norm_cfg=dict(type='BN', eps=1e-5),
+ patchembed_types=['ConvEmbed', 'Conv', 'Conv', 'Conv',],
+ attn_dw_dilation=[1, 2, 3],
+ attn_channel_split=[1, 3, 4],
+ attn_act_cfg=dict(type='SiLU'),
+ attn_final_dilation=True,
+ attn_force_fp32=False,
+ block_cfgs=dict(),
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ self.arch_settings = self.arch_settings[arch]
+ else:
+ essential_keys = {'embed_dims', 'depths', 'ffn_ratios'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.ffn_ratios = self.arch_settings['ffn_ratios']
+ self.num_stages = len(self.depths)
+ self.out_indices = out_indices
+ self.norm_eval = norm_eval
+ self.attn_force_fp32 = attn_force_fp32
+ self.use_layer_norm = stem_norm_cfg['type'] == 'LN'
+ assert stem_norm_cfg['type'] in ['BN', 'SyncBN', 'LN', 'LN2d',]
+ assert len(patchembed_types) == self.num_stages
+
+ total_depth = sum(self.depths)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ] # stochastic depth decay rule
+
+ cur_block_idx = 0
+ for i, depth in enumerate(self.depths):
+ if i == 0 and patchembed_types[i] == "ConvEmbed":
+ assert patch_sizes[i] <= 3
+ patch_embed = StackConvPatchEmbed(
+ in_channels=in_channels,
+ embed_dims=self.embed_dims[i],
+ kernel_size=patch_sizes[i],
+ stride=patch_sizes[i] // 2 + 1,
+ norm_cfg=conv_norm_cfg,
+ )
+ else:
+ patch_embed = ConvPatchEmbed(
+ in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
+ input_size=None,
+ embed_dims=self.embed_dims[i],
+ kernel_size=patch_sizes[i],
+ stride=patch_sizes[i] // 2 + 1,
+ padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
+ norm_cfg=conv_norm_cfg)
+
+ if i == self.num_stages - 1 and not attn_final_dilation:
+ attn_dw_dilation = [1, 2, 1]
+ blocks = nn.ModuleList([
+ MogaBlock(
+ embed_dims=self.embed_dims[i],
+ ffn_ratio=self.ffn_ratios[i],
+ drop_rate=drop_rate,
+ drop_path_rate=dpr[cur_block_idx + j],
+ norm_cfg=conv_norm_cfg,
+ init_value=init_value,
+ attn_dw_dilation=attn_dw_dilation,
+ attn_channel_split=attn_channel_split,
+ attn_act_cfg=attn_act_cfg,
+ attn_force_fp32=attn_force_fp32,
+ **block_cfgs) for j in range(depth)
+ ])
+ cur_block_idx += depth
+ norm = build_norm_layer(stem_norm_cfg, self.embed_dims[i])[1]
+
+ self.add_module(f'patch_embed{i + 1}', patch_embed)
+ self.add_module(f'blocks{i + 1}', blocks)
+ self.add_module(f'norm{i + 1}', norm)
+
+ self.frozen_stages = frozen_stages
+ self._freeze_stages()
+
+ def forward(self, x):
+ outs = []
+ for i in range(self.num_stages):
+ patch_embed = getattr(self, f'patch_embed{i + 1}')
+ blocks = getattr(self, f'blocks{i + 1}')
+ norm = getattr(self, f'norm{i + 1}')
+
+ x, hw_shape = patch_embed(x)
+ for block in blocks:
+ x = block(x)
+ if self.use_layer_norm:
+ x = x.flatten(2).transpose(1, 2)
+ x = norm(x)
+ x = x.reshape(-1, *hw_shape,
+ blocks.out_channels).permute(0, 3, 1, 2).contiguous()
+ else:
+ x = norm(x)
+
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(0, self.frozen_stages + 1):
+ # freeze patch embed
+ m = getattr(self, f'patch_embed{i + 1}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+ # freeze blocks
+ m = getattr(self, f'blocks{i + 1}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+ # freeze norm
+ m = getattr(self, f'norm{i + 1}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MogaNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, (_BatchNorm, nn.SyncBatchNorm)):
+ m.eval()
diff --git a/model-index.yml b/model-index.yml
index 3fb3d0457d6..1c903bf9739 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -33,6 +33,7 @@ Import:
- configs/mvit/metafile.yml
- configs/edgenext/metafile.yml
- configs/mobileone/metafile.yml
+ - configs/moganet/metafile.yml
- configs/efficientformer/metafile.yml
- configs/swin_transformer_v2/metafile.yml
- configs/deit3/metafile.yml
From ce54aa85ea7fb65bb8d2254bcac4b8f52f4bc550 Mon Sep 17 00:00:00 2001
From: Lupin1998 <1070535169@qq.com>
Date: Thu, 6 Jul 2023 10:21:23 +0000
Subject: [PATCH 2/2] fix bug
---
configs/moganet/metafile.yml | 29 ++++++-----------------------
mmpretrain/__init__.py | 4 ++--
2 files changed, 8 insertions(+), 25 deletions(-)
diff --git a/configs/moganet/metafile.yml b/configs/moganet/metafile.yml
index 0d6925f03a7..fd8ac971ccc 100644
--- a/configs/moganet/metafile.yml
+++ b/configs/moganet/metafile.yml
@@ -26,7 +26,7 @@ Models:
Top 1 Accuracy: 76.48
Top 5 Accuracy: 93.49
Task: Image Classification
- Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth
+ Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xtiny_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-xtiny_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth
@@ -42,7 +42,7 @@ Models:
Top 1 Accuracy: 77.24
Top 5 Accuracy: 93.51
Task: Image Classification
- Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
+ Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-tiny_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-tiny_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
@@ -58,7 +58,7 @@ Models:
Top 1 Accuracy: 83.38
Top 5 Accuracy: 96.58
Task: Image Classification
- Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth
+ Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-small_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-small_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth
@@ -74,7 +74,7 @@ Models:
Top 1 Accuracy: 84.20
Top 5 Accuracy: 96.77
Task: Image Classification
- Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth
+ Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-base_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-base_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth
@@ -90,7 +90,7 @@ Models:
Top 1 Accuracy: 84.76
Top 5 Accuracy: 97.15
Task: Image Classification
- Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth
+ Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-large_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-large_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth
@@ -106,25 +106,8 @@ Models:
Top 1 Accuracy: 85.11
Top 5 Accuracy: 97.38
Task: Image Classification
- Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth
+ Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xlarge_3rdparty_16xb32_in1k.pth
Config: configs/moganet/moganet-xlarge_16xb32_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
-
- # - Name: poolformer-m48_3rdparty_32xb128_in1k
- # Metadata:
- # FLOPs: 11801805696
- # Parameters: 73473448
- # In Collection: PoolFormer
- # Results:
- # - Dataset: ImageNet-1k
- # Metrics:
- # Top 1 Accuracy: 82.51
- # Top 5 Accuracy: 95.95
- # Task: Image Classification
- # Weights: https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth
- # Config: configs/poolformer/poolformer-m48_32xb128_in1k.py
- # Converted From:
- # Weights: https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar
- # Code: https://github.com/sail-sg/poolformer
diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py
index 3f5424699a1..1d065db1cd8 100644
--- a/mmpretrain/__init__.py
+++ b/mmpretrain/__init__.py
@@ -6,11 +6,11 @@
from .apis import * # noqa: F401, F403
from .version import __version__
-mmcv_minimum_version = '1.7.0'
+mmcv_minimum_version = '2.0.0'
mmcv_maximum_version = '2.1.0'
mmcv_version = digit_version(mmcv.__version__)
-mmengine_minimum_version = '0.7.0'
+mmengine_minimum_version = '0.8.0'
mmengine_maximum_version = '1.0.0'
mmengine_version = digit_version(mmengine.__version__)