diff --git a/.gitignore b/.gitignore index 04a725b73c7..f37fecf4af6 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,10 @@ venv.bak/ # slurm *.out + +# tensorflow +*.tar.gz +checkpoint +model_params.txt +*.ckpt* +results.txt diff --git a/configs/_base_/models/efficientnet_l2.py b/configs/_base_/models/efficientnet_l2.py new file mode 100644 index 00000000000..4219c87a81a --- /dev/null +++ b/configs/_base_/models/efficientnet_l2.py @@ -0,0 +1,12 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='EfficientNet', arch='l2'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=5504, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/efficientnet/README.md b/configs/efficientnet/README.md index 5742b6676a0..625b57c5a54 100644 --- a/configs/efficientnet/README.md +++ b/configs/efficientnet/README.md @@ -4,49 +4,129 @@ -## Abstract +## Introduction -Convolutional Neural Networks (ConvNets) are commonly developed at a fixed resource budget, and then scaled up for better accuracy if more resources are available. In this paper, we systematically study model scaling and identify that carefully balancing network depth, width, and resolution can lead to better performance. Based on this observation, we propose a new scaling method that uniformly scales all dimensions of depth/width/resolution using a simple yet highly effective compound coefficient. We demonstrate the effectiveness of this method on scaling up MobileNets and ResNet. To go even further, we use neural architecture search to design a new baseline network and scale it up to obtain a family of models, called EfficientNets, which achieve much better accuracy and efficiency than previous ConvNets. In particular, our EfficientNet-B7 achieves state-of-the-art 84.3% top-1 accuracy on ImageNet, while being 8.4x smaller and 6.1x faster on inference than the best existing ConvNet. Our EfficientNets also transfer well and achieve state-of-the-art accuracy on CIFAR-100 (91.7%), Flowers (98.8%), and 3 other transfer learning datasets, with an order of magnitude fewer parameters. +EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models. + +EfficientNets are based on AutoML and Compound Scaling. In particular, we first use [AutoML MNAS Mobile framework](https://ai.googleblog.com/2018/08/mnasnet-towards-automating-design-of.html) to develop a mobile-size baseline network, named as EfficientNet-B0; Then, we use the compound scaling method to scale up this baseline to obtain EfficientNet-B1 to B7.
+## Abstract + +
+ +Click to show the detailed Abstract + +
+Convolutional Neural Networks (ConvNets) are commonly developed at a fixed resource budget, and then scaled up for better accuracy if more resources are available. In this paper, we systematically study model scaling and identify that carefully balancing network depth, width, and resolution can lead to better performance. Based on this observation, we propose a new scaling method that uniformly scales all dimensions of depth/width/resolution using a simple yet highly effective compound coefficient. We demonstrate the effectiveness of this method on scaling up MobileNets and ResNet. To go even further, we use neural architecture search to design a new baseline network and scale it up to obtain a family of models, called EfficientNets, which achieve much better accuracy and efficiency than previous ConvNets. In particular, our EfficientNet-B7 achieves state-of-the-art 84.3% top-1 accuracy on ImageNet, while being 8.4x smaller and 6.1x faster on inference than the best existing ConvNet. Our EfficientNets also transfer well and achieve state-of-the-art accuracy on CIFAR-100 (91.7%), Flowers (98.8%), and 3 other transfer learning datasets, with an order of magnitude fewer parameters. + +
+ ## Results and models ### ImageNet-1k -In the result table, AA means trained with AutoAugment pre-processing, more details can be found in the [paper](https://arxiv.org/abs/1805.09501), and AdvProp is a method to train with adversarial examples, more details can be found in the [paper](https://arxiv.org/abs/1911.09665). +In the result table, AA means trained with AutoAugment pre-processing, more details can be found in the [paper](https://arxiv.org/abs/1805.09501); AdvProp is a method to train with adversarial examples, more details can be found in the [paper](https://arxiv.org/abs/1911.09665); RA means trained with RandAugment pre-processing, more details can be found in the [paper](https://arxiv.org/abs/1909.13719); NoisyStudent means trained with extra JFT-300M unlabeled data, more details can be found in the [paper](https://arxiv.org/abs/1911.04252); L2-475 means the same L2 architecture with input image size 475. Note: In MMClassification, we support training with AutoAugment, don't support AdvProp by now. -| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | -| :------------------------------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------: | :-----------------------------------------------------------------------------------: | -| EfficientNet-B0\* | 5.29 | 0.02 | 76.74 | 93.17 | [config](./efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth) | -| EfficientNet-B0 (AA)\* | 5.29 | 0.02 | 77.26 | 93.41 | [config](./efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa_in1k_20220119-8d939117.pth) | -| EfficientNet-B0 (AA + AdvProp)\* | 5.29 | 0.02 | 77.53 | 93.61 | [config](./efficientnet-b0_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k_20220119-26434485.pth) | -| EfficientNet-B1\* | 7.79 | 0.03 | 78.68 | 94.28 | [config](./efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32_in1k_20220119-002556d9.pth) | -| EfficientNet-B1 (AA)\* | 7.79 | 0.03 | 79.20 | 94.42 | [config](./efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa_in1k_20220119-619d8ae3.pth) | -| EfficientNet-B1 (AA + AdvProp)\* | 7.79 | 0.03 | 79.52 | 94.43 | [config](./efficientnet-b1_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k_20220119-5715267d.pth) | -| EfficientNet-B2\* | 9.11 | 0.03 | 79.64 | 94.80 | [config](./efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32_in1k_20220119-ea374a30.pth) | -| EfficientNet-B2 (AA)\* | 9.11 | 0.03 | 80.21 | 94.96 | [config](./efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa_in1k_20220119-dd61e80b.pth) | -| EfficientNet-B2 (AA + AdvProp)\* | 9.11 | 0.03 | 80.45 | 95.07 | [config](./efficientnet-b2_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k_20220119-1655338a.pth) | -| EfficientNet-B3\* | 12.23 | 0.06 | 81.01 | 95.34 | [config](./efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32_in1k_20220119-4b4d7487.pth) | -| EfficientNet-B3 (AA)\* | 12.23 | 0.06 | 81.58 | 95.67 | [config](./efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa_in1k_20220119-5b4887a0.pth) | -| EfficientNet-B3 (AA + AdvProp)\* | 12.23 | 0.06 | 81.81 | 95.69 | [config](./efficientnet-b3_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k_20220119-53b41118.pth) | -| EfficientNet-B4\* | 19.34 | 0.12 | 82.57 | 96.09 | [config](./efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32_in1k_20220119-81fd4077.pth) | -| EfficientNet-B4 (AA)\* | 19.34 | 0.12 | 82.95 | 96.26 | [config](./efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa_in1k_20220119-45b8bd2b.pth) | -| EfficientNet-B4 (AA + AdvProp)\* | 19.34 | 0.12 | 83.25 | 96.44 | [config](./efficientnet-b4_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k_20220119-38c2238c.pth) | -| EfficientNet-B5\* | 30.39 | 0.24 | 83.18 | 96.47 | [config](./efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32_in1k_20220119-e9814430.pth) | -| EfficientNet-B5 (AA)\* | 30.39 | 0.24 | 83.82 | 96.76 | [config](./efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa_in1k_20220119-2cab8b78.pth) | -| EfficientNet-B5 (AA + AdvProp)\* | 30.39 | 0.24 | 84.21 | 96.98 | [config](./efficientnet-b5_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k_20220119-f57a895a.pth) | -| EfficientNet-B6 (AA)\* | 43.04 | 0.41 | 84.05 | 96.82 | [config](./efficientnet-b6_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa_in1k_20220119-45b03310.pth) | -| EfficientNet-B6 (AA + AdvProp)\* | 43.04 | 0.41 | 84.74 | 97.14 | [config](./efficientnet-b6_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k_20220119-bfe3485e.pth) | -| EfficientNet-B7 (AA)\* | 66.35 | 0.72 | 84.38 | 96.88 | [config](./efficientnet-b7_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa_in1k_20220119-bf03951c.pth) | -| EfficientNet-B7 (AA + AdvProp)\* | 66.35 | 0.72 | 85.14 | 97.23 | [config](./efficientnet-b7_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k_20220119-c6dbff10.pth) | -| EfficientNet-B8 (AA + AdvProp)\* | 87.41 | 1.09 | 85.38 | 97.28 | [config](./efficientnet-b8_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k_20220119-297ce1b7.pth) | - -*Models with * are converted from the [official repo](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* +| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :---------------------------------------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------: | :--------------------------------------------------------------------------: | +| EfficientNet-B0\* | 5.29 | 0.42 | 76.74 | 93.17 | [config](./efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth) | +| EfficientNet-B0 (AA)\* | 5.29 | 0.42 | 77.26 | 93.41 | [config](./efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa_in1k_20220119-8d939117.pth) | +| EfficientNet-B0 (AA + AdvProp)\* | 5.29 | 0.42 | 77.53 | 93.61 | [config](./efficientnet-b0_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k_20220119-26434485.pth) | +| EfficientNet-B0 (RA + NoisyStudent)\* | 5.29 | 0.42 | 77.63 | 94.00 | [config](./efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty-ra-noisystudent_in1k_20221103-75cd08d3.pth) | +| EfficientNet-B1\* | 7.79 | 0.74 | 78.68 | 94.28 | [config](./efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32_in1k_20220119-002556d9.pth) | +| EfficientNet-B1 (AA)\* | 7.79 | 0.74 | 79.20 | 94.42 | [config](./efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa_in1k_20220119-619d8ae3.pth) | +| EfficientNet-B1 (AA + AdvProp)\* | 7.79 | 0.74 | 79.52 | 94.43 | [config](./efficientnet-b1_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k_20220119-5715267d.pth) | +| EfficientNet-B1 (RA + NoisyStudent)\* | 7.79 | 0.74 | 81.44 | 95.83 | [config](./efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty-ra-noisystudent_in1k_20221103-756bcbc0.pth) | +| EfficientNet-B2\* | 9.11 | 1.07 | 79.64 | 94.80 | [config](./efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32_in1k_20220119-ea374a30.pth) | +| EfficientNet-B2 (AA)\* | 9.11 | 1.07 | 80.21 | 94.96 | [config](./efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa_in1k_20220119-dd61e80b.pth) | +| EfficientNet-B2 (AA + AdvProp)\* | 9.11 | 1.07 | 80.45 | 95.07 | [config](./efficientnet-b2_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k_20220119-1655338a.pth) | +| EfficientNet-B2 (RA + NoisyStudent)\* | 9.11 | 1.07 | 82.47 | 96.23 | [config](./efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty-ra-noisystudent_in1k_20221103-756bcbc0.pth) | +| EfficientNet-B3\* | 12.23 | 1.07 | 81.01 | 95.34 | [config](./efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32_in1k_20220119-4b4d7487.pth) | +| EfficientNet-B3 (AA)\* | 12.23 | 1.07 | 81.58 | 95.67 | [config](./efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa_in1k_20220119-5b4887a0.pth) | +| EfficientNet-B3 (AA + AdvProp)\* | 12.23 | 1.07 | 81.81 | 95.69 | [config](./efficientnet-b3_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k_20220119-53b41118.pth) | +| EfficientNet-B3 (RA + NoisyStudent)\* | 12.23 | 1.07 | 84.02 | 96.89 | [config](./efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty-ra-noisystudent_in1k_20221103-a4ab5fd6.pth) | +| EfficientNet-B4\* | 19.34 | 1.95 | 82.57 | 96.09 | [config](./efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32_in1k_20220119-81fd4077.pth) | +| EfficientNet-B4 (AA)\* | 19.34 | 1.95 | 82.95 | 96.26 | [config](./efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa_in1k_20220119-45b8bd2b.pth) | +| EfficientNet-B4 (AA + AdvProp)\* | 19.34 | 1.95 | 83.25 | 96.44 | [config](./efficientnet-b4_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k_20220119-38c2238c.pth) | +| EfficientNet-B4 (RA + NoisyStudent)\* | 19.34 | 1.95 | 85.25 | 97.52 | [config](./efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty-ra-noisystudent_in1k_20221103-16ba8a2d.pth) | +| EfficientNet-B5\* | 30.39 | 10.1 | 83.18 | 96.47 | [config](./efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32_in1k_20220119-e9814430.pth) | +| EfficientNet-B5 (AA)\* | 30.39 | 10.1 | 83.82 | 96.76 | [config](./efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa_in1k_20220119-2cab8b78.pth) | +| EfficientNet-B5 (AA + AdvProp)\* | 30.39 | 10.1 | 84.21 | 96.98 | [config](./efficientnet-b5_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k_20220119-f57a895a.pth) | +| EfficientNet-B5 (RA + NoisyStudent)\* | 30.39 | 10.1 | 86.08 | 97.75 | [config](./efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty-ra-noisystudent_in1k_20221103-111a185f.pth) | +| EfficientNet-B6 (AA)\* | 43.04 | 20.0 | 84.05 | 96.82 | [config](./efficientnet-b6_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa_in1k_20220119-45b03310.pth) | +| EfficientNet-B6 (AA + AdvProp)\* | 43.04 | 20.0 | 84.74 | 97.14 | [config](./efficientnet-b6_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k_20220119-bfe3485e.pth) | +| EfficientNet-B6 (RA + NoisyStudent)\* | 43.04 | 20.0 | 86.47 | 97.87 | [config](./efficientnet-b6_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty-ra-noisystudent_in1k_20221103-7de7d2cc.pth) | +| EfficientNet-B7 (AA)\* | 66.35 | 39.3 | 84.38 | 96.88 | [config](./efficientnet-b7_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa_in1k_20220119-bf03951c.pth) | +| EfficientNet-B7 (AA + AdvProp)\* | 66.35 | 39.3 | 85.14 | 97.23 | [config](./efficientnet-b7_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k_20220119-c6dbff10.pth) | +| EfficientNet-B7 (RA + NoisyStudent)\* | 66.35 | 65.0 | 86.83 | 98.08 | [config](./efficientnet-b7_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty-ra-noisystudent_in1k_20221103-a82894bc.pth) | +| EfficientNet-B8 (AA + AdvProp)\* | 87.41 | 65.0 | 85.38 | 97.28 | [config](./efficientnet-b8_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k_20220119-297ce1b7.pth) | +| EfficientNet-L2-475 (RA + NoisyStudent)\* | 480.30 | 174.20 | 88.18 | 98.55 | [config](./efficientnet-l2-475_8xb8_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-l2_3rdparty-ra-noisystudent_in1k-475px_20221103-5a0d8058.pth) | +| EfficientNet-L2 (RA + NoisyStudent)\* | 480.30 | 484.98 | 88.33 | 98.65 | [config](./efficientnet-l2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-l2_3rdparty-ra-noisystudent_in1k_20221103-be73be13.pth) | + +*Models with * are converted from the [official repo](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). The config files of these models +are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* + +## How to use it? + + + +**Predict image** + +```python +>>> import torch +>>> from mmcls.apis import init_model, inference_model +>>> +>>> model = init_model('configs/efficientnet/efficientnet-b0_8xb32_in1k.py', "https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth") +>>> predict = inference_model(model, 'demo/demo.JPEG') +>>> print(predict['pred_class']) +sea snake +>>> print(predict['pred_score']) +0.6968820691108704 +``` + +**Use the model** + +```python +>>> import torch +>>> from mmcls.apis import init_model +>>> +>>> model = init_model('configs/efficientnet/efficientnet-b0_8xb32_in1k.py', "https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth") +>>> inputs = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device) +>>> # To get classification scores. +>>> out = model(inputs) +>>> print(out.shape) +torch.Size([1, 1000]) +>>> # To extract features. +>>> outs = model.extract_feat(inputs) +>>> print(outs[0].shape) +torch.Size([1, 1280]) +``` + +**Train/Test Command** + +Place the ImageNet dataset to the `data/imagenet/` directory, or prepare datasets according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset). + +Train: + +```shell +python tools/train.py configs/efficientnet/efficientnet-b0_8xb32_in1k.py +``` + +Test: + +```shell +python tools/test.py configs/efficientnet/efficientnet-b0_8xb32_in1k.py https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth +``` + + + +For more configurable parameters, please refer to the [API](https://mmclassification.readthedocs.io/en/1.x/api/generated/mmcls.models.backbones.EfficientNet.html#mmcls.models.backbones.EfficientNet). ## Citation diff --git a/configs/efficientnet/efficientnet-l2_8xb32_in1k-475px.py b/configs/efficientnet/efficientnet-l2_8xb32_in1k-475px.py new file mode 100644 index 00000000000..010f808bdd6 --- /dev/null +++ b/configs/efficientnet/efficientnet-l2_8xb32_in1k-475px.py @@ -0,0 +1,24 @@ +_base_ = [ + '../_base_/models/efficientnet_l2.py', + '../_base_/datasets/imagenet_bs32.py', + '../_base_/schedules/imagenet_bs256.py', + '../_base_/default_runtime.py', +] + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='EfficientNetRandomCrop', scale=475), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='EfficientNetCenterCrop', crop_size=475), + dict(type='PackClsInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) diff --git a/configs/efficientnet/efficientnet-l2_8xb8_in1k.py b/configs/efficientnet/efficientnet-l2_8xb8_in1k.py new file mode 100644 index 00000000000..ac4664c6bf1 --- /dev/null +++ b/configs/efficientnet/efficientnet-l2_8xb8_in1k.py @@ -0,0 +1,24 @@ +_base_ = [ + '../_base_/models/efficientnet_l2.py', + '../_base_/datasets/imagenet_bs32.py', + '../_base_/schedules/imagenet_bs256.py', + '../_base_/default_runtime.py', +] + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='EfficientNetRandomCrop', scale=800), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='EfficientNetCenterCrop', crop_size=800), + dict(type='PackClsInputs'), +] + +train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=8, dataset=dict(pipeline=test_pipeline)) +test_dataloader = dict(batch_size=8, dataset=dict(pipeline=test_pipeline)) diff --git a/configs/efficientnet/metafile.yml b/configs/efficientnet/metafile.yml index 66e73e41c2d..aeb5488e0d6 100644 --- a/configs/efficientnet/metafile.yml +++ b/configs/efficientnet/metafile.yml @@ -23,7 +23,7 @@ Collections: Models: - Name: efficientnet-b0_3rdparty_8xb32_in1k Metadata: - FLOPs: 16481180 + FLOPs: 420592480 Parameters: 5288548 In Collection: EfficientNet Results: @@ -39,7 +39,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b0_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 16481180 + FLOPs: 420592480 Parameters: 5288548 In Collection: EfficientNet Results: @@ -55,7 +55,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 16481180 + FLOPs: 420592480 Parameters: 5288548 In Collection: EfficientNet Results: @@ -69,9 +69,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b0.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b0_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 420592480 + Parameters: 5288548 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 77.63 + Top 5 Accuracy: 94.00 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty-ra-noisystudent_in1k_20221103-75cd08d3.pth + Config: configs/efficientnet/efficientnet-b0_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b0.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b1_3rdparty_8xb32_in1k Metadata: - FLOPs: 27052224 + FLOPs: 744059920 Parameters: 7794184 In Collection: EfficientNet Results: @@ -87,7 +103,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b1_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 27052224 + FLOPs: 744059920 Parameters: 7794184 In Collection: EfficientNet Results: @@ -103,7 +119,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 27052224 + FLOPs: 744059920 Parameters: 7794184 In Collection: EfficientNet Results: @@ -117,9 +133,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b1.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b1_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 744059920 + Parameters: 7794184 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 81.44 + Top 5 Accuracy: 95.83 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty-ra-noisystudent_in1k_20221103-756bcbc0.pth + Config: configs/efficientnet/efficientnet-b1_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b1.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b2_3rdparty_8xb32_in1k Metadata: - FLOPs: 34346386 + FLOPs: 1066620392 Parameters: 9109994 In Collection: EfficientNet Results: @@ -135,7 +167,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b2_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 34346386 + FLOPs: 1066620392 Parameters: 9109994 In Collection: EfficientNet Results: @@ -151,7 +183,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 34346386 + FLOPs: 1066620392 Parameters: 9109994 In Collection: EfficientNet Results: @@ -165,9 +197,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b2.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b2_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 1066620392 + Parameters: 9109994 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 82.47 + Top 5 Accuracy: 96.23 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty-ra-noisystudent_in1k_20221103-301ed299.pth + Config: configs/efficientnet/efficientnet-b2_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b2.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b3_3rdparty_8xb32_in1k Metadata: - FLOPs: 58641904 + FLOPs: 1953798216 Parameters: 12233232 In Collection: EfficientNet Results: @@ -183,7 +231,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b3_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 58641904 + FLOPs: 1953798216 Parameters: 12233232 In Collection: EfficientNet Results: @@ -199,7 +247,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 58641904 + FLOPs: 1953798216 Parameters: 12233232 In Collection: EfficientNet Results: @@ -213,9 +261,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b3.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b3_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 1953798216 + Parameters: 12233232 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 84.02 + Top 5 Accuracy: 96.89 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty-ra-noisystudent_in1k_20221103-a4ab5fd6.pth + Config: configs/efficientnet/efficientnet-b3_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b3.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b4_3rdparty_8xb32_in1k Metadata: - FLOPs: 121870624 + FLOPs: 4659080176 Parameters: 19341616 In Collection: EfficientNet Results: @@ -231,7 +295,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b4_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 121870624 + FLOPs: 4659080176 Parameters: 19341616 In Collection: EfficientNet Results: @@ -247,7 +311,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 121870624 + FLOPs: 4659080176 Parameters: 19341616 In Collection: EfficientNet Results: @@ -261,9 +325,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b4.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b4_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 4659080176 + Parameters: 19341616 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 85.25 + Top 5 Accuracy: 97.52 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty-ra-noisystudent_in1k_20221103-16ba8a2d.pth + Config: configs/efficientnet/efficientnet-b4_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b4.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b5_3rdparty_8xb32_in1k Metadata: - FLOPs: 243879440 + FLOPs: 10799472560 Parameters: 30389784 In Collection: EfficientNet Results: @@ -279,7 +359,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b5_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 243879440 + FLOPs: 10799472560 Parameters: 30389784 In Collection: EfficientNet Results: @@ -295,7 +375,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 243879440 + FLOPs: 10799472560 Parameters: 30389784 In Collection: EfficientNet Results: @@ -309,9 +389,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b5.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b5_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 10799472560 + Parameters: 30389784 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 86.08 + Top 5 Accuracy: 97.75 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty-ra-noisystudent_in1k_20221103-111a185f.pth + Config: configs/efficientnet/efficientnet-b5_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b5.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b6_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 412002408 + FLOPs: 19971777560 Parameters: 43040704 In Collection: EfficientNet Results: @@ -327,7 +423,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 412002408 + FLOPs: 19971777560 Parameters: 43040704 In Collection: EfficientNet Results: @@ -341,9 +437,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b6.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b6_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 19971777560 + Parameters: 43040704 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 86.47 + Top 5 Accuracy: 97.87 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty-ra-noisystudent_in1k_20221103-7de7d2cc.pth + Config: configs/efficientnet/efficientnet-b6_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b6.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b7_3rdparty_8xb32-aa_in1k Metadata: - FLOPs: 715526512 + FLOPs: 39316473392 Parameters: 66347960 In Collection: EfficientNet Results: @@ -359,7 +471,7 @@ Models: Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 715526512 + FLOPs: 39316473392 Parameters: 66347960 In Collection: EfficientNet Results: @@ -373,9 +485,25 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b7.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-b7_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 39316473392 + Parameters: 66347960 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 86.83 + Top 5 Accuracy: 98.08 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty-ra-noisystudent_in1k_20221103-a82894bc.pth + Config: configs/efficientnet/efficientnet-b7_8xb32_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b7.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet - Name: efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k Metadata: - FLOPs: 1092755326 + FLOPs: 64999827816 Parameters: 87413142 In Collection: EfficientNet Results: @@ -389,3 +517,35 @@ Models: Converted From: Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b8.tar.gz Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-l2_3rdparty-ra-noisystudent_in1k + Metadata: + FLOPs: 174203533416 + Parameters: 480309308 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 88.33 + Top 5 Accuracy: 98.65 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-l2_3rdparty-ra-noisystudent_in1k_20221103-be73be13.pth + Config: configs/efficientnet/efficientnet-l2_8xb8_in1k.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-l2.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + - Name: efficientnet-l2_3rdparty-ra-noisystudent_in1k-475px + Metadata: + FLOPs: 484984099280 + Parameters: 480309308 + In Collection: EfficientNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 88.18 + Top 5 Accuracy: 98.55 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-l2_3rdparty-ra-noisystudent_in1k-475px_20221103-5a0d8058.pth + Config: configs/efficientnet/efficientnet-l2_8xb32_in1k-475px.py + Converted From: + Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-l2_475.tar.gz + Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet diff --git a/mmcls/models/backbones/efficientnet.py b/mmcls/models/backbones/efficientnet.py index 55cf2346db1..be0b08a218d 100644 --- a/mmcls/models/backbones/efficientnet.py +++ b/mmcls/models/backbones/efficientnet.py @@ -246,6 +246,7 @@ class EfficientNet(BaseBackbone): 'b6': (1.8, 2.6, 528), 'b7': (2.0, 3.1, 600), 'b8': (2.2, 3.6, 672), + 'l2': (4.3, 5.3, 800), 'es': (1.0, 1.0, 224), 'em': (1.0, 1.1, 240), 'el': (1.2, 1.4, 300) @@ -273,7 +274,9 @@ def __init__(self, f'"{arch}" is not one of the arch_settings ' \ f'({", ".join(self.arch_settings.keys())})' self.arch_setting = self.arch_settings[arch] - self.layer_setting = self.layer_settings[arch[:1]] + # layer_settings of arch='l2' is 'b' + self.layer_setting = self.layer_settings['b' if arch == + 'l2' else arch[:1]] for index in out_indices: if index not in range(0, len(self.layer_setting)): raise ValueError('the item in out_indices must in ' diff --git a/tools/model_converters/efficientnet_to_mmcls.py b/tools/model_converters/efficientnet_to_mmcls.py index d1b097bd4ca..c193cc54cac 100644 --- a/tools/model_converters/efficientnet_to_mmcls.py +++ b/tools/model_converters/efficientnet_to_mmcls.py @@ -4,7 +4,7 @@ import numpy as np import torch -from mmcv.runner import Sequential +from mmengine.model import Sequential from tensorflow.python.training import py_checkpoint_reader from mmcls.models.backbones.efficientnet import EfficientNet @@ -27,7 +27,7 @@ def read_ckpt(ckpt): return weights -def map_key(weight): +def map_key(weight, l2_flag): m = dict() has_expand_conv = set() is_MBConv = set() @@ -59,6 +59,8 @@ def map_key(weight): idx2key.append('{}'.format(idx)) for k, v in weight.items(): + if l2_flag: + k = k.replace('/ExponentialMovingAverage', '') if 'Exponential' in k or 'RMS' in k: continue @@ -204,6 +206,11 @@ def map_key(weight): parser = argparse.ArgumentParser() parser.add_argument('infile', type=str, help='Path to the ckpt.') parser.add_argument('outfile', type=str, help='Output file.') + parser.add_argument( + '--l2', + action='store_true', + help='If true convert ExponentialMovingAverage weights. ' + 'l2 arch should use it.') args = parser.parse_args() assert args.outfile @@ -211,5 +218,5 @@ def map_key(weight): if not os.path.exists(outdir): os.makedirs(outdir) weights = read_ckpt(args.infile) - weights = map_key(weights) + weights = map_key(weights, args.l2) torch.save(weights, args.outfile)