diff --git a/README.md b/README.md
index 3a05f560b..fb8e17066 100644
--- a/README.md
+++ b/README.md
@@ -43,12 +43,12 @@ MindOCR is an open-source toolbox for OCR development and application based on [
The following is the corresponding `mindocr` versions and supported
mindspore versions.
-| mindocr | mindspore |
-|:-------:|:---------:|
-| master | master |
-| 0.4 | 2.3.0 |
-| 0.3 | 2.2.10 |
-| 0.1 | 1.8 |
+| mindocr | mindspore |
+|:-------:|:-----------:|
+| main | master |
+| 0.4 | 2.3.0/2.3.1 |
+| 0.3 | 2.2.10 |
+| 0.1 | 1.8 |
## Installation
diff --git a/README_CN.md b/README_CN.md
index 4f63714cd..9721efe0f 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -43,12 +43,12 @@ MindOCR是一个基于[MindSpore](https://www.mindspore.cn/en) 框架开发的OC
以下是对应的“mindocr”版本和支持 Mindspore 版本。
-| mindocr | mindspore |
-|:-------:|:---------:|
-| master | master |
-| 0.4 | 2.3.0 |
-| 0.3 | 2.2.10 |
-| 0.1 | 1.8 |
+| mindocr | mindspore |
+|:-------:|:-----------:|
+| main | master |
+| 0.4 | 2.3.0/2.3.1 |
+| 0.3 | 2.2.10 |
+| 0.1 | 1.8 |
## 安装教程
diff --git a/configs/det/dbnet/README.md b/configs/det/dbnet/README.md
index 062c94898..eb5ddb8ed 100644
--- a/configs/det/dbnet/README.md
+++ b/configs/det/dbnet/README.md
@@ -7,7 +7,7 @@ English | [中文](README_CN.md)
> DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947)
> DBNet++: [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
-## 1. Introduction
+## Introduction
### DBNet
@@ -57,137 +57,22 @@ DBNet++ performs better in detecting text instances of diverse scales, especiall
DBNet may generate inaccurate or discrete bounding boxes.
-## 2. General purpose models
+## Requirements
-Here we present general purpose models that were trained on wide variety of tasks (real-world photos, street views, documents, etc.) and challenges (straight texts, curved texts, long text lines, etc.) with two primary languages: Chinese and English. These models can be used right off-the-shelf in your applications or for initialization of your models.
-
-The models were trained on 12 public datasets (CTW, LSVT, RCTW-17, TextOCR, etc.) that contain wide range of images. The training set has 153,511 images and the validation set has 9,786 images.
-The test set consists of 598 images manually selected from the above-mentioned datasets.
-
-Performance tested on ascend 910 with graph mode
-
-
-
-| **Model** | **Device Card** | **Backbone** | **Languages** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
-|-----------|------------|--------------|-------------------|:---------------------------:|----------------|-----------------|----------------|----------------------------------------------------------------------------------------------------------|
-| DBNet | 8p | ResNet-50 | Chinese + English | 83.41% | 10 | 312.48 ms/step | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141.ckpt) | [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141-912f0a90.mindir) |
-| DBNet++ | 4p | ResNet-50 | Chinese + English | 84.30% | 32 | 1230.76 ms/step | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) | [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) |
-
-
-> The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are `(1,3,736,1280)` and `(1,3,1152,2048)`, respectively.
-
-
-## 3. Results
-
-DBNet and DBNet++ were trained on the ICDAR2015, MSRA-TD500, SCUT-CTW1500, Total-Text, and MLT2017 datasets. In addition, we conducted pre-training on the SynthText dataset and provided a URL to download pretrained weights. All training results are as follows:
-
-
- Performance tested on ascend 910 with graph mode
- ### ICDAR2015
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |---------------------|------------|---------------|----------------|------------|---------------|-------------|----------------|----------------|-------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
- | DBNet | 1p | MobileNetV3 | ImageNet | 76.31% | 78.27% | 77.28% | 10 | 100.00 ms/step | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539-f14c6a13.mindir) |
- | DBNet | 8p | MobileNetV3 | ImageNet | 76.22% | 77.98% | 77.09% | 8 | 66.64 ms/step | [yaml](db_mobilenetv3_icdar15_8p.yaml) | Coming soon |
- | DBNet | 1p | ResNet-18 | ImageNet | 80.12% | 83.41% | 81.73% | 20 | 185.19 ms/step | [yaml](db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) |
- | DBNet | 1p | ResNet-50 | ImageNet | 83.53% | 86.62% | 85.05% | 10 | 132.98 ms/step | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) |
- | DBNet | 8p | ResNet-50 | ImageNet | 82.62% | 88.54% | 85.48% | 10 | 183.92 ms/step | [yaml](db_r50_icdar15_8p.yaml) | Coming soon |
- | | | | | | | | | | | |
- | DBNet++ | 1p | ResNet-50 | SynthText | 86.81% | 86.85% | 86.86% | 32 | 409.21 ms/step | [yaml](dbpp_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) |
-
-
- > The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are `(1,3,736,1280)` and `(1,3,1152,2048)`, respectively.
-
- ### MSRA-TD500
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |-----------|------------|--------------|----------------|------------|---------------|-------------|----------------|----------------|---------------------------|-------------------------------------------------------------------------------------------------|
- | DBNet | 1p | ResNet-18 | SynthText | 79.90% | 88.07% | 83.78% | 20 | 164.34 ms/step | [yaml](db_r18_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_td500-b5abff68.ckpt) |
- | DBNet | 1p | ResNet-50 | SynthText | 84.02% | 87.48% | 85.71% | 20 | 280.90 ms/step | [yaml](db_r50_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_td500-0d12b5e8.ckpt) |
-
-
- > MSRA-TD500 dataset has 300 training images and 200 testing images, reference paper [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947), we trained using an extra 400 traning images from HUST-TR400. You can down all [dataset](https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar) for training.
-
- ### SCUT-CTW1500
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |-----------|------------|--------------|----------------|------------|---------------|-------------|----------------|----------------|-----------------------------|---------------------------------------------------------------------------------------------------|
- | DBNet | 1p | ResNet-18 | SynthText | 85.68% | 85.33% | 85.50% | 20 | 163.80 ms/step | [yaml](db_r18_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_ctw1500-0864b040.ckpt) |
- | DBNet | 1p | ResNet-50 | SynthText | 87.83% | 84.71% | 86.25% | 20 | 180.11 ms/step | [yaml](db_r50_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ctw1500-f637e3d3.ckpt) |
-
-
- ### Total-Text
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |-----------|------------|--------------|----------------|------------|---------------|-------------|----------------|----------------|-------------------------------|-----------------------------------------------------------------------------------------------------|
- | DBNet | 1p | ResNet-18 | SynthText | 83.66% | 87.61% | 85.59% | 20 | 206.40 ms/step | [yaml](db_r18_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_totaltext-fb456ff4.ckpt) |
- | DBNet | 1p | ResNet-50 | SynthText | 84.79% | 87.07% | 85.91% | 20 | 289.44 ms/step | [yaml](db_r50_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_totaltext-76d6f421.ckpt) |
-
-
- ### MLT2017
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |-----------|------------|--------------|----------------|------------|---------------|-------------|----------------|----------------|-----------------------------|---------------------------------------------------------------------------------------------------|
- | DBNet | 8p | ResNet-18 | SynthText | 73.62% | 83.93% | 78.44% | 20 | 464.00 ms/step | [yaml](db_r18_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_mlt2017-5af33809.ckpt) |
- | DBNet | 8p | ResNet-50 | SynthText | 76.04% | 84.51% | 80.05% | 20 | 523.6 ms/step | [yaml](db_r50_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_mlt2017-3bd6e569.ckpt) |
-
-
- ### SynthText
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Train Loss**| **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |-------------------|------------|--------------|----------------|-------------|----------------|----------------|-------------|--------------|
- | DBNet | 1p | ResNet-18 | ImageNet | 2.41 | 16 | 131.83 ms/step | [yaml](db_r18_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_synthtext-251ef3dd.ckpt) |
- | DBNet | 1p | ResNet-50 | ImageNet | 2.25 | 16 | 195.07 ms/step | [yaml](db_r50_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_synthtext-40655acb.ckpt) |
-
-
-
-
-
- Performance tested on ascend 910* with graph mode
- ### ICDAR2015
-
-
-
- | **Model** | **Device Card** | **Backbone** | **Pretrained** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **Recipe** | **Download** |
- |---------------------|-----------------|---------------|----------------|------------|---------------|-------------|----------------|----------------|-------------------------------------|------------------------------------------------------------------------------------------------------------|
- | DBNet | 1p | MobileNetV3 | ImageNet | 74.68% | 79.38% | 76.95% | 10 | 65.69 ms/step | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-e72f9b8b-910v2.ckpt) |
- | DBNet | 8p | MobileNetV3 | ImageNet | 76.27% | 76.06% | 76.17% | 8 | 54.46 ms/step | [yaml](db_mobilenetv3_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-7e89e1df-910v2.ckpt) |
- | DBNet | 1p | ResNet-50 | ImageNet | 84.50% | 85.36% | 84.93% | 10 | 155.62 ms/step | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-48153c3b-910v2.ckpt) |
- | DBNet | 8p | ResNet-50 | ImageNet | 81.15% | 87.63% | 84.26% | 10 | 159.22 ms/step | [yaml](db_r50_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-e10bad35-910v2.ckpt) |
-
-
-
- > The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are `(1,3,736,1280)` and `(1,3,1152,2048)`, respectively.
-
-
-
-
-#### Notes
-- Note that the training time of DBNet is highly affected by data processing and varies on different machines.
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:--------------:|:--------------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
+## Quick Start
-## 4. Quick Start
-
-### 4.1 Installation
+### Installation
Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
-### 4.2 Dataset preparation
+### Dataset preparation
-#### 4.2.1 ICDAR2015 dataset
+#### ICDAR2015 dataset
Please download [ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=downloads) dataset, and convert the labels to the desired format referring to [dataset_converters](../../../tools/dataset_converters/README.md).
@@ -209,7 +94,7 @@ The prepared dataset file struture should be:
└── train_det_gt.txt
```
-#### 4.2.2 MSRA-TD500 dataset
+#### MSRA-TD500 dataset
Please download [MSRA-TD500](http://www.iapr-tc11.org/mediawiki/index.php/MSRA_Text_Detection_500_Database_(MSRA-TD500)) dataset,and convert the labels to the desired format referring to [dataset_converters](../../../tools/dataset_converters/README.md).
@@ -233,7 +118,7 @@ MSRA-TD500
│ ├── test_det_gt.txt
```
-#### 4.2.3 SCUT-CTW1500 dataset
+#### SCUT-CTW1500 dataset
Please download [SCUT-CTW1500](https://github.com/Yuliang-Liu/Curve-Text-Detector) dataset,and convert the labels to the desired format referring to [dataset_converters](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README.md).
@@ -253,7 +138,7 @@ ctw1500
├── train_det_gt.txt
```
-#### 4.2.4 Total-Text dataset
+#### Total-Text dataset
Please download [Total-Text](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset) dataset,and convert the labels to the desired format referring to [dataset_converters](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README.md).
@@ -275,7 +160,7 @@ totaltext
├── train_det_gt.txt
```
-#### 4.2.5 MLT2017 dataset
+#### MLT2017 dataset
The MLT2017 dataset is a multilingual text detection and recognition dataset that includes nine languages: Chinese, Japanese, Korean, English, French, Arabic, Italian, German, and Hindi. Please download [MLT2017](https://rrc.cvc.uab.es/?ch=8&com=downloads) and extract the dataset. Then convert the .gif format images in the data to .jpg or .png format, and convert the labels to the desired format referring to [dataset_converters](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README.md).
@@ -299,7 +184,7 @@ MLT_2017
> If users want to use their own dataset for training, please convert the labels to the desired format referring to [dataset_converters](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README.md). Then configure the yaml file, and use a single or multiple devices to run train.py for training. For detailed information, please refer to the following tutorials.
-#### 4.2.6 SynthText dataset
+#### SynthText dataset
Please download [SynthText](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c) dataset and process it as described in [dataset_converters](../../../tools/dataset_converters/README.md)
@@ -329,7 +214,7 @@ Please download [SynthText](https://academictorrents.com/details/2dba9518166cbd1
> ```
> This operation will generate a filtered output in the same format as the original `SynthText`.
-### 4.3 Update yaml config file
+### Update yaml config file
Update `configs/det/dbnet/db_r50_icdar15.yaml` configuration file with data paths,
specifically the following parts. The `dataset_root` will be concatenated with `data_dir` and `label_file` respectively to be the complete dataset directory and label file path.
@@ -382,7 +267,7 @@ model:
[comment]: <> (The only difference between _DBNet_ and _DBNet++_ is in the _Adaptive Scale Fusion_ module, which is controlled by the `use_asf` parameter in the `neck` module.)
-### 4.4 Training
+### Training
* Standalone training
@@ -404,7 +289,7 @@ mpirun --allow-run-as-root -n 2 python tools/train.py --config configs/det/dbnet
The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg `ckpt_save_dir` in yaml config file. The default directory is `./tmp_det`.
-### 4.5 Evaluation
+### Evaluation
To evaluate the accuracy of the trained model, you can use `eval.py`. Please set the checkpoint path to the arg `ckpt_load_path` in the `eval` section of yaml config file, set `distribute` to be False, and then run:
@@ -412,42 +297,108 @@ To evaluate the accuracy of the trained model, you can use `eval.py`. Please set
python tools/eval.py -c=configs/det/dbnet/db_r50_icdar15.yaml
```
-## 5. MindSpore Lite Inference
+## Performance
-Please refer to the tutorial [MindOCR Inference](../../../docs/en/inference/inference_tutorial.md) for model inference based on MindSpot Lite on Ascend 310, including the following steps:
+### General Purpose Models
-- Model Export
+Here we present general purpose models that were trained on wide variety of tasks (real-world photos, street views, documents, etc.) and challenges (straight texts, curved texts, long text lines, etc.) with two primary languages: Chinese and English. These models can be used right off-the-shelf in your applications or for initialization of your models.
-Please [download](#3-results) the exported MindIR file first, or refer to the [Model Export](../../../docs/en/inference/convert_tutorial.md#1-model-export) tutorial and use the following command to export the trained ckpt model to MindIR file:
+The models were trained on 12 public datasets (CTW, LSVT, RCTW-17, TextOCR, etc.) that contain wide range of images. The training set has 153,511 images and the validation set has 9,786 images.
+The test set consists of 598 images manually selected from the above-mentioned datasets.
-```shell
-python tools/export.py --model_name_or_config dbnet_resnet50 --data_shape 736 1280 --local_ckpt_path /path/to/local_ckpt.ckpt
-# or
-python tools/export.py --model_name_or_config configs/det/dbnet/db_r50_icdar15.yaml --data_shape 736 1280 --local_ckpt_path /path/to/local_ckpt.ckpt
-```
+Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
-The `data_shape` is the model input shape of height and width for MindIR file. The shape value of MindIR in the download link can be found in [ICDAR2015 Notes](#ICDAR2015).
+*coming soon*
-- Environment Installation
+Experiments are tested on ascend 910 with mindspore 2.3.1 graph mode.
-Please refer to [Environment Installation](../../../docs/en/inference/environment.md) tutorial to configure the MindSpore Lite inference environment.
+*coming soon*
-- Model Conversion
-Please refer to [Model Conversion](../../../docs/en/inference/convert_tutorial.md#2-mindspore-lite-mindir-convert),
-and use the `converter_lite` tool for offline conversion of the MindIR file.
+### Specific Purpose Models
-- Inference
+DBNet and DBNet++ were trained on the ICDAR2015, MSRA-TD500, SCUT-CTW1500, Total-Text, and MLT2017 datasets. In addition, we conducted pre-training on the SynthText dataset and provided a URL to download pretrained weights. All training results are as follows:
-Assuming that you obtain output.mindir after model conversion, go to the `deploy/py_infer` directory, and use the following command for inference:
+#### ICDAR2015
+
+Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :------------------------------------: | :--------------------------------------------------------------------------------------------------------: |
+| DBNet | MobileNetV3 | ImageNet | 1 | 10 | O2 | 403.87 s | 65.69 | 152.23 | 74.68% | 79.38% | 76.95% | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-e72f9b8b-910v2.ckpt) |
+| DBNet | MobileNetV3 | ImageNet | 8 | 8 | O2 | 405.35 s | 54.46 | 1175.12 | 76.27% | 76.06% | 76.17% | [yaml](db_mobilenetv3_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-7e89e1df-910v2.ckpt) |
+| DBNet | ResNet-50 | ImageNet | 1 | 10 | O2 | 147.81 s | 155.62 | 64.25 | 84.50% | 85.36% | 84.93% | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-48153c3b-910v2.ckpt) |
+| DBNet | ResNet-50 | ImageNet | 8 | 10 | O2 | 151.23 s | 159.22 | 502.4 | 81.15% | 87.63% | 84.26% | [yaml](db_r50_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-e10bad35-910v2.ckpt) |
+
+> The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are `(1,3,736,1280)` and `(1,3,1152,2048)`, respectively.
+
+Experiments are tested on ascend 910 with mindspore 2.3.1 graph mode.
+
+#### ICDAR2015
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | img/s | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :---: | :--------: | :-----------: | :---------: | :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| DBNet | MobileNetV3 | ImageNet | 1 | 10 | O2 | 321.15 s | 100 | 100 | 76.31% | 78.27% | 77.28% | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539-f14c6a13.mindir) |
+| DBNet | MobileNetV3 | ImageNet | 8 | 8 | O2 | 309.39 s | 66.64 | 960 | 76.22% | 77.98% | 77.09% | [yaml](db_mobilenetv3_icdar15_8p.yaml) | Coming soon |
+| DBNet | ResNet-18 | ImageNet | 1 | 20 | O2 | 75.23 s | 185.19 | 108 | 80.12% | 83.41% | 81.73% | [yaml](db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) |
+| DBNet | ResNet-50 | ImageNet | 1 | 10 | O2 | 110.54 s | 132.98 | 75.2 | 83.53% | 86.62% | 85.05% | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) |
+| DBNet | ResNet-50 | ImageNet | 8 | 10 | O2 | 107.91 s | 183.92 | 435 | 82.62% | 88.54% | 85.48% | [yaml](db_r50_icdar15_8p.yaml) | Coming soon |
+| DBNet++ | ResNet-50 | SynthText | 1 | 32 | O2 | 184.74 s | 409.21 | 78.2 | 86.81% | 86.85% | 86.86% | [yaml](dbpp_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) |
+
+> The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are `(1,3,736,1280)` and `(1,3,1152,2048)`, respectively.
+
+#### MSRA-TD500
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :-----------------------: | :---------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | SynthText | 1 | 20 | O2 | 76.18 s | 163.34 | 121.7 | 79.90% | 88.07% | 83.78% | [yaml](db_r18_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_td500-b5abff68.ckpt) |
+| DBNet | ResNet-50 | SynthText | 1 | 20 | O2 | 108.45 s | 280.90 | 71.2 | 84.02% | 87.48% | 85.71% | [yaml](db_r50_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_td500-0d12b5e8.ckpt) |
+
+
+> MSRA-TD500 dataset has 300 training images and 200 testing images, reference paper [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947), we trained using an extra 400 traning images from HUST-TR400. You can down all [dataset](https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar) for training.
+
+#### SCUT-CTW1500
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+|----------------|--------------|----------------|-----------|----------------|---------------|-------------------|-------------|-----------|------------|---------------|-------------|-----------------------------|---------------------------------------------------------------------------------------------------|
+| DBNet | ResNet-18 | SynthText | 1 | 20 | O2 | 73.18 s | 163.80 | 122.1 | 85.68% | 85.33% | 85.50% | [yaml](db_r18_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_ctw1500-0864b040.ckpt) |
+| DBNet | ResNet-50 | SynthText | 1 | 20 | O2 | 110.34 s | 180.11 | 71.4 | 87.83% | 84.71% | 86.25% | [yaml](db_r50_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ctw1500-f637e3d3.ckpt) |
+
+
+#### Total-Text
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :---------------------------: | :-------------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | SynthText | 1 | 20 | O2 | 77.78 s | 206.40 | 96.9 | 83.66% | 87.61% | 85.59% | [yaml](db_r18_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_totaltext-fb456ff4.ckpt) |
+| DBNet | ResNet-50 | SynthText | 1 | 20 | O2 | 109.15 s | 289.44 | 69.1 | 84.79% | 87.07% | 85.91% | [yaml](db_r50_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_totaltext-76d6f421.ckpt) |
+
+
+#### MLT2017
+
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :-------------------------: | :-----------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | SynthText | 8 | 20 | O2 | 73.76 s | 464.00 | 344.8 | 73.62% | 83.93% | 78.44% | [yaml](db_r18_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_mlt2017-5af33809.ckpt) |
+| DBNet | ResNet-50 | SynthText | 8 | 20 | O2 | 105.12 s | 523.60 | 305.6 | 76.04% | 84.51% | 80.05% | [yaml](db_r50_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_mlt2017-3bd6e569.ckpt) |
+
+
+#### SynthText
+
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **train loss** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :------------: | :---------------------------: | :-------------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | ImageNet | 1 | 16 | O2 | 78.46 s | 131.83 | 121.37 | 2.41 | [yaml](db_r18_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_synthtext-251ef3dd.ckpt) |
+| DBNet | ResNet-50 | ImageNet | 1 | 16 | O2 | 108.93 s | 195.07 | 82.02 | 2.25 | [yaml](db_r50_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_synthtext-40655acb.ckpt) |
+
+
+
+### Notes
+- Note that the training time of DBNet is highly affected by data processing and varies on different machines.
-```shell
-python infer.py \
- --input_images_dir=/your_path_to/test_images \
- --det_model_path=your_path_to/output.mindir \
- --det_model_name_or_config=../../configs/det/dbnet/db_r50_icdar15.yaml \
- --res_save_dir=results_dir
-```
## References
diff --git a/configs/det/dbnet/README_CN.md b/configs/det/dbnet/README_CN.md
index a5da6b43c..4c5bd0970 100644
--- a/configs/det/dbnet/README_CN.md
+++ b/configs/det/dbnet/README_CN.md
@@ -7,7 +7,7 @@
> DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947)
> DBNet++: [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
-## 1. 概述
+## 概述
### DBNet
@@ -41,134 +41,22 @@ ASF由两个注意力模块组成——阶段注意力模块(stage-wise attent
阶段注意模块学习不同尺寸的特征图的权重,而空间注意力模块学习跨空间维度的attention。这两个模块的组合使得模型可以获得尺寸(scale)鲁棒性很好的特征融合。
DBNet++在检测不同尺寸的文本方面表现更好,尤其是对于尺寸较大的文本;然而,DBNet在检测尺寸较大的文本时可能会生成不准确或分离的检测框。
-## 2. 通用泛化模型
+### 配套版本
-本节提供了一些通过泛化模型,该模型使用中文和英文两种语言训练,针对各种不同的任务和挑战,包括真实世界图片,街景图片,文档,弯曲文本,长文本等。这些模型可直接用于下游任务,也可直接作为预训练权重。
-
-这些模型在12个公开数据集上训练,包括CTW,LSVT,RCTW-17,TextOCR等,其中训练集包含153511张图片,验证集包含9786张图片。
-从上述数据集中手动选择598张未被训练集和验证集使用的图片构成测试集。
-
-在采用图模式的ascend 910上测试性能
-
-
-| **模型** | **设备卡数** | **骨干网络** | **语言** | **F-score** | **吞吐量** | **模型权重下载** |
-|-----------|----------|--------------|-------------------|:---------------------------:|----------------|----------------------------------------------------------------------------------------------------------|
-| DBNet | 8P | ResNet-50 | Chinese + English | 83.41% | 256 img/s | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141-912f0a90.mindir) |
-| DBNet++ | 4P | ResNet-50 | Chinese + English | 84.30% | 104 img/s | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) |
-
-
-> 链接中模型DBNet的MindIR导出时的输入Shape为`(1,3,736,1280)`,模型DBNet++的MindIR导出时的输入Shape为`(1,3,1152,2048)`。
-
-
-## 3. 实验结果
-
-DBNet和DBNet++在ICDAR2015,MSRA-TD500,SCUT-CTW1500,Total-Text和MLT2017数据集上训练。另外,我们在SynthText数据集上进行了预训练,并提供预训练权重下载链接。所有训练结果如下:
-
-
- 在采用图模式的ascend 910上测试性能
- ### ICDAR2015
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **预训练数据集** | **Recall** | **Precision** | **F-score** | **训练时间** | **吞吐量** | **配置文件** | **模型权重下载** |
- |---------------------|--------|---------------|------------|------------|---------------|-------------|--------------|------------|----------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
- | DBNet | 1P | MobileNetV3 | ImageNet | 76.26% | 78.22% | 77.28% | 10 s/epoch | 100 img/s | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539-f14c6a13.mindir) |
- | DBNet | 8P | MobileNetV3 | ImageNet | 76.22% | 77.98% | 77.09% | 1.1 s/epoch | 960 img/s | [yaml](db_mobilenetv3_icdar15_8p.yaml) | Coming soon |
- | DBNet | 1P | ResNet-18 | ImageNet | 80.12% | 83.41% | 81.73% | 9.3 s/epoch | 108 img/s | [yaml](db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) |
- | DBNet | 1P | ResNet-50 | ImageNet | 83.53% | 86.62% | 85.05% | 13.3 s/epoch | 75.2 img/s | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) |
- | DBNet | 8P | ResNet-50 | ImageNet | 82.62% | 88.54% | 85.48% | 2.3 s/epoch | 435 img/s | [yaml](db_r50_icdar15_8p.yaml) | Coming soon |
- | | | | | | | | | | | |
- | DBNet++ | 1P | ResNet-50 | SynthText | 85.70% | 87.81% | 86.74% | 17.7 s/epoch | 56 img/s | [yaml](dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) |
- | DBNet++ | 8P | ResNet-50 | SynthText | 85.41% | 89.55% | 87.43% | 1.78 s/epoch | 432 img/s | [yaml](dbpp_r50_icdar15_8p.yaml) | Coming soon |
- | DBNet++ | 1P | ResNet-50 | SynthText | 86.81% | 86.85% | 86.86% | 12.7 s/epoch | 78.2 img/s | [yaml](dbpp_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) |
-
-
- > 链接中模型DBNet的MindIR导出时的输入Shape为`(1,3,736,1280)`,模型DBNet++的MindIR导出时的输入Shape为`(1,3,1152,2048)`。
-
- ### MSRA-TD500
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **预训练数据集** | **Recall** | **Precision** | **F-score** | **训练时间** | **吞吐量** | **配置文件** | **模型权重下载** |
- |--------|--------|-----------|------------|------------|---------------|-------------|-------------|-------------|---------------------------|-------------------------------------------------------------------------------------------------|
- | DBNet | 1P | ResNet-18 | SynthText | 79.90% | 88.07% | 83.78% | 5.6 s/epoch | 121.7 img/s | [yaml](db_r18_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_td500-b5abff68.ckpt) |
- | DBNet | 1P | ResNet-50 | SynthText | 84.02% | 87.48% | 85.71% | 9.6 s/epoch | 71.2 img/s | [yaml](db_r50_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_td500-0d12b5e8.ckpt) |
-
-
- > MSRA-TD500数据集有300训练集图片和200测试集图片,参考论文[Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947),我们训练此权重额外使用了来自HUST-TR400数据集的400训练集图片。可以在此下载全部[数据集](https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar)用于训练。
-
- ### SCUT-CTW1500
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **预训练数据集** | **Recall** | **Precision** | **F-score** | **训练时间** | **吞吐量** | **配置文件** | **模型权重下载** |
- |--------|--------|-----------|------------|------------|---------------|-------------|--------------|-------------|-----------------------------|---------------------------------------------------------------------------------------------------|
- | DBNet | 1P | ResNet-18 | SynthText | 85.68% | 85.33% | 85.50% | 8.2 s/epoch | 122.1 img/s | [yaml](db_r18_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_ctw1500-0864b040.ckpt) |
- | DBNet | 1P | ResNet-50 | SynthText | 87.83% | 84.71% | 86.25% | 14.0 s/epoch | 71.4 img/s | [yaml](db_r50_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ctw1500-f637e3d3.ckpt) |
-
-
- ### Total-Text
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **预训练数据集** | **Recall** | **Precision** | **F-score** | **训练时间** | **吞吐量** | **配置文件** | **模型权重下载** |
- |--------|--------|-----------|------------|------------|---------------|-------------|--------------|------------|-------------------------------|-----------------------------------------------------------------------------------------------------|
- | DBNet | 1P | ResNet-18 | SynthText | 83.66% | 87.61% | 85.59% | 12.9 s/epoch | 96.9 img/s | [yaml](db_r18_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_totaltext-fb456ff4.ckpt) |
- | DBNet | 1P | ResNet-50 | SynthText | 84.79% | 87.07% | 85.91% | 18.0 s/epoch | 69.1 img/s | [yaml](db_r50_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_totaltext-76d6f421.ckpt) |
-
-
- ### MLT2017
-
-
-
- | **模型** | **设备** | **骨干网络** | **预训练数据集** | **Recall** | **Precision** | **F-score** | **训练时间** | **吞吐量** | **配置文件** | **模型权重下载** |
- |--------|--------|-----------|------------|------------|---------------|-------------|--------------|-------------|-----------------------------|---------------------------------------------------------------------------------------------------|
- | DBNet | 8P | ResNet-18 | SynthText | 73.62% | 83.93% | 78.44% | 20.9 s/epoch | 344.8 img/s | [yaml](db_r18_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_mlt2017-5af33809.ckpt) |
- | DBNet | 8P | ResNet-50 | SynthText | 76.04% | 84.51% | 80.05% | 23.6 s/epoch | 305.6 img/s | [yaml](db_r50_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_mlt2017-3bd6e569.ckpt) |
-
-
- ### SynthText
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **预训练数据集** | **训练Loss**| **训练时间** | **吞吐量** | **配置文件** | **模型权重下载** |
- |-----------------|--------|--------------|----------------|---------|---------|---------------|-------------|--------------|
- | DBNet | 1P | ResNet-18 | ImageNet | 2.41 |7075 s/epoch | 121.37 img/s | [yaml](db_r18_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_synthtext-251ef3dd.ckpt) |
- | DBNet | 1P | ResNet-50 | ImageNet | 2.25 |10470 s/epoch | 82.02 img/s | [yaml](db_r50_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_synthtext-40655acb.ckpt) |
-
-
-
-
-
- 在采用图模式的ascend 910*上测试性能
- ### ICDAR2015
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **预训练数据集** | **Recall** | **Precision** | **F-score** | **Batch Size** | **Step Time** | **配置文件** | **模型权重下载** |
- |--------|----------|-------------|------------|------------|---------------|-------------|----------------|----------------|----------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
- | DBNet | 1p | MobileNetV3 | ImageNet | 74.68% | 79.38% | 76.95% | 10 | 65.69 ms/step | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-e72f9b8b-910v2.ckpt) |
- | DBNet | 8p | MobileNetV3 | ImageNet | 76.27% | 76.06% | 76.17% | 8 | 54.46 ms/step | [yaml](db_mobilenetv3_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-7e89e1df-910v2.ckpt) |
- | DBNet | 1p | ResNet-50 | ImageNet | 84.50% | 85.36% | 84.93% | 10 | 155.62 ms/step | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-48153c3b-910v2.ckpt) |
- | DBNet | 8p | ResNet-50 | ImageNet | 81.15% | 87.63% | 84.26% | 10 | 159.22 ms/step | [yaml](db_r50_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-e10bad35-910v2.ckpt) |
-
-
-
-
-
-#### 注释:
-- DBNet的训练时长受数据处理部分和不同运行环境的影响非常大。
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:--------------:|:--------------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
-## 4. 快速上手
+## 快速上手
-### 4.1 安装
+### 安装
请参考MindOCR套件的[安装指南](https://github.com/mindspore-lab/mindocr#installation) 。
-### 4.2 数据准备
+### 数据准备
-#### 4.2.1 ICDAR2015 数据集
+#### ICDAR2015 数据集
请从[该网址](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载ICDAR2015数据集,然后参考[数据转换](../../../tools/dataset_converters/README_CN.md)对数据集标注进行格式转换。
@@ -190,7 +78,7 @@ DBNet和DBNet++在ICDAR2015,MSRA-TD500,SCUT-CTW1500,Total-Text和MLT2017
└── train_det_gt.txt
```
-#### 4.2.2 MSRA-TD500 数据集
+#### MSRA-TD500 数据集
请从[该网址](http://www.iapr-tc11.org/mediawiki/index.php/MSRA_Text_Detection_500_Database_(MSRA-TD500))下载MSRA-TD500数据集,然后参考[数据转换](../../../tools/dataset_converters/README_CN.md)对数据集标注进行格式转换。
@@ -214,7 +102,7 @@ MSRA-TD500
│ ├── test_det_gt.txt
```
-#### 4.2.3 SCUT-CTW1500 数据集
+#### SCUT-CTW1500 数据集
请从[该网址](https://github.com/Yuliang-Liu/Curve-Text-Detector)下载SCUT-CTW1500数据集,然后参考[数据转换](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README_CN.md)对数据集标注进行格式转换。
@@ -234,7 +122,7 @@ ctw1500
├── train_det_gt.txt
```
-#### 4.2.4 Total-Text 数据集
+#### Total-Text 数据集
请从[该网址](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset)下载Total-Text数据集,然后参考[数据转换](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README_CN.md)对数据集标注进行格式转换。
@@ -255,7 +143,7 @@ totaltext
├── train_det_gt.txt
```
-#### 4.2.5 MLT2017 数据集
+#### MLT2017 数据集
MLT2017数据集是一个多语言文本检测识别数据集,包含中文、日文、韩文、英文、法文、阿拉伯文、意大利文、德文和印度文共9种语言。请从[该网址](https://rrc.cvc.uab.es/?ch=8&com=downloads)下载MLT2017数据集,解压后请将数据中格式为.gif的图像转化为.jpg或.png格式。然后参考[数据转换](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README_CN.md)对数据集标注进行格式转换。
完成数据准备工作后,数据的目录结构应该如下所示:
@@ -277,7 +165,7 @@ MLT_2017
```
> 用户如果想要使用自己的数据集进行训练,请参考[数据转换](https://github.com/mindspore-lab/mindocr/blob/main/tools/dataset_converters/README_CN.md)对数据集标注进行格式转换。并配置yaml文件,然后使用单卡或者多卡运行train.py进行训练即可,详细信息可参考下面几节教程。
-#### 4.2.6 SynthText 数据集
+#### SynthText 数据集
请从[该网址](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c)下载SynthText数据集,解压后的数据的目录结构应该如下所示:
@@ -307,7 +195,7 @@ MLT_2017
> ```
> 以上的操作会产生与`SynthText`原始标注格式相同但是是经过过滤后的标注数据.
-### 4.3 配置说明
+### 配置说明
在配置文件`configs/det/dbnet/db_r50_icdar15.yaml`中更新如下文件路径。其中`dataset_root`会分别和`data_dir`以及`label_file`拼接构成完整的数据集目录和标签文件路径。
@@ -359,7 +247,7 @@ model:
[comment]: <> (_DBNet_和_DBNet++的唯一区别在于_Adaptive Scale Fusion_ module, 在`neck`模块中的 `use_asf`参数进行设置。)
-### 4.4 训练
+### 训练
* 单卡训练
@@ -381,7 +269,7 @@ mpirun --allow-run-as-root -n 2 python tools/train.py --config configs/det/dbnet
训练结果(包括checkpoint、每个epoch的性能和曲线图)将被保存在yaml配置文件的`ckpt_save_dir`参数配置的路径下,默认为`./tmp_det`。
-### 4.5 评估
+### 评估
评估环节,在yaml配置文件中将`ckpt_load_path`参数配置为checkpoint文件的路径,设置`distribute`为False,然后运行:
@@ -389,42 +277,105 @@ mpirun --allow-run-as-root -n 2 python tools/train.py --config configs/det/dbnet
python tools/eval.py --config configs/det/dbnet/db_r50_icdar15.yaml
```
-## 5. MindSpore Lite 推理
+## 性能表现
-请参考[MindOCR 推理](../../../docs/zh/inference/inference_tutorial.md)教程,基于MindSpore Lite在Ascend 310上进行模型的推理,包括以下步骤:
+### 通用泛化模型
-- 模型导出
+本节提供了一些通过泛化模型,该模型使用中文和英文两种语言训练,针对各种不同的任务和挑战,包括真实世界图片,街景图片,文档,弯曲文本,长文本等。这些模型可直接用于下游任务,也可直接作为预训练权重。
-请先[下载](#3-实验结果)已导出的MindIR文件,或者参考[模型导出](../../../docs/zh/inference/convert_tutorial.md#1-模型导出)教程,使用以下命令将训练完成的ckpt导出为MindIR文件:
+这些模型在12个公开数据集上训练,包括CTW,LSVT,RCTW-17,TextOCR等,其中训练集包含153511张图片,验证集包含9786张图片。
+从上述数据集中手动选择598张未被训练集和验证集使用的图片构成测试集。
-```shell
-python tools/export.py --model_name_or_config dbnet_resnet50 --data_shape 736 1280 --local_ckpt_path /path/to/local_ckpt.ckpt
-# or
-python tools/export.py --model_name_or_config configs/det/dbnet/db_r50_icdar15.yaml --data_shape 736 1280 --local_ckpt_path /path/to/local_ckpt.ckpt
-```
+在采用图模式的ascend 910*上实验结果,mindspore版本为2.3.1
-其中,`data_shape`是导出MindIR时的模型输入Shape的height和width,下载链接中MindIR对应的shape值见[ICDAR2015注释](#ICDAR2015)。
+*即将到来*
-- 环境搭建
+在采用图模式的ascend 910上实验结果,mindspore版本为2.3.1
-请参考[环境安装](../../../docs/zh/inference/environment.md)教程,配置MindSpore Lite推理运行环境。
+*即将到来*
-- 模型转换
-请参考[模型转换](../../../docs/zh/inference/convert_tutorial.md#2-mindspore-lite-mindir-转换)教程,使用`converter_lite`工具对MindIR模型进行离线转换。
+### 细分领域模型
-- 执行推理
+DBNet和DBNet++在ICDAR2015,MSRA-TD500,SCUT-CTW1500,Total-Text和MLT2017数据集上训练。另外,我们在SynthText数据集上进行了预训练,并提供预训练权重下载链接。所有训练结果如下:
+在采用图模式的ascend 910*上实验结果,mindspore版本为2.3.1
-假设在模型转换后得到output.mindir文件,在`deploy/py_infer`目录下使用以下命令进行推理:
+#### ICDAR2015
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :------------------------------------: | :--------------------------------------------------------------------------------------------------------: |
+| DBNet | MobileNetV3 | ImageNet | 1 | 10 | O2 | 403.87 s | 65.69 | 152.23 | 74.68% | 79.38% | 76.95% | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-e72f9b8b-910v2.ckpt) |
+| DBNet | MobileNetV3 | ImageNet | 8 | 8 | O2 | 405.35 s | 54.46 | 1175.12 | 76.27% | 76.06% | 76.17% | [yaml](db_mobilenetv3_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-7e89e1df-910v2.ckpt) |
+| DBNet | ResNet-50 | ImageNet | 1 | 10 | O2 | 147.81 s | 155.62 | 64.25 | 84.50% | 85.36% | 84.93% | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-48153c3b-910v2.ckpt) |
+| DBNet | ResNet-50 | ImageNet | 8 | 10 | O2 | 151.23 s | 159.22 | 502.4 | 81.15% | 87.63% | 84.26% | [yaml](db_r50_icdar15_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/dbnet/dbnet_resnet50-e10bad35-910v2.ckpt) |
+
+> 链接中模型DBNet的MindIR导出时的输入Shape为`(1,3,736,1280)`,模型DBNet++的MindIR导出时的输入Shape为`(1,3,1152,2048)`。
+
+
+在采用图模式的ascend 910上实验结果,mindspore版本为2.3.1
+
+#### ICDAR2015
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | img/s | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :---: | :--------: | :-----------: | :---------: | :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| DBNet | MobileNetV3 | ImageNet | 1 | 10 | O2 | 321.15 s | 100 | 100 | 76.31% | 78.27% | 77.28% | [yaml](db_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_mobilenetv3-62c44539-f14c6a13.mindir) |
+| DBNet | MobileNetV3 | ImageNet | 8 | 8 | O2 | 309.39 s | 66.64 | 960 | 76.22% | 77.98% | 77.09% | [yaml](db_mobilenetv3_icdar15_8p.yaml) | Coming soon |
+| DBNet | ResNet-18 | ImageNet | 1 | 20 | O2 | 75.23 s | 185.19 | 108 | 80.12% | 83.41% | 81.73% | [yaml](db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) |
+| DBNet | ResNet-50 | ImageNet | 1 | 10 | O2 | 110.54 s | 132.98 | 75.2 | 83.53% | 86.62% | 85.05% | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) |
+| DBNet | ResNet-50 | ImageNet | 8 | 10 | O2 | 107.91 s | 183.92 | 435 | 82.62% | 88.54% | 85.48% | [yaml](db_r50_icdar15_8p.yaml) | Coming soon |
+| DBNet++ | ResNet-50 | SynthText | 1 | 32 | O2 | 184.74 s | 409.21 | 78.2 | 86.81% | 86.85% | 86.86% | [yaml](dbpp_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) |
+
+
+> 链接中模型DBNet的MindIR导出时的输入Shape为`(1,3,736,1280)`,模型DBNet++的MindIR导出时的输入Shape为`(1,3,1152,2048)`。
+
+#### MSRA-TD500
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :-----------------------: | :---------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | SynthText | 1 | 20 | O2 | 76.18 s | 163.34 | 121.7 | 79.90% | 88.07% | 83.78% | [yaml](db_r18_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_td500-b5abff68.ckpt) |
+| DBNet | ResNet-50 | SynthText | 1 | 20 | O2 | 108.45 s | 280.90 | 71.2 | 84.02% | 87.48% | 85.71% | [yaml](db_r50_td500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_td500-0d12b5e8.ckpt) |
+
+> MSRA-TD500数据集有300训练集图片和200测试集图片,参考论文[Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947),我们训练此权重额外使用了来自HUST-TR400数据集的400训练集图片。可以在此下载全部[数据集](https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar)用于训练。
+
+#### SCUT-CTW1500
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+|----------------|--------------|----------------|-----------|----------------|---------------|-------------------|-------------|-----------|------------|---------------|-------------|-----------------------------|---------------------------------------------------------------------------------------------------|
+| DBNet | ResNet-18 | SynthText | 1 | 20 | O2 | 73.18 s | 163.80 | 122.1 | 85.68% | 85.33% | 85.50% | [yaml](db_r18_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_ctw1500-0864b040.ckpt) |
+| DBNet | ResNet-50 | SynthText | 1 | 20 | O2 | 110.34 s | 180.11 | 71.4 | 87.83% | 84.71% | 86.25% | [yaml](db_r50_ctw1500.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ctw1500-f637e3d3.ckpt) |
+
+
+#### Total-Text
+
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :---------------------------: | :-------------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | SynthText | 1 | 20 | O2 | 77.78 s | 206.40 | 96.9 | 83.66% | 87.61% | 85.59% | [yaml](db_r18_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_totaltext-fb456ff4.ckpt) |
+| DBNet | ResNet-50 | SynthText | 1 | 20 | O2 | 109.15 s | 289.44 | 69.1 | 84.79% | 87.07% | 85.91% | [yaml](db_r50_totaltext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_totaltext-76d6f421.ckpt) |
+
+
+#### MLT2017
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **recall** | **precision** | **f-score** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :--------: | :-----------: | :---------: | :-------------------------: | :-----------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | SynthText | 8 | 20 | O2 | 73.76 s | 464.00 | 344.8 | 73.62% | 83.93% | 78.44% | [yaml](db_r18_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_mlt2017-5af33809.ckpt) |
+| DBNet | ResNet-50 | SynthText | 8 | 20 | O2 | 105.12 s | 523.60 | 305.6 | 76.04% | 84.51% | 80.05% | [yaml](db_r50_mlt2017.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_mlt2017-3bd6e569.ckpt) |
+
+
+#### SynthText
+
+| **model name** | **backbone** | **pretrained** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **train loss** | **recipe** | **weight** |
+| :------------: | :----------: | :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :------------: | :---------------------------: | :-------------------------------------------------------------------------------------------------: |
+| DBNet | ResNet-18 | ImageNet | 1 | 16 | O2 | 78.46 s | 131.83 | 121.37 | 2.41 | [yaml](db_r18_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18_synthtext-251ef3dd.ckpt) |
+| DBNet | ResNet-50 | ImageNet | 1 | 16 | O2 | 108.93 s | 195.07 | 82.02 | 2.25 | [yaml](db_r50_synthtext.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_synthtext-40655acb.ckpt) |
+
+
+### 注释
+- DBNet的训练时长受数据处理部分和不同运行环境的影响非常大。
-```shell
-python infer.py \
- --input_images_dir=/your_path_to/test_images \
- --det_model_path=your_path_to/output.mindir \
- --det_model_name_or_config=../../configs/det/dbnet/db_r50_icdar15.yaml \
- --res_save_dir=results_dir
-```
## 参考文献
diff --git a/configs/layout/layoutlmv3/README.md b/configs/layout/layoutlmv3/README.md
new file mode 100644
index 000000000..57da17a01
--- /dev/null
+++ b/configs/layout/layoutlmv3/README.md
@@ -0,0 +1,104 @@
+English | [中文](README_CN.md)
+
+# LayoutLMv3
+
+
+> [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387)
+
+> [Original Repo](https://github.com/microsoft/unilm/tree/master/layoutlmv3)
+
+## 1. Introduction
+Unlike previous LayoutLM series models, LayoutLMv3 does not rely on complex CNN or Faster R-CNN networks to represent images in its model architecture. Instead, it directly utilizes image blocks of document images, thereby greatly reducing parameters and avoiding complex document preprocessing such as manual annotation of target region boxes and document object detection. Its simple unified architecture and training objectives make LayoutLMv3 a versatile pretraining model suitable for both text-centric and image-centric document AI tasks.
+
+The experimental results demonstrate that LayoutLMv3 achieves better performance with fewer parameters on the following datasets:
+
+- Text-centric datasets: Form Understanding FUNSD dataset, Receipt Understanding CORD dataset, and Document Visual Question Answering DocVQA dataset.
+- Image-centric datasets: Document Image Classification RVL-CDIP dataset and Document Layout Analysis PubLayNet dataset.
+
+LayoutLMv3 also employs a text-image multimodal Transformer architecture to learn cross-modal representations. Text vectors are obtained by adding word vectors, one-dimensional positional vectors, and two-dimensional positional vectors of words. Text from document images and their corresponding two-dimensional positional information (layout information) are extracted using optical character recognition (OCR) tools. As adjacent words in text often convey similar semantics, LayoutLMv3 shares the two-dimensional positional vectors of adjacent words, while each word in LayoutLM and LayoutLMv2 has different two-dimensional positional vectors.
+
+The representation of image vectors typically relies on CNN-extracted feature grid features or Faster R-CNN-extracted region features, which increase computational costs or depend on region annotations. Therefore, the authors obtain image features by linearly mapping image blocks, a representation method initially proposed in ViT, which incurs minimal computational cost and does not rely on region annotations, effectively addressing the aforementioned issues. Specifically, the image is first resized to a uniform size (e.g., 224x224), then divided into fixed-size blocks (e.g., 16x16), and image features are obtained through linear mapping to form an image feature sequence, followed by addition of a learnable one-dimensional positional vector to obtain the image vector.[[1](#references)]
+
+
+
+
+
+ Figure 1. LayoutLMv3 architecture [1]
+
+
+## 2. Quick Start
+
+### 2.1 Preparation
+
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:---------------:|:------------:|:--------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
+
+#### 2.1.1 Installation
+Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
+
+#### 2.1.2 PubLayNet Dataset Preparation
+
+PubLayNet is a dataset for document layout analysis. It contains images of research papers and articles and annotations for various elements in a page such as "text", "list", "figure" etc in these research paper images. The dataset was obtained by automatically matching the XML representations and the content of over 1 million PDF articles that are publicly available on PubMed Central.
+
+The training and validation datasets for PubLayNet can be downloaded [here](https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz)
+
+```bash
+python tools/dataset_converters/convert.py \
+ --dataset_name publaynet \
+ --image_dir publaynet/ \
+ --output_path publaynet/
+```
+
+Once the download is complete, the data can be converted to a data type in layoutlmv3 input format using the script provided by MindOCR above.
+
+### 2.2 Model Conversion
+
+Note: Please install torch before starting the conversion script
+```bash
+pip install torch
+```
+
+Download the [layoutlmv3-base-finetuned-publaynet](https://huggingface.co/HYPJUDY/layoutlmv3-base-finetuned-publaynet) model to /path/to/layoutlmv3-base-finetuned-publaynet, and run:
+
+```bash
+python tools/param_converter_from_torch.py \
+ --input_path /path/to/layoutlmv3-base-finetuned-publaynet/model_final.pt \
+ --json_path configs/layout/layoutlmv3/layoutlmv3_publaynet_param_map.json \
+ --output_path /path/to/layoutlmv3-base-finetuned-publaynet/from_torch.ckpt
+```
+
+### 2.3 Model Evaluation
+The evaluation results on the public benchmark dataset (PublayNet) are as follows:
+
+Experiments are tested on ascend 910* with mindspore 2.3.1 pynative mode
+
+
+| **model name** | **cards** | **batch size** | **img/s** | **map** | **config** |
+|----------------|-----------|----------------|-----------|---------|----------------------------------------------------------------------------------------------------------------|
+| LayoutLMv3 | 1 | 1 | 2.7 | 94.3% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/layout/layoutlmv3/layoutlmv3_publaynet.yaml) |
+
+
+### 2.4 Model Inference
+
+```bash
+python tools/infer/text/predict_layout.py \
+ --mode 1 \
+ --image_dir {path_to_img} \
+ --layout_algorithm LAYOUTLMV3 \
+ --config {config_path}
+```
+By default, model inference results are saved in the inference_results folder
+
+layout_res.png (Model inference visualization results)
+
+layout_results.txt (Model inference text results)
+
+### 2.5 Model Training
+
+coming soon
+
+## References
+
+
+[1] Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking. arXiv preprint arXiv:2204.08387, 2022.
diff --git a/configs/layout/layoutlmv3/README_CN.md b/configs/layout/layoutlmv3/README_CN.md
new file mode 100644
index 000000000..a58f34c3e
--- /dev/null
+++ b/configs/layout/layoutlmv3/README_CN.md
@@ -0,0 +1,108 @@
+[English](README.md) | 中文
+
+# LayoutLMv3
+
+
+> [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387)
+
+> [Original Repo](https://github.com/microsoft/unilm/tree/master/layoutlmv3)
+
+## 1. 模型描述
+
+
+不同于以往的LayoutLM系列模型,在模型架构设计上,LayoutLMv3 不依赖复杂的 CNN 或 Faster R-CNN 网络来表征图像,而是直接利用文档图像的图像块,从而大大节省了参数并避免了复杂的文档预处理(如人工标注目标区域框和文档目标检测)。简单的统一架构和训练目标使 LayoutLMv3 成为通用的预训练模型,可适用于以文本为中心和以图像为中心的文档 AI 任务。
+
+实验结果表明,LayoutLMv3在以下数据集以更少的参数量达到了更优的性能:
+- 以文本为中心的数据集:表单理解FUNSD数据集、票据理解CORD数据集以及文档视觉问答DocVQA数据集。
+- 以图像为中心的数据集:文档图像分类RVL-CDIP数据集以及文档布局分析PubLayNet数据集。
+
+LayoutLMv3 还应用了文本——图像多模态 Transformer 架构来学习跨模态表征。文本向量由词向量、词的一维位置向量和二维位置向量相加得到。文档图像的文本和其相应的二维位置信息(布局信息)则利用光学字符识别(OCR)工具抽取。因为文本的邻接词通常表达了相似的语义,LayoutLMv3 共享了邻接词的二维位置向量,而 LayoutLM 和 LayoutLMv2 的每个词则用了不同的二维位置向量。
+
+图像向量的表示通常依赖于 CNN 抽取特征图网格特征或 Faster R-CNN 提取区域特征,这些方式增加了计算开销或依赖于区域标注。因此,作者将图像块经过线性映射获得图像特征,这种图像表示方式最早在 ViT 中被提出,计算开销极小且不依赖于区域标注,有效解决了以上问题。具体来说,首先将图像缩放为统一的大小(例如224x224),然后将图像切分成固定大小的块(例如16x16),并通过线性映射获得图像特征序列,再加上可学习的一维位置向量后得到图像向量。[1]
+
+
+
+
+
+
+
+ 图1. LayoutLMv3架构图 [1]
+
+
+
+## 2. 快速开始
+
+### 2.1 环境及数据准备
+
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:---------------:|:------------:|:--------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
+
+#### 2.1.1 安装
+环境安装教程请参考MindOCR的 [installation instruction](https://github.com/mindspore-lab/mindocr#installation).
+
+#### 2.1.2 PubLayNet数据集准备
+
+PubLayNet是一个用于文档布局分析的数据集。它包含研究论文和文章的图像,以及页面中各种元素的注释,如这些研究论文图像中的“文本”、“列表”、“图形”等。该数据集是通过自动匹配PubMed Central上公开的100多万篇PDF文章的XML表示和内容而获得的。
+
+PubLayNet的训练及验证数据集可以从 [这里](https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz) 下载。
+
+```bash
+python tools/dataset_converters/convert.py \
+ --dataset_name publaynet \
+ --image_dir publaynet/ \
+ --output_path publaynet/
+```
+
+下载完成后,可以使用上述MindOCR提供的脚本将数据转换为layoutlmv3输入格式的数据类型。
+
+### 2.2 模型转换
+
+注:启动转换脚本前请安装torch
+```bash
+pip install torch
+```
+
+请下载 [layoutlmv3-base-finetuned-publaynet](https://huggingface.co/HYPJUDY/layoutlmv3-base-finetuned-publaynet) 模型到 /path/to/layoutlmv3-base-finetuned-publaynet, 然后运行:
+
+```bash
+python tools/param_converter_from_torch.py \
+ --input_path /path/to/layoutlmv3-base-finetuned-publaynet/model_final.pt \
+ --json_path configs/layout/layoutlmv3/layoutlmv3_publaynet_param_map.json \
+ --output_path /path/to/layoutlmv3-base-finetuned-publaynet/from_torch.ckpt
+```
+
+### 2.3 模型评估
+在公开基准数据集(PublayNet)上的-评估结果如下:
+
+在采用动态图模式的ascend 910*上实验结果,mindspore版本为2.3.1
+
+
+| **model name** | **cards** | **batch size** | **img/s** | **map** | **config** |
+|----------------|-----------|----------------|-----------|---------|----------------------------------------------------------------------------------------------------------------|
+| LayoutLMv3 | 1 | 1 | 2.7 | 94.3% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/layout/layoutlmv3/layoutlmv3_publaynet.yaml) |
+
+
+### 2.4 模型推理
+
+```bash
+python tools/infer/text/predict_layout.py \
+ --mode 1 \
+ --image_dir {path_to_img} \
+ --layout_algorithm LAYOUTLMV3 \
+ --config {config_path}
+```
+模型推理结果默认保存在inference_results文件夹下
+
+layout_res.png (模型推理可视化结果)
+
+layout_results.txt (模型推理文本结果)
+
+### 2.5 模型训练
+
+coming soon
+
+## 参考文献
+
+
+[1] Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking. arXiv preprint arXiv:2204.08387, 2022.
diff --git a/configs/layout/layoutlmv3/layoutlmv3_publaynet.yaml b/configs/layout/layoutlmv3/layoutlmv3_publaynet.yaml
new file mode 100644
index 000000000..9676da5ce
--- /dev/null
+++ b/configs/layout/layoutlmv3/layoutlmv3_publaynet.yaml
@@ -0,0 +1,136 @@
+system:
+ mode: 1 # 0 for graph mode, 1 for pynative mode in MindSpore
+ distribute: False
+ amp_level: "O0"
+ seed: 42
+ log_interval: 10
+ val_start_epoch: 50
+ val_while_train: True
+ drop_overflow_update: False
+
+model:
+ type: layout
+ transform: null
+ backbone:
+ name: build_layoutlmv3_fpn_backbone
+ out_features: ["layer3", "layer5", "layer7", "layer11"]
+ fpn:
+ in_features: ["layer3", "layer5", "layer7", "layer11"]
+ norm: ""
+ out_channels: 256
+ fuse_type: sum
+ neck:
+ name: RPN
+ in_features: ["p2", "p3", "p4", "p5", "p6"]
+ pre_nms_topk_train: 2000
+ pre_nms_topk_test: 1000
+ feat_channel: 256
+ anchor_generator:
+ aspect_ratios: [0.5, 1.0, 2.0]
+ anchor_sizes: [[32], [64], [128], [256], [512]]
+ strides: [4, 8, 16, 32, 64]
+ rpn_label_assignment:
+ rpn_sample_batch: 256
+ fg_fraction: 0.5
+ negative_overlap: 0.3
+ positive_overlap: 0.7
+ use_random: True
+ train_proposal:
+ min_size: 0
+ nms_thresh: 0.7
+ pre_nms_top_n: 2000
+ post_nms_top_n: 1000
+ test_proposal:
+ min_size: 0
+ nms_thresh: 0.7
+ pre_nms_top_n: 1000
+ post_nms_top_n: 1000
+ head:
+ name: CascadeROIHeads
+ mask_on: True
+ in_features: ["p2", "p3", "p4", "p5"]
+ num_classes: 5
+ bbox_loss: None
+ add_gt_as_proposals: True
+ roi_extractor:
+ featmap_strides: [4, 8, 16, 32]
+ roi_box_head:
+ cls_agnostic_bbox_reg: True
+ name: FastRCNNConvFCHead
+ conv_dims: []
+ fc_dims: [1024, 1024]
+ pooler_resolution: 7
+ pooler_sampling_ratio: 0
+ pooler_type: ROIAlignV2
+ in_channel: 256
+ out_channel: 1024
+ roi_mask_head:
+ name: MaskRCNNConvUpsampleHead
+ conv_dims: [256, 256, 256, 256, 256]
+ pooler_resolution: 14
+ pooler_sampling_ratio: 0
+ pooler_type: ROIAlignV2
+ in_channel: 256
+ roi_box_cascade_head:
+ bbox_reg_weights: [[10.0, 10.0, 5.0, 5.0], [20.0, 20.0, 10.0, 10.0], [30.0, 30.0, 15.0, 15.0]]
+ ious: [0.5, 0.6, 0.7]
+ bbox_assigner:
+ name: BBoxAssigner
+ rois_per_batch: 512
+ bg_thresh: 0.5
+ fg_thresh: 0.5
+ fg_fraction: 0.25
+ pretrained:
+
+postprocess:
+ name: Layoutlmv3Postprocess
+ conf_thres: 0.05
+ iou_thres: 0.5
+ conf_free: False
+ multi_label: True
+ time_limit: 100
+
+metric:
+ name: Layoutlmv3Metric
+ annotations_path: &annotations_path publaynet/val.json
+
+eval:
+ ckpt_load_path: "from_torch.ckpt"
+ dataset_sink_mode: False
+ dataset:
+ type: PublayNetDataset
+ dataset_path: publaynet/val.txt
+ annotations_path: *annotations_path
+ img_size: 800
+ model_name: layoutlmv3
+ transform_pipeline:
+ - func_name: letterbox
+ - func_name: label_norm
+ xyxy2xywh_: True
+ - func_name: label_pad
+ padding_size: 160
+ padding_value: -1
+ - func_name: image_normal
+ mean: [ 127.5, 127.5, 127.5 ]
+ std: [ 127.5, 127.5, 127.5 ]
+ - func_name: image_transpose
+ bgr2rgb: True
+ hwc2chw: True
+ - func_name: image_batch_pad
+ max_size: 1333
+ batch_size: &refine_batch_size 1
+ stride: 64
+ output_columns: ["image", "labels", "image_ids", "hw_ori", "hw_scale", "pad"]
+ net_input_column_index: [0, 3, 4] # input indices for network forward func in output_columns
+ meta_data_column_index: [2, 3, 4, 5] # input indices marked as label
+ loader:
+ shuffle: False
+ batch_size: *refine_batch_size
+ drop_remainder: False
+ max_rowsize: 12
+ num_workers: 1
+
+predict:
+ ckpt_load_path: "from_torch.ckpt"
+ category_dict: {1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'}
+ color_dict: {1: [255, 0, 0], 2: [0, 0, 255], 3: [0, 255, 0], 4: [0, 255, 255], 5: [255, 0, 255]}
diff --git a/configs/layout/layoutlmv3/layoutlmv3_publaynet_param_map.json b/configs/layout/layoutlmv3/layoutlmv3_publaynet_param_map.json
new file mode 100644
index 000000000..5a72fbcea
--- /dev/null
+++ b/configs/layout/layoutlmv3/layoutlmv3_publaynet_param_map.json
@@ -0,0 +1,274 @@
+{
+ "convert_map": {
+ "backbone.fpn_lateral2.weight": "backbone.fpn_lateral2.weight",
+ "backbone.fpn_lateral2.bias": "backbone.fpn_lateral2.bias",
+ "backbone.fpn_output2.weight": "backbone.fpn_output2.weight",
+ "backbone.fpn_output2.bias": "backbone.fpn_output2.bias",
+ "backbone.fpn_lateral3.weight": "backbone.fpn_lateral3.weight",
+ "backbone.fpn_lateral3.bias": "backbone.fpn_lateral3.bias",
+ "backbone.fpn_output3.weight": "backbone.fpn_output3.weight",
+ "backbone.fpn_output3.bias": "backbone.fpn_output3.bias",
+ "backbone.fpn_lateral4.weight": "backbone.fpn_lateral4.weight",
+ "backbone.fpn_lateral4.bias": "backbone.fpn_lateral4.bias",
+ "backbone.fpn_output4.weight": "backbone.fpn_output4.weight",
+ "backbone.fpn_output4.bias": "backbone.fpn_output4.bias",
+ "backbone.fpn_lateral5.weight": "backbone.fpn_lateral5.weight",
+ "backbone.fpn_lateral5.bias": "backbone.fpn_lateral5.bias",
+ "backbone.fpn_output5.weight": "backbone.fpn_output5.weight",
+ "backbone.fpn_output5.bias": "backbone.fpn_output5.bias",
+ "backbone.bottom_up.backbone.cls_token": "backbone.bottom_up.cls_token",
+ "backbone.bottom_up.backbone.pos_embed": "backbone.bottom_up.pos_embed",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.self.query.weight": "backbone.bottom_up.encoder.layer.0.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.self.query.bias": "backbone.bottom_up.encoder.layer.0.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.self.key.weight": "backbone.bottom_up.encoder.layer.0.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.self.key.bias": "backbone.bottom_up.encoder.layer.0.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.self.value.weight": "backbone.bottom_up.encoder.layer.0.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.self.value.bias": "backbone.bottom_up.encoder.layer.0.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.0.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.0.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.0.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.0.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.0.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.0.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.0.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.0.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.0.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.0.output.dense.weight": "backbone.bottom_up.encoder.layer.0.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.0.output.dense.bias": "backbone.bottom_up.encoder.layer.0.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.0.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.0.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.0.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.0.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.self.query.weight": "backbone.bottom_up.encoder.layer.1.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.self.query.bias": "backbone.bottom_up.encoder.layer.1.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.self.key.weight": "backbone.bottom_up.encoder.layer.1.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.self.key.bias": "backbone.bottom_up.encoder.layer.1.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.self.value.weight": "backbone.bottom_up.encoder.layer.1.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.self.value.bias": "backbone.bottom_up.encoder.layer.1.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.1.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.1.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.1.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.1.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.1.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.1.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.1.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.1.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.1.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.1.output.dense.weight": "backbone.bottom_up.encoder.layer.1.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.1.output.dense.bias": "backbone.bottom_up.encoder.layer.1.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.1.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.1.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.1.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.1.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.self.query.weight": "backbone.bottom_up.encoder.layer.2.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.self.query.bias": "backbone.bottom_up.encoder.layer.2.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.self.key.weight": "backbone.bottom_up.encoder.layer.2.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.self.key.bias": "backbone.bottom_up.encoder.layer.2.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.self.value.weight": "backbone.bottom_up.encoder.layer.2.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.self.value.bias": "backbone.bottom_up.encoder.layer.2.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.2.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.2.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.2.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.2.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.2.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.2.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.2.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.2.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.2.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.2.output.dense.weight": "backbone.bottom_up.encoder.layer.2.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.2.output.dense.bias": "backbone.bottom_up.encoder.layer.2.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.2.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.2.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.2.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.2.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.self.query.weight": "backbone.bottom_up.encoder.layer.3.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.self.query.bias": "backbone.bottom_up.encoder.layer.3.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.self.key.weight": "backbone.bottom_up.encoder.layer.3.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.self.key.bias": "backbone.bottom_up.encoder.layer.3.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.self.value.weight": "backbone.bottom_up.encoder.layer.3.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.self.value.bias": "backbone.bottom_up.encoder.layer.3.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.3.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.3.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.3.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.3.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.3.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.3.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.3.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.3.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.3.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.3.output.dense.weight": "backbone.bottom_up.encoder.layer.3.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.3.output.dense.bias": "backbone.bottom_up.encoder.layer.3.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.3.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.3.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.3.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.3.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.self.query.weight": "backbone.bottom_up.encoder.layer.4.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.self.query.bias": "backbone.bottom_up.encoder.layer.4.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.self.key.weight": "backbone.bottom_up.encoder.layer.4.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.self.key.bias": "backbone.bottom_up.encoder.layer.4.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.self.value.weight": "backbone.bottom_up.encoder.layer.4.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.self.value.bias": "backbone.bottom_up.encoder.layer.4.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.4.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.4.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.4.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.4.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.4.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.4.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.4.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.4.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.4.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.4.output.dense.weight": "backbone.bottom_up.encoder.layer.4.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.4.output.dense.bias": "backbone.bottom_up.encoder.layer.4.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.4.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.4.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.4.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.4.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.self.query.weight": "backbone.bottom_up.encoder.layer.5.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.self.query.bias": "backbone.bottom_up.encoder.layer.5.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.self.key.weight": "backbone.bottom_up.encoder.layer.5.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.self.key.bias": "backbone.bottom_up.encoder.layer.5.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.self.value.weight": "backbone.bottom_up.encoder.layer.5.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.self.value.bias": "backbone.bottom_up.encoder.layer.5.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.5.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.5.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.5.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.5.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.5.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.5.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.5.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.5.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.5.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.5.output.dense.weight": "backbone.bottom_up.encoder.layer.5.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.5.output.dense.bias": "backbone.bottom_up.encoder.layer.5.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.5.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.5.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.5.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.5.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.self.query.weight": "backbone.bottom_up.encoder.layer.6.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.self.query.bias": "backbone.bottom_up.encoder.layer.6.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.self.key.weight": "backbone.bottom_up.encoder.layer.6.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.self.key.bias": "backbone.bottom_up.encoder.layer.6.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.self.value.weight": "backbone.bottom_up.encoder.layer.6.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.self.value.bias": "backbone.bottom_up.encoder.layer.6.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.6.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.6.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.6.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.6.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.6.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.6.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.6.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.6.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.6.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.6.output.dense.weight": "backbone.bottom_up.encoder.layer.6.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.6.output.dense.bias": "backbone.bottom_up.encoder.layer.6.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.6.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.6.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.6.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.6.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.self.query.weight": "backbone.bottom_up.encoder.layer.7.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.self.query.bias": "backbone.bottom_up.encoder.layer.7.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.self.key.weight": "backbone.bottom_up.encoder.layer.7.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.self.key.bias": "backbone.bottom_up.encoder.layer.7.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.self.value.weight": "backbone.bottom_up.encoder.layer.7.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.self.value.bias": "backbone.bottom_up.encoder.layer.7.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.7.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.7.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.7.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.7.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.7.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.7.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.7.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.7.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.7.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.7.output.dense.weight": "backbone.bottom_up.encoder.layer.7.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.7.output.dense.bias": "backbone.bottom_up.encoder.layer.7.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.7.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.7.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.7.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.7.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.self.query.weight": "backbone.bottom_up.encoder.layer.8.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.self.query.bias": "backbone.bottom_up.encoder.layer.8.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.self.key.weight": "backbone.bottom_up.encoder.layer.8.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.self.key.bias": "backbone.bottom_up.encoder.layer.8.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.self.value.weight": "backbone.bottom_up.encoder.layer.8.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.self.value.bias": "backbone.bottom_up.encoder.layer.8.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.8.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.8.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.8.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.8.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.8.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.8.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.8.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.8.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.8.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.8.output.dense.weight": "backbone.bottom_up.encoder.layer.8.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.8.output.dense.bias": "backbone.bottom_up.encoder.layer.8.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.8.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.8.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.8.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.8.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.self.query.weight": "backbone.bottom_up.encoder.layer.9.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.self.query.bias": "backbone.bottom_up.encoder.layer.9.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.self.key.weight": "backbone.bottom_up.encoder.layer.9.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.self.key.bias": "backbone.bottom_up.encoder.layer.9.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.self.value.weight": "backbone.bottom_up.encoder.layer.9.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.self.value.bias": "backbone.bottom_up.encoder.layer.9.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.9.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.9.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.9.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.9.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.9.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.9.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.9.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.9.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.9.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.9.output.dense.weight": "backbone.bottom_up.encoder.layer.9.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.9.output.dense.bias": "backbone.bottom_up.encoder.layer.9.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.9.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.9.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.9.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.9.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.self.query.weight": "backbone.bottom_up.encoder.layer.10.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.self.query.bias": "backbone.bottom_up.encoder.layer.10.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.self.key.weight": "backbone.bottom_up.encoder.layer.10.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.self.key.bias": "backbone.bottom_up.encoder.layer.10.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.self.value.weight": "backbone.bottom_up.encoder.layer.10.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.self.value.bias": "backbone.bottom_up.encoder.layer.10.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.10.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.10.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.10.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.10.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.10.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.10.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.10.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.10.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.10.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.10.output.dense.weight": "backbone.bottom_up.encoder.layer.10.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.10.output.dense.bias": "backbone.bottom_up.encoder.layer.10.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.10.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.10.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.10.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.10.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.self.query.weight": "backbone.bottom_up.encoder.layer.11.attention.self_attention.query.weight",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.self.query.bias": "backbone.bottom_up.encoder.layer.11.attention.self_attention.query.bias",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.self.key.weight": "backbone.bottom_up.encoder.layer.11.attention.self_attention.key.weight",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.self.key.bias": "backbone.bottom_up.encoder.layer.11.attention.self_attention.key.bias",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.self.value.weight": "backbone.bottom_up.encoder.layer.11.attention.self_attention.value.weight",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.self.value.bias": "backbone.bottom_up.encoder.layer.11.attention.self_attention.value.bias",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.output.dense.weight": "backbone.bottom_up.encoder.layer.11.attention.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.output.dense.bias": "backbone.bottom_up.encoder.layer.11.attention.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.11.attention.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.11.attention.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.11.attention.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.layer.11.intermediate.dense.weight": "backbone.bottom_up.encoder.layer.11.intermediate.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.11.intermediate.dense.bias": "backbone.bottom_up.encoder.layer.11.intermediate.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.11.output.dense.weight": "backbone.bottom_up.encoder.layer.11.output.dense.weight",
+ "backbone.bottom_up.backbone.encoder.layer.11.output.dense.bias": "backbone.bottom_up.encoder.layer.11.output.dense.bias",
+ "backbone.bottom_up.backbone.encoder.layer.11.output.LayerNorm.weight": "backbone.bottom_up.encoder.layer.11.output.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.encoder.layer.11.output.LayerNorm.bias": "backbone.bottom_up.encoder.layer.11.output.LayerNorm.beta",
+ "backbone.bottom_up.backbone.encoder.fpn1.0.weight": "backbone.bottom_up.encoder.fpn1.0.weight",
+ "backbone.bottom_up.backbone.encoder.fpn1.0.bias": "backbone.bottom_up.encoder.fpn1.0.bias",
+ "backbone.bottom_up.backbone.encoder.fpn1.1.weight": "backbone.bottom_up.encoder.fpn1.1.gamma",
+ "backbone.bottom_up.backbone.encoder.fpn1.1.bias": "backbone.bottom_up.encoder.fpn1.1.beta",
+ "backbone.bottom_up.backbone.encoder.fpn1.1.running_mean": "backbone.bottom_up.encoder.fpn1.1.moving_mean",
+ "backbone.bottom_up.backbone.encoder.fpn1.1.running_var": "backbone.bottom_up.encoder.fpn1.1.moving_variance",
+ "backbone.bottom_up.backbone.encoder.fpn1.3.weight": "backbone.bottom_up.encoder.fpn1.3.weight",
+ "backbone.bottom_up.backbone.encoder.fpn1.3.bias": "backbone.bottom_up.encoder.fpn1.3.bias",
+ "backbone.bottom_up.backbone.encoder.fpn2.0.weight": "backbone.bottom_up.encoder.fpn2.0.weight",
+ "backbone.bottom_up.backbone.encoder.fpn2.0.bias": "backbone.bottom_up.encoder.fpn2.0.bias",
+ "backbone.bottom_up.backbone.patch_embed.proj.weight": "backbone.bottom_up.patch_embed.proj.weight",
+ "backbone.bottom_up.backbone.patch_embed.proj.bias": "backbone.bottom_up.patch_embed.proj.bias",
+ "backbone.bottom_up.backbone.LayerNorm.weight": "backbone.bottom_up.LayerNorm.gamma",
+ "backbone.bottom_up.backbone.LayerNorm.bias": "backbone.bottom_up.LayerNorm.beta",
+ "backbone.bottom_up.backbone.norm.weight": "backbone.bottom_up.norm.gamma",
+ "backbone.bottom_up.backbone.norm.bias": "backbone.bottom_up.norm.beta",
+ "proposal_generator.rpn_head.conv.weight": "neck.rpn_feat.rpn_conv.weight",
+ "proposal_generator.rpn_head.conv.bias": "neck.rpn_feat.rpn_conv.bias",
+ "proposal_generator.rpn_head.objectness_logits.weight": "neck.rpn_feat.rpn_rois_score.weight",
+ "proposal_generator.rpn_head.objectness_logits.bias": "neck.rpn_feat.rpn_rois_score.bias",
+ "proposal_generator.rpn_head.anchor_deltas.weight": "neck.rpn_feat.rpn_rois_delta.weight",
+ "proposal_generator.rpn_head.anchor_deltas.bias": "neck.rpn_feat.rpn_rois_delta.bias",
+ "roi_heads.box_head.0.fc1.weight": "head.box_head.0.fc1.weight",
+ "roi_heads.box_head.0.fc1.bias": "head.box_head.0.fc1.bias",
+ "roi_heads.box_head.0.fc2.weight": "head.box_head.0.fc2.weight",
+ "roi_heads.box_head.0.fc2.bias": "head.box_head.0.fc2.bias",
+ "roi_heads.box_head.1.fc1.weight": "head.box_head.1.fc1.weight",
+ "roi_heads.box_head.1.fc1.bias": "head.box_head.1.fc1.bias",
+ "roi_heads.box_head.1.fc2.weight": "head.box_head.1.fc2.weight",
+ "roi_heads.box_head.1.fc2.bias": "head.box_head.1.fc2.bias",
+ "roi_heads.box_head.2.fc1.weight": "head.box_head.2.fc1.weight",
+ "roi_heads.box_head.2.fc1.bias": "head.box_head.2.fc1.bias",
+ "roi_heads.box_head.2.fc2.weight": "head.box_head.2.fc2.weight",
+ "roi_heads.box_head.2.fc2.bias": "head.box_head.2.fc2.bias",
+ "roi_heads.box_predictor.0.cls_score.weight": "head.box_predictor.0.cls_score.weight",
+ "roi_heads.box_predictor.0.cls_score.bias": "head.box_predictor.0.cls_score.bias",
+ "roi_heads.box_predictor.0.bbox_pred.weight": "head.box_predictor.0.bbox_pred.weight",
+ "roi_heads.box_predictor.0.bbox_pred.bias": "head.box_predictor.0.bbox_pred.bias",
+ "roi_heads.box_predictor.1.cls_score.weight": "head.box_predictor.1.cls_score.weight",
+ "roi_heads.box_predictor.1.cls_score.bias": "head.box_predictor.1.cls_score.bias",
+ "roi_heads.box_predictor.1.bbox_pred.weight": "head.box_predictor.1.bbox_pred.weight",
+ "roi_heads.box_predictor.1.bbox_pred.bias": "head.box_predictor.1.bbox_pred.bias",
+ "roi_heads.box_predictor.2.cls_score.weight": "head.box_predictor.2.cls_score.weight",
+ "roi_heads.box_predictor.2.cls_score.bias": "head.box_predictor.2.cls_score.bias",
+ "roi_heads.box_predictor.2.bbox_pred.weight": "head.box_predictor.2.bbox_pred.weight",
+ "roi_heads.box_predictor.2.bbox_pred.bias": "head.box_predictor.2.bbox_pred.bias",
+ "roi_heads.mask_head.mask_fcn1.weight": "head.mask_head.mask_fcn1.weight",
+ "roi_heads.mask_head.mask_fcn1.bias": "head.mask_head.mask_fcn1.bias",
+ "roi_heads.mask_head.mask_fcn2.weight": "head.mask_head.mask_fcn2.weight",
+ "roi_heads.mask_head.mask_fcn2.bias": "head.mask_head.mask_fcn2.bias",
+ "roi_heads.mask_head.mask_fcn3.weight": "head.mask_head.mask_fcn3.weight",
+ "roi_heads.mask_head.mask_fcn3.bias": "head.mask_head.mask_fcn3.bias",
+ "roi_heads.mask_head.mask_fcn4.weight": "head.mask_head.mask_fcn4.weight",
+ "roi_heads.mask_head.mask_fcn4.bias": "head.mask_head.mask_fcn4.bias",
+ "roi_heads.mask_head.deconv.weight": "head.mask_head.deconv.weight",
+ "roi_heads.mask_head.deconv.bias": "head.mask_head.deconv.bias",
+ "roi_heads.mask_head.predictor.weight": "head.mask_head.predictor.weight",
+ "roi_heads.mask_head.predictor.bias": "head.mask_head.predictor.bias"
+ },
+ "transpose_map": {
+ }
+}
diff --git a/configs/layout/yolov8/images/example_docx.png b/configs/layout/yolov8/images/example_docx.png
new file mode 100644
index 000000000..a63d21330
Binary files /dev/null and b/configs/layout/yolov8/images/example_docx.png differ
diff --git a/configs/rec/crnn/README.md b/configs/rec/crnn/README.md
index df938a673..c98301c46 100644
--- a/configs/rec/crnn/README.md
+++ b/configs/rec/crnn/README.md
@@ -3,10 +3,9 @@ English | [中文](https://github.com/mindspore-lab/mindocr/blob/main/configs/re
# CRNN
-> [An End-to-End Trainable Neural Network for Image-based Sequence
-Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs/1507.05717)
+> [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs/1507.05717)
-## 1. Introduction
+## Introduction
Convolutional Recurrent Neural Network (CRNN) integrates CNN feature extraction and RNN sequence modeling as well as transcription into a unified framework.
@@ -22,81 +21,20 @@ As shown in the architecture graph (Figure 1), CRNN firstly extracts a feature s
Figure 1. Architecture of CRNN [1]
-## 2. Results
-
+## Requirements
-### Training Perf.
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:--------------:|:-------------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
-According to our experiments, the training (following the steps in [Model Training](#32-model-training)) performance and evaluation (following the steps in [Model Evaluation](#33-model-evaluation)) accuracy are as follows:
-
- Performance tested on ascend 910 with graph mode
+## Quick Start
+### Preparation
-
-
- | **Model** | **Device Card** | **Backbone** | **Train Dataset** | **Model Params** | **Batch size per card** | **Graph train 8P (s/epoch)** | **Graph train 8P (ms/step)** | **Graph train 8P (FPS)** | **Avg Eval Accuracy** | **Recipe** | **Download** |
- | :-----: |:----------:| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
- | CRNN | 8P | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
- | CRNN | 8P | ResNet34_vd | MJ+ST | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
-
-
- - Detailed accuracy results for each benchmark dataset (IC03, IC13, IC15, IIIT, SVT, SVTP, CUTE):
-
-
- | **Model** | **Backbone** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** |
- | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
- | CRNN | VGG7 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
- | CRNN | ResNet34_vd | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
-
-
-
-
-
- Performance tested on ascend 910* with graph mode
-
-
- | **Model** | **Device Card** | **Backbone** | **Train Dataset** | **Model Params** | **Batch size per card** | **Graph train 8P (s/epoch)** | **Graph train 8P (ms/step)** | **Graph train 8P (FPS)** | **Avg Eval Accuracy** | **Recipe** | **Download** |
- | :-----: |:---------------:| :-----: | :-----: | :-----: | :-----: | :-----: |:----------------------------:|:------------------------:|:---------------------:| :-----: | :-----: |
- | CRNN | 8P | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 14.76 | 8672.09 | 81.31% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/crnn/crnn_vgg7-6faf1b2d-910v2.ckpt)|
-
-
-
-
-### Inference Perf.
-
-The inference performance is tested on Mindspore Lite, please take a look at [Mindpore Lite Inference](#6-mindspore-lite-inference) for more details.
-
-
-
-| Device | Env | Model | Backbone | Params | Test Dataset | Batch size | Graph infer 1P (FPS) |
-| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |
-| Ascend310P | Lite2.0 | CRNN | ResNet34_vd | 24.48 M | IC15 | 1 | 361.09 |
-| Ascend310P | Lite2.0 | CRNN | ResNet34_vd | 24.48 M | SVT | 1 | 274.67 |
-
-
-
-**Notes:**
-- To reproduce the result on other contexts, please ensure the global batch size is the same.
-- The characters supported by model are lowercase English characters from a to z and numbers from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary).
-- The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [Dataset Download & Dataset Usage](#312-dataset-download) section.
-- The input Shapes of MindIR of CRNN_VGG7 and CRNN_ResNet34_vd are both (1, 3, 32, 100).
-
-
-## 3. Quick Start
-### 3.1 Preparation
-
-#### 3.1.1 Installation
+#### Installation
Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
-#### 3.1.2 Dataset Download
+#### Dataset Download
Please download lmdb dataset for traininig and evaluation from [here](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) (ref: [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)). There're several zip files:
- `data_lmdb_release.zip` contains the **entire** datasets including training data, validation data and evaluation data.
- `training/` contains two datasets: [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText (ST)](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c)
@@ -143,7 +81,7 @@ data_lmdb_release/
└── lock.mdb
```
-#### 3.1.3 Dataset Usage
+#### Dataset Usage
Here we used the datasets under `training/` folders for **training**, and the union dataset `validation/` for validation. After training, we used the datasets under `evaluation/` to evluation model accuracy.
@@ -262,7 +200,7 @@ eval:
...
```
-#### 3.1.4 Check YAML Config Files
+#### Check YAML Config Files
Apart from the dataset setting, please also check the following important args: `system.distribute`, `system.val_while_train`, `common.batch_size`, `train.ckpt_save_dir`, `train.dataset.dataset_root`, `train.dataset.data_dir`, `train.dataset.label_file`,
`eval.ckpt_load_path`, `eval.dataset.dataset_root`, `eval.dataset.data_dir`, `eval.dataset.label_file`, `eval.loader.batch_size`. Explanations of these important args:
@@ -305,7 +243,7 @@ eval:
- As the global batch size (batch_size x num_devices) is important for reproducing the result, please adjust `batch_size` accordingly to keep the global batch size unchanged for a different number of NPUs, or adjust the learning rate linearly to a new global batch size.
-### 3.2 Model Training
+### Model Training
* Distributed Training
@@ -329,7 +267,7 @@ python tools/train.py --config configs/rec/crnn/crnn_resnet34.yaml
The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg `train.ckpt_save_dir`. The default directory is `./tmp_rec`.
-### 3.3 Model Evaluation
+### Model Evaluation
To evaluate the accuracy of the trained model, you can use `eval.py`. Please set the checkpoint path to the arg `eval.ckpt_load_path` in the yaml config file, set the evaluation dataset path to the arg `eval.dataset.data_dir`, set `system.distribute` to be False, and then run:
@@ -343,7 +281,7 @@ Similarly, the accuracy of the trained model can be evaluated using multiple eva
python tools/benchmarking/multi_dataset_eval.py --config configs/rec/crnn/crnn_resnet34.yaml
```
-## 4. Character Dictionary
+## Character Dictionary
### Default Setting
@@ -370,7 +308,7 @@ To use a specific dictionary, set the parameter `common.character_dict_path` to
- Remember to check the value of `dataset->transform_pipeline->RecCTCLabelEncode->lower` in the configuration yaml. Set it to False if you prefer case-sensitive encoding.
-## 5. Chinese Text Recognition Model Training
+## Chinese Text Recognition Model Training
Currently, this model supports multilingual recognition and provides pre-trained models for different languages. Details are as follows:
@@ -388,61 +326,72 @@ To train with the prepared datsets and config file, please run:
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/crnn/crnn_resnet34_ch.yaml
```
-### Results and Pretrained Weights
+### Training with Custom Datasets
+You can train models for different languages with your own custom datasets. Loading the pretrained Chinese model to finetune on your own dataset usually yields better results than training from scratch. Please refer to the tutorial [Training Recognition Network with Custom Datasets](../../../docs/en/tutorials/training_recognition_custom_dataset.md).
-After training, evaluation results on the benchmark test set are as follows, where we also provide the model config and pretrained weights.
-
+## Performance
-| **Model** | **Language** | **Context** |**Backbone** | **Scene** | **Web** | **Document** | **Train T.** | **FPS** | **Recipe** | **Download** |
-| :-----: | :-----: | :--------: | :--------: | :--------: | :--------: | :--------: | :---------: | :--------: | :---------: | :-----------: |
-| CRNN | Chinese | D910x4-MS1.10-G | ResNet34_vd | 60.45% | 65.95% | 97.68% | 647 s/epoch | 1180 | [crnn_resnet34_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c-105bccb2.mindir) |
-
+### General Purpose Chinese Models
-**Notes:**
-- The input shape for exported MindIR file in the download link is (1, 3, 32, 320).
+Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
-### Training with Custom Datasets
-You can train models for different languages with your own custom datasets. Loading the pretrained Chinese model to finetune on your own dataset usually yields better results than training from scratch. Please refer to the tutorial [Training Recognition Network with Custom Datasets](../../../docs/en/tutorials/training_recognition_custom_dataset.md).
+*coming soon*
+Experiments are tested on ascend 910 with mindspore 2.3.1 graph mode.
-## 6. MindSpore Lite Inference
+| **model name** | **backbone** | **cards** | **batch size** | **language** | **jit level** | **graph compile** | **ms/step** | **img/s** | **scene** | **web** | **document** | **recipe** | **weight** |
+| :------------: | :----------: | :-------: | :------------: | :----------: | :-----------: | :---------------: | :---------: | :-------: | :-------: | :-----: | :----------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| CRNN | ResNet34_vd | 4 | 256 | Chinese | O2 | 203.48 s | 38.01 | 1180 | 60.45% | 65.95% | 97.68% | [https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c-105bccb2.mindir) |
-To inference with MindSpot Lite on Ascend 310, please refer to the tutorial [MindOCR Inference](../../../docs/en/inference/inference_tutorial.md). In short, the whole process consists of the following steps:
-**1. Model Export**
+> The input shape for exported MindIR file in the download link is (1, 3, 32, 320).
-Please [download](#2-results) the exported MindIR file first, or refer to the [Model Export](../../../docs/en/inference/convert_tutorial.md#1-model-export) tutorial and use the following command to export the trained ckpt model to MindIR file:
+### Specific Purpose Models
-```shell
-python tools/export.py --model_name_or_config crnn_resnet34 --data_shape 32 100 --local_ckpt_path /path/to/local_ckpt.ckpt
-# or
-python tools/export.py --model_name_or_config configs/rec/crnn/crnn_resnet34.yaml --data_shape 32 100 --local_ckpt_path /path/to/local_ckpt.ckpt
-```
+#### Training Performance
-The `data_shape` is the model input shape of height and width for MindIR file. The shape value of MindIR in the download link can be found in [Notes](#2-results) under results table.
+Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
+| **model name** | **backbone** | **train dataset** | **params(M)** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **accuracy** | **recipe** | **weight** |
+| :------------: | :----------: | :---------------: | :-----------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :----------: | :----------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------: |
+| CRNN | VGG7 | MJ+ST | 8.72 | 8 | 16 | O2 | 94.36 s | 14.76 | 8672.09 | 81.31% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/crnn/crnn_vgg7-6faf1b2d-910v2.ckpt) |
-**2. Environment Installation**
-Please refer to [Environment Installation](../../../docs/en/inference/environment.md) tutorial to configure the MindSpore Lite inference environment.
+Experiments are tested on ascend 910 with mindspore 2.3.1 graph mode.
-**3. Model Conversion**
-Please refer to [Model Conversion](../../../docs/en/inference/convert_tutorial.md#2-mindspore-lite-mindir-convert),
-and use the `converter_lite` tool for offline conversion of the MindIR file.
+| **model name** | **backbone** | **train dataset** | **params(M)** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **accuracy** | **recipe** | **weight** |
+| :------------: | :----------: | :---------------: | :-----------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :----------: | :--------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| CRNN | VGG7 | MJ+ST | 8.72 | 8 | 16 | O2 | 67.18 s | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
+| CRNN | ResNet34_vd | MJ+ST | 24.48 | 8 | 64 | O2 | 201.54 s | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
-**4. Inference**
+Detailed accuracy results for each benchmark dataset (IC03, IC13, IC15, IIIT, SVT, SVTP, CUTE):
-Assuming that you obtain output.mindir after model conversion, go to the `deploy/py_infer` directory, and use the following command for inference:
-```shell
-python infer.py \
- --input_images_dir=/your_path_to/test_images \
- --rec_model_path=your_path_to/output.mindir \
- --rec_model_name_or_config=../../configs/rec/crnn/crnn_resnet34.yaml \
- --res_save_dir=results_dir
-```
+| **model name** | **backbone** | **cards** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **average** |
+| :------------: | :----------: | :-------: | :----------: | :----------: | :----------: | :-----------: | :-----------: | :-----------: | :-------------: | :-----: | :------: | :--------: | :---------: |
+| CRNN | VGG7 | 1 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
+| CRNN | ResNet34_vd | 1 | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
+
+
+#### Inference Performance
+
+Experiments are tested on ascend 310P with mindspore lite 2.3.1 graph mode.
+
+| model name | backbone | test dataset | params(M) | cards | batch size | **jit level** | **graph compile** | img/s |
+| :--------: | :---------: | :----------: | :-------: | :---: | :--------: | :-----------: | :---------------: | :----: |
+| CRNN | ResNet34_vd | IC15 | 24.48 | 1 | 1 | O2 | 10.46 s | 361.09 |
+| CRNN | ResNet34_vd | SVT | 24.48 | 1 | 1 | O2 | 10.31 s | 274.67 |
+
+
+### Notes
+
+- To reproduce the result on other contexts, please ensure the global batch size is the same.
+- The characters supported by model are lowercase English characters from a to z and numbers from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary).
+- The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [Dataset Download & Dataset Usage](#312-dataset-download) section.
+- The input Shapes of MindIR of CRNN_VGG7 and CRNN_ResNet34_vd are both (1, 3, 32, 100).
+
## References
diff --git a/configs/rec/crnn/README_CN.md b/configs/rec/crnn/README_CN.md
index 26406bb9c..f565f3eab 100644
--- a/configs/rec/crnn/README_CN.md
+++ b/configs/rec/crnn/README_CN.md
@@ -3,10 +3,9 @@
# CRNN
-> [An End-to-End Trainable Neural Network for Image-based Sequence
-Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs/1507.05717)
+> [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs/1507.05717)
-## 1. 模型描述
+## 模型描述
卷积递归神经网络 (CRNN) 将 CNN 特征提取和 RNN 序列建模以及转录集成到一个统一的框架中。
@@ -22,82 +21,22 @@ Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs
图1. CRNN架构图 [1]
-## 2. 评估结果
-
-### 训练端
+## 配套版本
-根据我们的实验,训练([模型训练](#32-模型训练))性能和精度评估([模型评估](#33-模型评估))结果如下:
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:--------------:|:--------------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
-
- 在采用图模式的ascend 910上测试性能
-
- | **模型** | **环境配置** | **骨干网络** | **训练集** | **参数量** | **单卡批量** | **图模式8卡训练 (s/epoch)** | **图模式8卡训练 (ms/step)** | **图模式8卡训练 (FPS)** | **平均评估精度** | **配置文件** | **模型权重下载** |
- | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
- | CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
- | CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
-
+## 快速开始
+### 环境及数据准备
- - 在各个基准数据集(IC03,IC13,IC15,IIIT,SVT,SVTP,CUTE)上的准确率:
-
-
-
- | **模型** | **骨干网络** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **平均准确率** |
- | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
- | CRNN | VGG7 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
- | CRNN | ResNet34_vd | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
-
-
-
-
- 在采用图模式的ascend 910*上测试性能
-
-
-
- | **模型** | **设备卡数** | **骨干网络** | **训练集** | **参数量** | **单卡批量** | **图模式8卡训练 (s/epoch)** | **Graph train 8P (ms/step)** | **图模式8卡训练 (FPS)** | **平均评估精度** | **配置文件** | **模型权重下载** |
- |:------:|:--------:|:--------:|:-------:|:-------:|:--------:| :-----: |:----------------------------:|:------------------------:|:---------------------:| :-----: | :-----: |
- | CRNN | 8P | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 14.76 | 8672.09 | 81.31% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/crnn/crnn_vgg7-6faf1b2d-910v2.ckpt) |
-
-
-
-### 推理端
-
-推理端的性能测试主要是基于Mindspore Lite,详细的操作介绍可参考 [Mindspore Lite推理](#6-mindspore-lite-推理)。
-
-
-
-| 设备 | 编译环境 | 模型 | 骨干网络 | 参数量 | 测试集 | 批量大小 | 图模式单卡推理 (FPS) |
-| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |
-| Ascend310P | Lite2.0 | CRNN | ResNet34_vd | 24.48 M | IC15 | 1 | 361.09 |
-| Ascend310P | Lite2.0 | CRNN | ResNet34_vd | 24.48 M | SVT | 1 | 274.67 |
-
-
-
-
-**注意:**
-- 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。
-- 模型所能识别的字符都是默认的设置,即所有英文小写字母a至z及数字0至9,详细请看[4. 字符词典](#4-字符词典)
-- 模型都是从头开始训练的,无需任何预训练。关于训练和测试数据集的详细介绍,请参考[数据集下载及使用](#312-数据集下载)章节。
-- CRNN_VGG7和CRNN_ResNet34_vd的MindIR导出时的输入Shape均为(1, 3, 32, 100)。
-
-
-## 3. 快速开始
-### 3.1 环境及数据准备
-
-#### 3.1.1 安装
+#### 安装
环境安装教程请参考MindOCR的 [installation instruction](https://github.com/mindspore-lab/mindocr#installation).
-#### 3.1.2 数据集下载
+#### 数据集下载
LMDB格式的训练及验证数据集可以从[这里](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) (出处: [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here))下载。连接中的文件包含多个压缩文件,其中:
- `data_lmdb_release.zip` 包含了**完整**的一套数据集,有训练集(training/),验证集(validation/)以及测试集(evaluation)。
- `training.zip` 包括两个数据集,分别是 [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/) 和 [SynthText (ST)](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c)
@@ -106,7 +45,7 @@ LMDB格式的训练及验证数据集可以从[这里](https://www.dropbox.com/s
- `validation.zip`: 与 data_lmdb_release.zip 中的validation/ 一样。
- `evaluation.zip`: 与 data_lmdb_release.zip 中的evaluation/ 一样。
-#### 3.1.3 数据集使用
+#### 数据集使用
解压文件后,数据文件夹结构如下:
@@ -261,7 +200,7 @@ eval:
...
```
-#### 3.1.4 检查配置文件
+#### 检查配置文件
除了数据集的设置,请同时重点关注以下变量的配置:`system.distribute`, `system.val_while_train`, `common.batch_size`, `train.ckpt_save_dir`, `train.dataset.dataset_root`, `train.dataset.data_dir`, `train.dataset.label_file`,
`eval.ckpt_load_path`, `eval.dataset.dataset_root`, `eval.dataset.data_dir`, `eval.dataset.label_file`, `eval.loader.batch_size`。说明如下:
@@ -304,7 +243,7 @@ eval:
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率。
-### 3.2 模型训练
+### 模型训练
* 分布式训练
@@ -328,7 +267,7 @@ python tools/train.py --config configs/rec/crnn/crnn_resnet34.yaml
训练结果(包括checkpoint、每个epoch的性能和曲线图)将被保存在yaml配置文件的`train.ckpt_save_dir`参数配置的目录下,默认为`./tmp_rec`。
-### 3.3 模型评估
+### 模型评估
若要评估已训练模型的准确性,可以使用`eval.py`。请将yaml配置文件的参数`eval.ckpt_load_path`设置为模型checkpoint的文件路径,参数`eval.dataset.data_dir`设置为评估数据集目录,参数`system.distribute`设置为False,然后运行:
@@ -342,7 +281,7 @@ python tools/eval.py --config configs/rec/crnn/crnn_resnet34.yaml
python tools/benchmarking/multi_dataset_eval.py --config configs/rec/crnn/crnn_resnet34.yaml
```
-## 4. 字符词典
+## 字符词典
### 默认设置
@@ -370,7 +309,7 @@ Mindocr内置了一部分字典,均放在了 `mindocr/utils/dict/` 位置,
- 请记住检查配置文件中的 `dataset->transform_pipeline->RecCTCLabelEncode->lower` 参数的值。如果词典中有大小写字母而且想区分大小写的话,请将其设置为 False。
-## 5. 中文识别模型训练
+## 中文识别模型训练
目前,CRNN模型支持中英文字识别并提供相应的预训练权重。详细内容如下
@@ -388,60 +327,74 @@ Mindocr内置了一部分字典,均放在了 `mindocr/utils/dict/` 位置,
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/crnn/crnn_resnet34_ch.yaml
```
-### 评估结果和预训练权重
-模型训练完成后,在测试集不同场景上的准确率评估结果如下。相应的模型配置和预训练权重可通过表中链接下载。
+### 使用自定义数据集进行训练
+您可以在自定义的数据集基于提供的预训练权重进行微调训练, 以在特定场景获得更高的识别准确率,具体步骤请参考文档 [使用自定义数据集训练识别网络](../../../docs/zh/tutorials/training_recognition_custom_dataset_CN.md)。
+
+## 性能表现
-
+### 通用泛化中文模型
-| **模型** | **语种** | **环境配置** | **骨干网络** | **街景类** | **网页类** | **文档类** | **训练时间** | **FPS** | **配置文件** | **模型权重下载** |
-| :-----: | :-----: | :-------: |:--------: | :--------: | :--------: | :--------: | :---------: |:--------: | :---------: | :-----------: |
-| CRNN | 中文 | D910x4-MS1.10-G | ResNet34_vd | 60.45% | 65.95% | 97.68% | 647 s/epoch | 1180 | [crnn_resnet34_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_ch.yaml) |[ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c-105bccb2.mindir) |
-
+在采用图模式的ascend 910*上实验结果,mindspore版本为2.3.1
-**注释:**
+*即将到来*
-- MindIR导出时的输入Shape为(1, 3, 32, 320).
+在采用图模式的ascend 910上实验结果,mindspore版本为2.3.1
+| **model name** | **backbone** | **cards** | **batch size** | **language** | **jit level** | **graph compile** | **ms/step** | **img/s** | **scene** | **web** | **document** | **recipe** | **weight** |
+|:--------------:|:------------:|:--------------:|:-----------------:|:------------:|:---------:|:-----------------:|:---------:|:-------:|:------------:|:-----------:|:---------:|:------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
+| CRNN | ResNet34_vd | 4| 256| Chinese | O2 | 203.48 s | 38.01 | 1180 | 60.45% | 65.95% | 97.68% | [https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_ch-7a342e3c-105bccb2.mindir) |
-### 使用自定义数据集进行训练
-您可以在自定义的数据集基于提供的预训练权重进行微调训练, 以在特定场景获得更高的识别准确率,具体步骤请参考文档 [使用自定义数据集训练识别网络](../../../docs/zh/tutorials/training_recognition_custom_dataset_CN.md)。
+> 链接中模型的MindIR导出时的输入Shape为`(1, 3, 32, 320)`.
+### 细分领域模型
-## 6. MindSpore Lite 推理
+#### 训练性能表现
-请参考[MindOCR 推理](../../../docs/zh/inference/inference_tutorial.md)教程,基于MindSpore Lite在Ascend 310上进行模型的推理,包括以下步骤:
+在采用图模式的ascend 910*上实验结果,mindspore版本为2.3.1
-**1. 模型导出**
-请先[下载](#2-评估结果)已导出的MindIR文件,或者参考[模型导出](../../../docs/zh/inference/convert_tutorial.md#1-模型导出)教程,使用以下命令将训练完成的ckpt导出为MindIR文件:
+| **model name** | **backbone** | **train dataset** | **params(M)** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **accuracy** | **recipe** | **weight** |
+| :------------: | :----------: | :---------------: | :-----------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :----------: | :----------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------: |
+| CRNN | VGG7 | MJ+ST | 8.72 | 8 | 16 | O2 | 94.36 s | 14.76 | 8672.09 | 81.31% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/crnn/crnn_vgg7-6faf1b2d-910v2.ckpt) |
-```shell
-python tools/export.py --model_name_or_config crnn_resnet34 --data_shape 32 100 --local_ckpt_path /path/to/local_ckpt.ckpt
-# or
-python tools/export.py --model_name_or_config configs/rec/crnn/crnn_resnet34.yaml --data_shape 32 100 --local_ckpt_path /path/to/local_ckpt.ckpt
-```
-其中,`data_shape`是导出MindIR时的模型输入Shape的height和width,下载链接中MindIR对应的shape值见[注释](#2-评估结果)。
+在采用图模式的ascend 910上实验结果,mindspore版本为2.3.1
-**2. 环境搭建**
-请参考[环境安装](../../../docs/zh/inference/environment.md)教程,配置MindSpore Lite推理运行环境。
+| **model name** | **backbone** | **train dataset** | **params(M)** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **accuracy** | **recipe** | **weight** |
+| :------------: | :----------: | :---------------: | :-----------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :----------: | :--------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| CRNN | VGG7 | MJ+ST | 8.72 | 8 | 16 | O2 | 67.18 s | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
+| CRNN | ResNet34_vd | MJ+ST | 24.48 | 8 | 64 | O2 | 201.54 s | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
-**3. 模型转换**
-请参考[模型转换](../../../docs/zh/inference/convert_tutorial.md#2-mindspore-lite-mindir-转换)教程,使用`converter_lite`工具对MindIR模型进行离线转换。
+在各个基准数据集(IC03,IC13,IC15,IIIT,SVT,SVTP,CUTE)上的准确率:
-**4. 执行推理**
+| **model name** | **backbone** | **cards** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **average** |
+| :------------: | :----------: | :-------: | :----------: | :----------: | :----------: | :-----------: | :-----------: | :-----------: | :-------------: | :-----: | :------: | :--------: | :---------: |
+| CRNN | VGG7 | 1 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
+| CRNN | ResNet34_vd | 1 | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
-假设在模型转换后得到output.mindir文件,在`deploy/py_infer`目录下使用以下命令进行推理:
-```shell
-python infer.py \
- --input_images_dir=/your_path_to/test_images \
- --rec_model_path=your_path_to/output.mindir \
- --rec_model_name_or_config=../../configs/rec/crnn/crnn_resnet34.yaml \
- --res_save_dir=results_dir
-```
+
+
+
+#### 推理性能表现
+
+在采用图模式的ascend 310P上实验结果,mindspore lite版本为2.3.1
+
+| model name | backbone | test dataset | params(M) | cards | batch size | **jit level** | **graph compile** | img/s |
+| :--------: | :---------: | :----------: | :-------: | :---: | :--------: | :-----------: | :---------------: | :----: |
+| CRNN | ResNet34_vd | IC15 | 24.48 | 1 | 1 | O2 | 10.46 s | 361.09 |
+| CRNN | ResNet34_vd | SVT | 24.48 | 1 | 1 | O2 | 10.31 s | 274.67 |
+
+
+
+### 注意
+- 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。
+- 模型所能识别的字符都是默认的设置,即所有英文小写字母a至z及数字0至9,详细请看[4. 字符词典](#4-字符词典)
+- 模型都是从头开始训练的,无需任何预训练。关于训练和测试数据集的详细介绍,请参考[数据集下载及使用](#312-数据集下载)章节。
+- CRNN_VGG7和CRNN_ResNet34_vd的MindIR导出时的输入Shape均为(1, 3, 32, 100)。
+
## 参考文献
diff --git a/configs/rec/svtr/README.md b/configs/rec/svtr/README.md
index bf6cacb2a..3dd8142f3 100644
--- a/configs/rec/svtr/README.md
+++ b/configs/rec/svtr/README.md
@@ -5,7 +5,7 @@ English | [中文](https://github.com/mindspore-lab/mindocr/blob/main/configs/re
> [SVTR: Scene Text Recognition with a Single Visual Model](https://arxiv.org/abs/2205.00159)
-## 1. Introduction
+## Introduction
Dominant scene text recognition models commonly contain two building blocks, a visual model for feature extraction and a sequence model for text transcription. This hybrid architecture, although accurate, is complex and less efficient. This paper proposes a Single Visual model for Scene Text recognition within the patch-wise image tokenization framework, which dispenses with the sequential modeling entirely. The method, termed SVTR, firstly decomposes an image text into small patches named character components. Afterward, hierarchical stages are recurrently carried out by component-level mixing, merging and/or combining. Global and local mixing blocks are devised to perceive the inter-character and intra-character patterns, leading to a multi-grained character component perception. Thus, characters are recognized by a simple linear prediction. Experimental results on both English and Chinese scene text recognition tasks demonstrate the effectiveness of SVTR. SVTR-L (Large) achieves highly competitive accuracy in English and outperforms existing methods by a large margin in Chinese, while running faster. In addition, SVTR-T (Tiny) is an effective and much smaller model, which shows appealing speed at inference. [1]
@@ -19,59 +19,25 @@ Dominant scene text recognition models commonly contain two building blocks, a v
-## 2. Results
-
-### Accuracy
+## Requirements
-According to our experiments, the evaluation results on public benchmark datasets (IC03, IC13, IC15, IIIT, SVT, SVTP, CUTE) is as follow:
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:--------------:|:-------------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
-
- Performance tested on ascend 910 with graph mode
-
- | **Model** | **Device Card** | **Avg Accuracy** | **Train T.** | **FPS** | **Recipe** | **Download** |
- | :-----: |:---------------:| :--------------: | :----------: | :--------: | :--------: |:----------: |
- | SVTR-Tiny | 4P | 90.23% | 3638 s/epoch | 4560 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3-86ece8c8.mindir) |
- | SVTR-Tiny-8P | 8P | 90.32% | 1646 s/epoch | 9840 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6.ckpt) \| [mindir](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6-255191ef.mindir) |
-
- Detailed accuracy results for each benchmark dataset
-
+## Quick Start
+### Preparation
- | **Model** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** |
- | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
- | SVTR-Tiny | 95.70% | 95.50% | 95.33% | 93.99% | 83.60% | 79.83% | 94.70% | 91.96% | 85.58% | 86.11% | 90.23% |
- | SVTR-Tiny-8P | 95.93% | 95.62% | 95.33% | 93.89% | 84.32% | 80.55% | 94.33% | 90.57% | 86.20% | 86.46% | 90.32% |
-
-
-
-**Notes:**
-- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G-graph mode or F-pynative mode with ms function. For example, D910x4-MS1.10-G is for training on 4 pieces of Ascend 910 NPU using graph mode based on Minspore version 1.10.
-- To reproduce the result on other contexts, please ensure the global batch size is the same.
-- The characters supported by model are lowercase English characters from a to z and numbers from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary).
-- The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [Dataset Download & Dataset Usage](#312-dataset-download) section.
-- The input Shapes of MindIR of RARE is (1, 3, 64, 256).
-
-
-## 3. Quick Start
-### 3.1 Preparation
-
-#### 3.1.1 Installation
+#### Installation
Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
-#### 3.1.2 Dataset Preparation
+#### Dataset Preparation
-##### 3.1.2.1 MJSynth, validation and evaluation dataset
+##### MJSynth, validation and evaluation dataset
Part of the lmdb dataset for training and evaluation can be downloaded from [here](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) (ref: [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)). There're several zip files:
- `data_lmdb_release.zip` contains the datasets including training data, validation data and evaluation data.
- `training/` contains two datasets: [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText (ST)](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c). *Here we use **MJSynth only**.*
@@ -80,7 +46,7 @@ Part of the lmdb dataset for training and evaluation can be downloaded from [her
- `validation.zip`: same as the validation/ within data_lmdb_release.zip
- `evaluation.zip`: same as the evaluation/ within data_lmdb_release.zip
-##### 3.1.2.2 SynthText dataset
+##### SynthText dataset
For `SynthText`, we do not use the given LMDB dataset in `data_lmdb_release.zip`, since it only contains part of the cropped images. Please download the raw dataset from [here](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c) and prepare the LMDB dataset using the following command
@@ -94,7 +60,7 @@ python tools/dataset_converters/convert.py \
```
the `ST_full` contained the full cropped images of SynthText in LMDB data format. Please replace the `ST` folder with the `ST_full` folder.
-#### 3.1.3 Dataset Usage
+#### Dataset Usage
Finally, the data structure should like this.
@@ -246,7 +212,7 @@ eval:
...
```
-#### 3.1.4 Check YAML Config Files
+#### Check YAML Config Files
Apart from the dataset setting, please also check the following important args: `system.distribute`, `system.val_while_train`, `common.batch_size`, `train.ckpt_save_dir`, `train.dataset.dataset_root`, `train.dataset.data_dir`, `train.dataset.label_file`,
`eval.ckpt_load_path`, `eval.dataset.dataset_root`, `eval.dataset.data_dir`, `eval.dataset.label_file`, `eval.loader.batch_size`. Explanations of these important args:
@@ -288,7 +254,7 @@ eval:
- As the global batch size (batch_size x num_devices) is important for reproducing the result, please adjust `batch_size` accordingly to keep the global batch size unchanged for a different number of NPUs, or adjust the learning rate linearly to a new global batch size.
-### 3.2 Model Training
+### Model Training
* Distributed Training
@@ -312,7 +278,7 @@ python tools/train.py --config configs/rec/svtr/svtr_tiny.yaml
The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg `ckpt_save_dir`. The default directory is `./tmp_rec`.
-### 3.3 Model Evaluation
+### Model Evaluation
To evaluate the accuracy of the trained model, you can use `eval.py`. Please set the checkpoint path to the arg `ckpt_load_path` in the `eval` section of yaml config file, set `distribute` to be False, and then run:
@@ -320,7 +286,7 @@ To evaluate the accuracy of the trained model, you can use `eval.py`. Please set
python tools/eval.py --config configs/rec/svtr/svtr_tiny.yaml
```
-## 4. Character Dictionary
+## Character Dictionary
### Default Setting
@@ -347,7 +313,7 @@ To use a specific dictionary, set the parameter `character_dict_path` to the pat
- Remember to check the value of `dataset->transform_pipeline->RecAttnLabelEncode->lower` in the configuration yaml. Set it to False if you prefer case-sensitive encoding.
-## 5. Chinese Text Recognition Model Training
+## Chinese Text Recognition Model Training
Currently, this model supports multilingual recognition and provides pre-trained models for different languages. Details are as follows:
@@ -365,58 +331,54 @@ To train with the prepared datsets and config file, please run:
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/svtr/svtr_tiny_ch.yaml
```
-### Results and Pretrained Weights
-
-After training, evaluation results on the benchmark test set are as follows, where we also provide the model config and pretrained weights.
-
-
-| **Model** | **Language** | **Context** | **Scene** | **Web** | **Document** | **Train T.** | **FPS** | **Recipe** | **Download** |
-| :-----: | :-----: | :--------: | :--------: | :--------: | :--------: | :---------: | :--------: | :---------: | :-----------: |
-| SVTR-Tiny | Chinese | D910x4-MS1.10-G | 65.93% | 69.64% | 98.01% | 647 s/epoch | 1580 | [svtr_tiny_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4-3e495768.mindir) |
-
### Training with Custom Datasets
You can train models for different languages with your own custom datasets. Loading the pretrained Chinese model to finetune on your own dataset usually yields better results than training from scratch. Please refer to the tutorial [Training Recognition Network with Custom Datasets](../../../docs/en/tutorials/training_recognition_custom_dataset.md).
-## 6. MindSpore Lite Inference
+## Performance
-To inference with MindSpot Lite on Ascend 310, please refer to the tutorial [MindOCR Inference](../../../docs/en/inference/inference_tutorial.md). In short, the whole process consists of the following steps:
+### General Purpose Chinese Models
-**1. Model Export**
+Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
-Please [download](#2-results) the exported MindIR file first, or refer to the [Model Export](../../../docs/en/inference/convert_tutorial.md#1-model-export) tutorial and use the following command to export the trained ckpt model to MindIR file:
+*coming soon*
-```shell
-python tools/export.py --model_name_or_config svtr_tiny --data_shape 64 256 --local_ckpt_path /path/to/local_ckpt.ckpt
-# or
-python tools/export.py --model_name_or_config configs/rec/svtr/svtr_tiny.yaml --data_shape 64 256 --local_ckpt_path /path/to/local_ckpt.ckpt
-```
+Experiments are tested on ascend 910 with mindspore 2.3.1 graph mode.
-The `data_shape` is the model input shape of height and width for MindIR file. The shape value of MindIR in the download link can be found in [Notes](#2-results) under results table.
+| **model name** | **cards** | **batch size** | **languages** | **jit level** | **graph compile** | **ms/step** | **img/s** | **scene** | **web** | **document** | **recipe** | **weight** |
+| :------------: | :-------: | :------------: | :-----------: | :-----------: | :---------------: | :---------: | :-------: | :-------: | :-----: | :----------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| SVTR-Tiny | 4 | 256 | Chinese | O2 | 235.1 s | 37.75 | 1580 | 65.93% | 69.64% | 98.01% | [svtr_tiny_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4-3e495768.mindir) |
+### Specific Purpose Models
-**2. Environment Installation**
+Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
-Please refer to [Environment Installation](../../../docs/en/inference/environment.md) tutorial to configure the MindSpore Lite inference environment.
+*coming soon*
-**3. Model Conversion**
+Experiments are tested on ascend 910 with mindspore 2.3.1 graph mode.
-Please refer to [Model Conversion](../../../docs/en/inference/convert_tutorial.md#2-mindspore-lite-mindir-convert),
-and use the `converter_lite` tool for offline conversion of the MindIR file.
+| **model name** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **accuracy** | **recipe** | **weight** |
+| :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :----------: | :-------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| SVTR-Tiny | 4 | 512 | O2 | 226.86 s | 49.38 | 4560 | 90.23% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3-86ece8c8.mindir) |
+| SVTR-Tiny-8P | 8 | 512 | O2 | 230.74 s | 55.16 | 9840 | 90.32% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6.ckpt) \| [mindir](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6-255191ef.mindir) |
-**4. Inference**
-Assuming that you obtain output.mindir after model conversion, go to the `deploy/py_infer` directory, and use the following command for inference:
+Detailed accuracy results for each benchmark dataset:
-```shell
-python infer.py \
- --input_images_dir=/your_path_to/test_images \
- --rec_model_path=your_path_to/output.mindir \
- --rec_model_name_or_config=../../configs/rec/svtr/svtr_tiny.yaml \
- --res_save_dir=results_dir
-```
+
+| **model name** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **average** |
+| :------------: | :----------: | :----------: | :----------: | :-----------: | :-----------: | :-----------: | :-------------: | :-----: | :------: | :--------: | :---------: |
+| SVTR-Tiny | 95.70% | 95.50% | 95.33% | 93.99% | 83.60% | 79.83% | 94.70% | 91.96% | 85.58% | 86.11% | 90.23% |
+| SVTR-Tiny-8P | 95.93% | 95.62% | 95.33% | 93.89% | 84.32% | 80.55% | 94.33% | 90.57% | 86.20% | 86.46% | 90.32% |
+
+
+### Notes
+- To reproduce the result on other contexts, please ensure the global batch size is the same.
+- The characters supported by model are lowercase English characters from a to z and numbers from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary).
+- The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [Dataset Download & Dataset Usage](#312-dataset-download) section.
+- The input Shapes of MindIR of RARE is (1, 3, 64, 256).
## References
diff --git a/configs/rec/svtr/README_CN.md b/configs/rec/svtr/README_CN.md
index cc306abce..cea591b23 100644
--- a/configs/rec/svtr/README_CN.md
+++ b/configs/rec/svtr/README_CN.md
@@ -5,7 +5,7 @@
> [SVTR: Scene Text Recognition with a Single Visual Model](https://arxiv.org/abs/2205.00159)
-## 1. 模型描述
+## 模型描述
主流的场景文字识别模型通常包含两个基本构建部分,一个视觉模型用于特征提取和一个序列模型用于文本转换。虽然这种混合架构非常准确,但也相对复杂和低效。因此,作者提出了一种新的方法:单一视觉模型。这种方法在图形标记化(image tokenization)框架下建立,完全抛弃了顺序的建模方式。作者的方法将图像划分成小的补丁,并通过逐层组件级别的混合、合并和/或组合进行操作以实现层级。作者还设计了全局和局部混合块以识别多颗粒度的字符组件模式,从而进行字符识别。作者实验了英文和中文场景文本识别任务,结果表明作者的模型SVTR是有效的。作者的大型模型SVTR-L在英文方面能提供高准确度的性能,在中文方面也表现优越且速度更快。作者的小型模型SVTR-T在推断方面也有很好的表现。[1]
@@ -19,58 +19,23 @@
图1. SVTR结构 [1]
-## 2. 评估结果
-
+## 配套版本
-### 精度结果
+| mindspore | ascend driver | firmware | cann toolkit/kernel |
+|:----------:|:--------------:|:-------------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
-根据我们的实验,在公开基准数据集(IC03,IC13,IC15,IIIT,SVT,SVTP,CUTE)上的评估结果如下:
-
-| **模型** | **环境配置** | **平均准确率** | **训练时间** | **FPS** | **配置文件** | **模型权重下载** |
-|:------------:|:---------------:|:---------:|:------------:|:-------:|:---------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
-| SVTR-Tiny | D910x4-MS1.10-G | 90.23% | 3638 s/epoch | 4560 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3-86ece8c8.mindir) |
-| SVTR-Tiny-8P | D910x8-MS2.2-G | 90.32% | 1646 s/epoch | 9840 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6.ckpt) \| [mindir](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6-255191ef.mindir |
+## 快速开始
+### 环境及数据准备
-
-
-
-
- 在各个基准数据集上的准确率
-
- | **模型** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **平均准确率** |
- |:------------:|:------------:|:------------:|:------------:|:-------------:|:-------------:|:-------------:|:---------------:|:-------:|:--------:|:----------:|:---------:|
- | SVTR-Tiny | 95.70% | 95.50% | 95.33% | 93.99% | 83.60% | 79.83% | 94.70% | 91.96% | 85.58% | 86.11% | 90.23% |
- | SVTR-Tiny-8P | 95.93% | 95.62% | 95.33% | 93.89% | 84.32% | 80.55% | 94.33% | 90.57% | 86.20% | 86.46% | 90.32% |
-
-
-
-
-**注意:**
-- 环境配置:训练的环境配置表示为 {处理器}x{处理器数量}-{MS模式},其中 Mindspore 模式可以是 G-graph 模式或 F-pynative 模式。例如,D910x4-MS1.10-G 用于使用图形模式在4张昇腾910 NPU上依赖Mindspore1.10版本进行训练。
-- 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。
-- 模型所能识别的字符都是默认的设置,即所有英文小写字母a至z及数字0至9,详细请看[4. 字符词典](#4-字符词典)
-- 模型都是从头开始训练的,无需任何预训练。关于训练和测试数据集的详细介绍,请参考[数据集下载及使用](#312-数据集下载)章节。
-- SVTR的MindIR导出时的输入Shape均为(1, 3, 64, 256)。
-
-## 3. 快速开始
-### 3.1 环境及数据准备
-
-#### 3.1.1 安装
+#### 安装
环境安装教程请参考MindOCR的 [installation instruction](https://github.com/mindspore-lab/mindocr#installation).
-#### 3.1.2 数据集准备
+#### 数据集准备
-##### 3.1.2.1 MJSynth, 验证集和测试集
+##### MJSynth, 验证集和测试集
部分LMDB格式的训练及验证数据集可以从[这里](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) (出处: [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here))下载。连接中的文件包含多个压缩文件,其中:
- `data_lmdb_release.zip` 包含了了部分数据集,有训练集(training/),验证集(validation/)以及测试集(evaluation)。
- `training.zip` 包括两个数据集,分别是 [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/) 和 [SynthText (ST)](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c)。 这里我们只使用**MJSynth**。
@@ -79,7 +44,7 @@ Table Format:
- `validation.zip`: 与 data_lmdb_release.zip 中的validation/ 一样。
- `evaluation.zip`: 与 data_lmdb_release.zip 中的evaluation/ 一样。
-##### 3.1.2.2 SynthText dataset
+##### SynthText数据集
我们不使用`data_lmdb_release.zip`提供的`SynthText`数据, 因为它只包含部分切割下来的图片。请从[此处](https://academictorrents.com/details/2dba9518166cbd141534cbf381aa3e99a087e83c)下载原始数据, 并使用以下命令转换成LMDB格式
@@ -93,7 +58,7 @@ python tools/dataset_converters/convert.py \
```
`ST_full` 包含了所有已切割的图片,以LMDB格式储存。 请将 `ST` 文件夹换成 `ST_full` 文件夹。
-#### 3.1.3 数据集使用
+#### 数据集使用
最终数据文件夹结构如下:
@@ -286,7 +251,7 @@ eval:
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小。
-### 3.2 模型训练
+### 模型训练
* 分布式训练
@@ -310,7 +275,7 @@ python tools/train.py --config configs/rec/svtr/svtr_tiny.yaml
训练结果(包括checkpoint、每个epoch的性能和曲线图)将被保存在yaml配置文件的`ckpt_save_dir`参数配置的目录下,默认为`./tmp_rec`。
-### 3.3 模型评估
+### 模型评估
若要评估已训练模型的准确性,可以使用`eval.py`。请在yaml配置文件的`eval`部分将参数`ckpt_load_path`设置为模型checkpoint的文件路径,设置`distribute`为False,然后运行:
@@ -318,7 +283,7 @@ python tools/train.py --config configs/rec/svtr/svtr_tiny.yaml
python tools/eval.py --config configs/rec/svtr/svtr_tiny.yaml
```
-## 4. 字符词典
+## 字符词典
### 默认设置
@@ -345,7 +310,7 @@ Mindocr内置了一部分字典,均放在了 `mindocr/utils/dict/` 位置,
- 您可以通过将配置文件中的参数 `use_space_char` 设置为 True 来包含空格字符。
- 请记住检查配置文件中的 `dataset->transform_pipeline->RecAttnLabelEncode->lower` 参数的值。如果词典中有大小写字母而且想区分大小写的话,请将其设置为 False。
-## 5. 中文识别模型训练
+## 中文识别模型训练
目前,SVTR模型支持中英文字识别并提供相应的预训练权重。详细内容如下
@@ -363,55 +328,55 @@ Mindocr内置了一部分字典,均放在了 `mindocr/utils/dict/` 位置,
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/svtr/svtr_tiny_ch.yaml
```
-### 评估结果和预训练权重
-模型训练完成后,在测试集不同场景上的准确率评估结果如下。相应的模型配置和预训练权重可通过表中链接下载。
-
-
-| **模型** | **语种** | **环境配置** | **街景类** | **网页类** | **文档类** | **训练时间** | **FPS** | **配置文件** | **模型权重下载** |
-| :-----: | :-----: | :-------: | :--------: | :--------: | :--------: | :---------: |:--------: | :---------: | :-----------: |
-| SVTR-Tiny | 中文 | D910x4-MS1.10-G | 65.93% | 69.64% | 98.01% | 647 s/epoch | 1580 | [svtr_tiny_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4.ckpt) \| [mindir]() |
-
### 使用自定义数据集进行训练
您可以在自定义的数据集基于提供的预训练权重进行微调训练, 以在特定场景获得更高的识别准确率,具体步骤请参考文档 [使用自定义数据集训练识别网络](../../../docs/zh/tutorials/training_recognition_custom_dataset_CN.md)。
+## 性能表现
-## 6. MindSpore Lite 推理
+### 通用泛化中文模型
-请参考[MindOCR 推理](../../../docs/zh/inference/inference_tutorial.md)教程,基于MindSpore Lite在Ascend 310上进行模型的推理,包括以下步骤:
+在采用图模式的ascend 910*上实验结果,mindspore版本为2.3.1
-**1. 模型导出**
+*即将到来*
-请先[下载](#2-评估结果)已导出的MindIR文件,或者参考[模型导出](../../../docs/zh/inference/convert_tutorial.md#1-模型导出)教程,使用以下命令将训练完成的ckpt导出为MindIR文件:
+在采用图模式的ascend 910上实验结果,mindspore版本为2.3.1
-```shell
-python tools/export.py --model_name_or_config svtr_tiny --data_shape 64 256 --local_ckpt_path /path/to/local_ckpt.ckpt
-# or
-python tools/export.py --model_name_or_config configs/rec/svtr/svtr_tiny.yaml --data_shape 64 256 --local_ckpt_path /path/to/local_ckpt.ckpt
-```
-其中,`data_shape`是导出MindIR时的模型输入Shape的height和width,下载链接中MindIR对应的shape值见[注释](#2-评估结果)。
+| **model name** | **cards** | **batch size** | **languages** | **jit level** | **graph compile** | **ms/step** | **img/s** | **scene** | **web** | **document** | **recipe** | **weight** |
+| :------------: | :-------: | :------------: | :-----------: | :-----------: | :---------------: | :---------: | :-------: | :-------: | :-----: | :----------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| SVTR-Tiny | 4 | 256 | Chinese | O2 | 235.1 s | 37.75 | 1580 | 65.93% | 69.64% | 98.01% | [svtr_tiny_ch.yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_ch.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny_ch-2ee6ade4-3e495768.mindir) |
-**2. 环境搭建**
+### 细分领域模型
-请参考[环境安装](../../../docs/zh/inference/environment.md)教程,配置MindSpore Lite推理运行环境。
+在采用图模式的ascend 910*上实验结果,mindspore版本为2.3.1
-**3. 模型转换**
+*即将到来*
-请参考[模型转换](../../../docs/zh/inference/convert_tutorial.md#2-mindspore-lite-mindir-转换)教程,使用`converter_lite`工具对MindIR模型进行离线转换。
+在采用图模式的ascend 910上实验结果,mindspore版本为2.3.1
-**4. 执行推理**
+| **model name** | **cards** | **batch size** | **jit level** | **graph compile** | **ms/step** | **img/s** | **accuracy** | **recipe** | **weight** |
+| :------------: | :-------: | :------------: | :-----------: | :---------------: | :---------: | :-------: | :----------: | :-------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| SVTR-Tiny | 4 | 512 | O2 | 226.86 s | 49.38 | 4560 | 90.23% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3-86ece8c8.mindir) |
+| SVTR-Tiny-8P | 8 | 512 | O2 | 230.74 s | 55.16 | 9840 | 90.32% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6.ckpt) \| [mindir](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6-255191ef.mindir) |
-假设在模型转换后得到output.mindir文件,在`deploy/py_infer`目录下使用以下命令进行推理:
-```shell
-python infer.py \
- --input_images_dir=/your_path_to/test_images \
- --rec_model_path=your_path_to/output.mindir \
- --rec_model_name_or_config=../../configs/rec/svtr/svtr_tiny.yaml \
- --res_save_dir=results_dir
-```
+在各个基准数据集上的准确率
+
+| **model name** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **average** |
+| :------------: | :----------: | :----------: | :----------: | :-----------: | :-----------: | :-----------: | :-------------: | :-----: | :------: | :--------: | :---------: |
+| SVTR-Tiny | 95.70% | 95.50% | 95.33% | 93.99% | 83.60% | 79.83% | 94.70% | 91.96% | 85.58% | 86.11% | 90.23% |
+| SVTR-Tiny-8P | 95.93% | 95.62% | 95.33% | 93.89% | 84.32% | 80.55% | 94.33% | 90.57% | 86.20% | 86.46% | 90.32% |
+
+
+
+**注意:**
+- 环境配置:训练的环境配置表示为 {处理器}x{处理器数量}-{MS模式},其中 Mindspore 模式可以是 G-graph 模式或 F-pynative 模式。例如,D910x4-MS1.10-G 用于使用图形模式在4张昇腾910 NPU上依赖Mindspore1.10版本进行训练。
+- 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。
+- 模型所能识别的字符都是默认的设置,即所有英文小写字母a至z及数字0至9,详细请看[4. 字符词典](#4-字符词典)
+- 模型都是从头开始训练的,无需任何预训练。关于训练和测试数据集的详细介绍,请参考[数据集下载及使用](#312-数据集下载)章节。
+- SVTR的MindIR导出时的输入Shape均为(1, 3, 64, 256)。
## 参考文献
diff --git a/deploy/py_infer/src/data_process/preprocess/transforms/layout_transforms.py b/deploy/py_infer/src/data_process/preprocess/transforms/layout_transforms.py
index 898619c3c..a3ebd2aed 100644
--- a/deploy/py_infer/src/data_process/preprocess/transforms/layout_transforms.py
+++ b/deploy/py_infer/src/data_process/preprocess/transforms/layout_transforms.py
@@ -2,7 +2,7 @@
import numpy as np
-def letterbox(scaleup):
+def letterbox(scaleup, model_name=""):
def func(data):
image = data["image"]
hw_ori = data["raw_img_shape"]
@@ -17,7 +17,10 @@ def func(data):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if model_name == "layoutlmv3":
+ r = max(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ else:
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
@@ -28,11 +31,12 @@ def func(data):
dw, dh = dw / 2, dh / 2 # divide padding into 2 sides
hw_pad = np.array([dh, dw])
- if shape[::-1] != new_unpad: # resize
- image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
- image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ if model_name != "layoutlmv3":
+ if shape[::-1] != new_unpad: # resize
+ image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
data["image"] = image
data["image_ids"] = 0
diff --git a/mindocr/data/layout_dataset.py b/mindocr/data/layout_dataset.py
index 54abf3f2e..49fc72014 100644
--- a/mindocr/data/layout_dataset.py
+++ b/mindocr/data/layout_dataset.py
@@ -70,6 +70,7 @@ def __init__(
self.img_formats = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
self.help_url = "https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
+ self.model_name = kwargs.get("model_name", "")
self.img_size = img_size
self.augment = augment
self.rect = rect
@@ -295,10 +296,14 @@ def load_image(self, index):
img = cv2.imread(path) # BGR
assert img is not None, "Image Not Found " + path
h_ori, w_ori = img.shape[:2] # orig hw
- r = self.img_size / max(h_ori, w_ori) # resize image to img_size
- if r != 1: # always resize down, only resize up if training with augmentation
- interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
- img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp)
+ if self.model_name == "layoutlmv3":
+ r = self.img_size / min(h_ori, w_ori)
+ img = cv2.resize(img, (int(round(w_ori * r)), int(round(h_ori * r))), interpolation=cv2.INTER_LINEAR)
+ else:
+ r = self.img_size / max(h_ori, w_ori) # resize image to img_size
+ if r != 1: # always resize down, only resize up if training with augmentation
+ interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
+ img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp)
return img, np.array([h_ori, w_ori]) # img, hw_original
else:
@@ -363,23 +368,28 @@ def letterbox(self, image, labels, hw_ori, new_shape, scaleup=False, color=(114,
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if self.model_name == "layoutlmv3":
+ r = max(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ else:
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
- dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ dw, dh = abs(new_shape[1] - new_unpad[0]), abs(new_shape[0] - new_unpad[1]) # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
hw_pad = np.array([dh, dw])
- if shape[::-1] != new_unpad: # resize
- image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
- image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ if self.model_name != "layoutlmv3":
+ if shape[::-1] != new_unpad: # resize
+ image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
+
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
# convert labels
if labels.size: # normalized xywh to pixel xyxy format
@@ -413,6 +423,25 @@ def image_norm(self, image, labels, scale=255.0):
image /= scale
return image, labels
+ def image_normal(self, image, labels, mean, std):
+ image = (image - mean) / std
+ image = image.astype(np.float32, copy=False)
+ return image, labels
+
+ def image_batch_pad(self, image, labels, max_size=None):
+ image_size = image.shape[-2:]
+ if max_size is None:
+ max_size = np.array(image_size)
+ else:
+ max_size = np.array([max_size, max_size])
+
+ h_pad = int(max_size[0] - image_size[0])
+ w_pad = int(max_size[1] - image_size[1])
+
+ padding_size = ((0, 0), (0, h_pad), (0, w_pad))
+ batched_imgs = np.pad(image, padding_size, mode="constant", constant_values=0)
+ return batched_imgs, labels
+
def image_transpose(self, image, labels, bgr2rgb=True, hwc2chw=True):
if bgr2rgb:
image = image[:, :, ::-1]
diff --git a/mindocr/data/transforms/layoutlm_transforms.py b/mindocr/data/transforms/layoutlm_transforms.py
index 4b1095a9e..de3aa951a 100644
--- a/mindocr/data/transforms/layoutlm_transforms.py
+++ b/mindocr/data/transforms/layoutlm_transforms.py
@@ -46,6 +46,33 @@ def __call__(self, data):
return data
+class ImageStridePad:
+ """
+ image stride pad
+ """
+
+ def __init__(self, stride=32, max_size=None, **kwargs):
+ self.stride = stride
+ self.max_size = max_size
+
+ def __call__(self, data):
+ img = data["image"]
+ image_size = img.shape[-2:]
+ if self.max_size is None:
+ max_size = np.array(image_size)
+ else:
+ max_size = np.array(self.max_size)
+
+ max_size = (max_size + (self.stride - 1)) // self.stride * self.stride
+ h_pad = int(max_size[0] - image_size[0])
+ w_pad = int(max_size[1] - image_size[1])
+
+ padding_size = ((0, 0), (0, h_pad), (0, w_pad))
+ img = np.pad(img, padding_size, mode="constant", constant_values=0)
+ data["image"] = img
+ return data
+
+
class VQATokenLabelEncode:
"""
Label encode for NLP VQA methods
diff --git a/mindocr/metrics/builder.py b/mindocr/metrics/builder.py
index fd765547a..38563b71d 100644
--- a/mindocr/metrics/builder.py
+++ b/mindocr/metrics/builder.py
@@ -2,7 +2,7 @@
from .cls_metrics import *
from .det_metrics import *
from .kie_metrics import VQAReTokenMetric, VQASerTokenMetric
-from .layout_metrics import YOLOv8Metric
+from .layout_metrics import *
from .rec_metrics import *
from .table_metrics import *
diff --git a/mindocr/metrics/layout_metrics.py b/mindocr/metrics/layout_metrics.py
index 58fc33138..42ca23d0d 100644
--- a/mindocr/metrics/layout_metrics.py
+++ b/mindocr/metrics/layout_metrics.py
@@ -1,7 +1,7 @@
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
-__all__ = ["YOLOv8Metric"]
+__all__ = ["YOLOv8Metric", "Layoutlmv3Metric"]
class YOLOv8Metric(object):
@@ -27,3 +27,7 @@ def eval(self):
def clear(self):
self.result_dicts = list()
+
+
+class Layoutlmv3Metric(YOLOv8Metric):
+ """Compute the mean average precision."""
diff --git a/mindocr/models/backbones/layoutlmv3/configuration.py b/mindocr/models/backbones/layoutlmv3/configuration.py
index 93243ddb5..75ee9e604 100644
--- a/mindocr/models/backbones/layoutlmv3/configuration.py
+++ b/mindocr/models/backbones/layoutlmv3/configuration.py
@@ -3,11 +3,11 @@
@dataclass
class LayoutLMv3PretrainedConfig:
- def __init__(self, use_float16=False):
+ def __init__(self, use_float16=False, **kwargs):
pretrained_config = {
"use_float16": use_float16,
"fast_qkv": False,
- "vocab_size": 250002,
+ "vocab_size": kwargs.get("vocab_size", 250002),
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
diff --git a/mindocr/models/backbones/layoutlmv3/layoutlmv3.py b/mindocr/models/backbones/layoutlmv3/layoutlmv3.py
index 1e1bc1f9b..1a1f52516 100644
--- a/mindocr/models/backbones/layoutlmv3/layoutlmv3.py
+++ b/mindocr/models/backbones/layoutlmv3/layoutlmv3.py
@@ -1,12 +1,16 @@
import collections
+import math
import numpy as np
+from addict import Dict
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common import dtype as mstype
+from mindspore.common.initializer import HeUniform
from mindocr.models.backbones._registry import register_backbone, register_backbone_class
+from ..layoutxlm.visual_backbone import FPN, LastLevelMaxPool, ShapeSpec
from ..transformer_common.layer import (
LayoutXLMAttention,
LayoutXLMEmbeddings,
@@ -50,6 +54,8 @@ def construct(self, pixel_values: Tensor, position_embedding: Tensor = None):
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
position_embedding = position_embedding.transpose(0, 3, 1, 2)
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
+ # There is a difference in accuracy between MindSpore's Bicubic mode and Torch,
+ # and the interface needs to be updated
position_embedding = ops.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
embeddings = embeddings + position_embedding
@@ -213,22 +219,114 @@ def __init__(self, config):
class LayoutLMv3Encoder(LayoutXLMEncoder):
- def __init__(self, config):
+ def __init__(self, config, detection=False, out_features=None):
super().__init__(config)
+ self.detection = detection
+ self.out_features = out_features
self.layer = nn.CellList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
+ if self.detection:
+ self.gradient_checkpointing = True
+ embed_dim = self.config.hidden_size
+ self.out_indices = [int(name[5:]) for name in self.out_features]
+ self.fpn1 = nn.SequentialCell(
+ nn.Conv2dTranspose(embed_dim, embed_dim, kernel_size=2, stride=2, has_bias=True),
+ # nn.SyncBatchNorm(embed_dim),
+ nn.BatchNorm2d(embed_dim),
+ nn.GELU(),
+ nn.Conv2dTranspose(embed_dim, embed_dim, kernel_size=2, stride=2, has_bias=True)
+ )
+
+ self.fpn2 = nn.SequentialCell(
+ nn.Conv2dTranspose(embed_dim, embed_dim, kernel_size=2, stride=2, has_bias=True)
+ )
+
+ self.fpn3 = nn.Identity()
+
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+
+ def construct(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ bbox=None,
+ position_ids=None,
+ Hp=None,
+ Wp=None
+ ):
+ all_hidden_states = () if output_hidden_states else None
+
+ rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None
+ rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
+
+ if self.detection:
+ feat_out = {}
+ j = 0
+
+ hidden_save = dict()
+ hidden_save["input_hidden_states"] = hidden_states
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = None
+ past_key_value = None
+ # gradient_checkpointing is set as False here so we remove some codes here
+ hidden_save["input_attention_mask"] = attention_mask
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ hidden_save["{}_data".format(i)] = hidden_states
+
+ if self.detection and i in self.out_indices:
+ xp = hidden_states[:, -Hp * Wp:, :].permute(0, 2, 1).reshape(len(hidden_states), -1, Hp, Wp)
+ feat_out[self.out_features[j]] = self.ops[j](xp.contiguous())
+ j += 1
+
+ if self.detection:
+ return feat_out
+
+ return hidden_states, hidden_save
+
@register_backbone_class
class LayoutLMv3Model(nn.Cell):
- def __init__(self, config):
+ def __init__(self, config, detection=False, out_features=None):
super().__init__(config)
self.config = config
+ self.detection = detection
+ self.out_features = out_features
self.num_hidden_layers = config.num_hidden_layers
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
self.patch_size = config.patch_size
self.use_float16 = config.use_float16
self.dense_dtype = mstype.float32
+ if self.num_hidden_layers <= 12:
+ self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+ else:
+ self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
+ self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
if self.use_float16 is True:
self.dense_dtype = mstype.float16
self.min = finfo(self.dense_dtype)
@@ -256,7 +354,7 @@ def __init__(self, config):
self.norm = nn.LayerNorm((config.hidden_size,), epsilon=1e-6)
- self.encoder = LayoutLMv3Encoder(config)
+ self.encoder = LayoutLMv3Encoder(config, detection=detection, out_features=out_features)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@@ -298,15 +396,21 @@ def calculate_visual_bbox(self, dtype, batch_size):
return visual_bbox
def visual_embeddings(self, pixel_values):
- embeddings = self.patch_embed(pixel_values)
+ if self.detection:
+ embeddings = self.patch_embed(pixel_values,
+ self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
+ else:
+ embeddings = self.patch_embed(pixel_values)
# add [CLS] token
batch_size, seq_len, _ = embeddings.shape
cls_tokens = self.cls_token.broadcast_to((batch_size, -1, -1))
+ if self.pos_embed is not None and self.detection:
+ cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
embeddings = ops.cat((cls_tokens, embeddings), axis=1)
# add position embeddings
- if self.pos_embed is not None:
+ if self.pos_embed is not None and not self.detection:
embeddings = embeddings + self.pos_embed
embeddings = self.pos_drop(embeddings)
@@ -382,6 +486,14 @@ def _convert_head_mask_to_5d(self, head_mask: Tensor, num_hidden_layers: int):
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self.out_features
+ }
+
def construct(
self,
input_ids=None, # input_ids
@@ -450,7 +562,10 @@ def construct(
inputs_embeds=inputs_embeds,
)
final_bbox = final_position_ids = None
+ Hp = Wp = None
if pixel_values is not None:
+ patch_size = 16
+ Hp, Wp = int(pixel_values.shape[2] / patch_size), int(pixel_values.shape[3] / patch_size)
visual_embeddings = self.visual_embeddings(pixel_values)
visual_embeddings_shape = visual_embeddings.shape
visual_attention_mask = ops.ones((batch_size, visual_embeddings_shape[1]), dtype=mstype.int64)
@@ -510,15 +625,133 @@ def construct(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ Hp=Hp,
+ Wp=Wp
)
+ if self.detection:
+ return encoder_outputs
+
sequence_output = encoder_outputs[0]
return (sequence_output,) + encoder_outputs[1:]
+class FPNForLayout(FPN):
+ def __init__(self,
+ bottom_up,
+ in_features,
+ out_channels,
+ norm="",
+ top_block=None,
+ fuse_type="sum",
+ square_pad=0):
+ super(FPN, self).__init__()
+ assert in_features, in_features
+
+ input_shapes = bottom_up.output_shape()
+ strides = [input_shapes[f].stride for f in in_features]
+ in_channels_per_feature = [input_shapes[f].channels for f in in_features]
+
+ lateral_convs = []
+ output_convs = []
+
+ use_bias = norm == ""
+ for idx, in_channels in enumerate(in_channels_per_feature):
+ lateral_conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ weight_init=HeUniform(negative_slope=1),
+ has_bias=use_bias,
+ bias_init="zeros")
+ output_conv = nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ weight_init=HeUniform(negative_slope=1),
+ has_bias=use_bias,
+ bias_init="zeros",
+ pad_mode='pad')
+ stage = int(math.log2(strides[idx]))
+ self.insert_child_to_cell("fpn_lateral{}".format(stage), lateral_conv)
+ self.insert_child_to_cell("fpn_output{}".format(stage), output_conv)
+
+ lateral_convs.append(lateral_conv)
+ output_convs.append(output_conv)
+
+ self.lateral_convs = nn.CellList(lateral_convs[::-1])
+ self.output_convs = nn.CellList(output_convs[::-1])
+
+ self.top_block = top_block
+ self.in_features = tuple(in_features)
+ self.bottom_up = bottom_up
+
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
+ if self.top_block is not None:
+ for s in range(stage, stage + self.top_block.num_levels):
+ self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
+
+ self._out_features = list(self._out_feature_strides.keys())
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
+ self._size_divisibility = strides[-1]
+ self._square_pad = square_pad
+ self._fuse_type = fuse_type
+
+ self.out_channels = out_channels
+
+ def construct(self, **x):
+ bottom_up_features = self.bottom_up(**x)
+
+ results = []
+ bottom_up_feature = bottom_up_features.get(self.in_features[-1])
+ prev_features = self.lateral_convs[0](bottom_up_feature)
+ results.append(self.output_convs[0](prev_features))
+
+ for idx, (lateral_conv, output_conv) in enumerate(zip(self.lateral_convs, self.output_convs)):
+ if idx > 0:
+ features = self.in_features[-idx - 1]
+ features = bottom_up_features[features]
+ old_shape = list(prev_features.shape)[2:]
+ new_size = tuple([2 * i for i in old_shape])
+ top_down_features = ops.ResizeNearestNeighbor(size=new_size)(prev_features)
+ lateral_features = lateral_conv(features)
+ prev_features = lateral_features + top_down_features
+ if self._fuse_type == "avg":
+ prev_features /= 2
+ results.insert(0, output_conv(prev_features))
+ if self.top_block is not None:
+ if self.top_block.in_feature in bottom_up_features:
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
+ else:
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
+ results.extend(self.top_block(top_block_in_feature))
+
+ assert len(self._out_features) == len(results)
+
+ return {f: res for f, res in zip(self._out_features, results)}
+
+
@register_backbone
def layoutlmv3(use_float16: bool = True, **kwargs):
- pretrained_config = LayoutLMv3PretrainedConfig(use_float16)
+ pretrained_config = LayoutLMv3PretrainedConfig(use_float16, **kwargs)
model = LayoutLMv3Model(pretrained_config)
return model
+
+
+@register_backbone
+def build_layoutlmv3_fpn_backbone(use_float16: bool = False, **kwargs):
+ pretrained_config = LayoutLMv3PretrainedConfig(use_float16, **kwargs)
+ pretrained_config.has_spatial_attention_bias = False
+ pretrained_config.has_relative_attention_bias = False
+ pretrained_config.text_embed = False
+ cfg = Dict(kwargs)
+ bottom_up = LayoutLMv3Model(pretrained_config, detection=True, out_features=cfg.out_features)
+ backbone = FPNForLayout(
+ bottom_up=bottom_up,
+ in_features=cfg.fpn.in_features,
+ out_channels=cfg.fpn.out_channels,
+ norm=cfg.fpn.norm,
+ top_block=LastLevelMaxPool(),
+ fuse_type=cfg.fpn.fuse_type
+ )
+ return backbone
diff --git a/mindocr/models/backbones/layoutxlm/visual_backbone.py b/mindocr/models/backbones/layoutxlm/visual_backbone.py
index 763824227..4155c6212 100644
--- a/mindocr/models/backbones/layoutxlm/visual_backbone.py
+++ b/mindocr/models/backbones/layoutxlm/visual_backbone.py
@@ -40,7 +40,8 @@ def __init__(self):
self.in_feature = "p5"
def construct(self, x):
- return [ops.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
+ dtype = x.dtype
+ return [ops.max_pool2d(x.astype(ms.float16), kernel_size=1, stride=2, padding=0).astype(dtype)]
class FPN(nn.Cell):
diff --git a/mindocr/models/base_model.py b/mindocr/models/base_model.py
index de3e56691..9494817d5 100644
--- a/mindocr/models/base_model.py
+++ b/mindocr/models/base_model.py
@@ -1,6 +1,7 @@
+import numpy as np
from addict import Dict
-from mindspore import nn
+from mindspore import nn, ops
from .backbones import build_backbone
from .heads import build_head
@@ -28,6 +29,11 @@ def __init__(self, config: dict):
else:
self.is_kie = False
+ if config.type == "layout":
+ self.is_layout = True
+ else:
+ self.is_layout = False
+
if config.transform:
transform_name = config.transform.pop("name")
self.transform = build_trans(transform_name, **config.transform)
@@ -77,12 +83,32 @@ def re(self, *inputs):
x = self.head(x, input_ids, question, question_label, answer, answer_label)
return x
+ def layout(self, *inputs):
+ pixel_values = inputs[0]
+ hw_ori = inputs[1]
+ hw_scale = inputs[2]
+ pixel_values_unpad_shape = hw_ori * hw_scale
+
+ # fix batch
+ max_size = ops.reduce_max(pixel_values_unpad_shape, axis=0)
+ stride = self.backbone._size_divisibility
+ max_size = (max_size + (stride - 1)) // stride * stride
+ pixel_values = pixel_values[:, :, :int(max_size[0]), :int(max_size[1])]
+
+ features = self.backbone(pixel_values=pixel_values)
+ proposals, rois_mask = self.neck.predict(features, pixel_values_unpad_shape)
+ res = self.head.predict(features, proposals, rois_mask, pixel_values_unpad_shape)
+ return res
+
def construct(self, *args):
if self.is_kie is True:
if self.head_name == "TokenClassificationHead":
return self.ser(*args)
elif self.head_name == "RelationExtractionHead":
return self.re(*args)
+ elif self.is_layout:
+ if self.head_name == "CascadeROIHeads":
+ return self.layout(*args)
x = args[0]
if self.transform is not None:
@@ -118,8 +144,6 @@ def construct(self, *args):
import time
- import numpy as np
-
import mindspore as ms
bs = 8
diff --git a/mindocr/models/heads/builder.py b/mindocr/models/heads/builder.py
index 286fc997f..d60313708 100644
--- a/mindocr/models/heads/builder.py
+++ b/mindocr/models/heads/builder.py
@@ -18,6 +18,7 @@
'YOLOv8Head',
'MultiHead',
'TableMasterHead',
+ 'CascadeROIHeads'
]
from .cls_head import MobileNetV3Head
from .conv_head import ConvHead
@@ -34,6 +35,7 @@
from .rec_multi_head import MultiHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VisionLANHead
+from .roi_head.box_head import CascadeROIHeads
from .table_master_head import TableMasterHead
from .yolov8_head import YOLOv8Head
diff --git a/mindocr/models/heads/roi_head/__init__.py b/mindocr/models/heads/roi_head/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindocr/models/heads/roi_head/box_head.py b/mindocr/models/heads/roi_head/box_head.py
new file mode 100644
index 000000000..a715ce562
--- /dev/null
+++ b/mindocr/models/heads/roi_head/box_head.py
@@ -0,0 +1,222 @@
+from addict import Dict
+
+from mindspore import nn, ops
+from mindspore.common.initializer import HeNormal, HeUniform, Normal
+
+from ...label_assignment import BBoxAssigner
+from ...utils.box_utils import delta2bbox
+from .mask_head import MaskRCNNConvUpSampleHead
+from .roi_extractor import RoIExtractor
+
+
+class FastRCNNConvFCHead(nn.SequentialCell):
+
+ def __init__(self, in_channel=256, out_channel=1024, resolution=7, conv_dims=[], fc_dims=[1024, 1024]):
+ super().__init__()
+
+ for k, conv_dim in enumerate(conv_dims):
+ conv = nn.Conv2d(in_channel,
+ conv_dim,
+ kernel_size=3,
+ padding=1,
+ pad_mode='pad',
+ weight_initializer=HeNormal(mode="fan_out", nonlinearity="relu"),
+ has_bias=True,
+ bias_init="zeros")
+ self.insert_child_to_cell("conv{}".format(k + 1), conv)
+ self.insert_child_to_cell("conv_relu{}".format(k + 1), nn.ReLU())
+
+ self._output_size = in_channel * resolution * resolution
+ for k, fc_dim in enumerate(fc_dims):
+ if k == 0:
+ self.insert_child_to_cell("flatten", nn.Flatten())
+ fc = nn.Dense(self._output_size,
+ fc_dim,
+ weight_init=HeUniform(negative_slope=1),
+ has_bias=True,
+ bias_init="zeros")
+ self.insert_child_to_cell("fc{}".format(k + 1), fc)
+ self.insert_child_to_cell("fc_relu{}".format(k + 1), nn.ReLU())
+ self._output_size = fc_dim
+
+ def construct(self, x):
+ b, n, c, _, _ = x.shape
+ x = x.reshape(b * n, -1)
+ for layer in self:
+ x = layer(x)
+ return x
+
+
+class FastRCNNOutputLayers(nn.Cell):
+ """
+ Two linear layers for predicting Fast R-CNN outputs:
+
+ 1. proposal-to-detection box regression deltas
+ 2. classification scores
+ """
+
+ def __init__(self, out_channel, num_classes, cls_agnostic_bbox_reg=True, box_dim=4):
+ super().__init__()
+
+ self.num_classes = num_classes
+
+ self.cls_score = nn.Dense(out_channel,
+ num_classes + 1,
+ weight_init=Normal(sigma=0.01),
+ has_bias=True,
+ bias_init="zeros")
+ self.num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
+ self.bbox_pred = nn.Dense(out_channel,
+ self.num_bbox_reg_classes * box_dim,
+ weight_init=Normal(sigma=0.001),
+ has_bias=True,
+ bias_init="zeros")
+
+ def construct(self, x):
+ if x.dim() > 2:
+ x = ops.function.flatten(x, start_dim=1)
+ scores = self.cls_score(x)
+ proposal_deltas = self.bbox_pred(x)
+ return scores, proposal_deltas
+
+ def predict_boxes(self, predictions, proposals):
+ if not len(proposals):
+ return []
+ batch_size, rois_num, _ = proposals.shape
+ _, proposal_deltas = predictions
+ rois = ops.tile(proposals[:, :, :4].reshape((batch_size, rois_num, 1, 4)), (1, 1, self.num_bbox_reg_classes, 1))
+ # rois = rois.reshape((-1, rois.shape[-1]))[:, :4]
+ pred_loc = delta2bbox(proposal_deltas.reshape((-1, 4)), rois.reshape((-1, 4))) # true box xyxy
+ pred_loc = pred_loc.reshape((batch_size, rois_num, self.num_bbox_reg_classes * 4))
+ return pred_loc
+
+ def predict_probs(self, predictions, proposals):
+ batch_size, rois_num, _ = proposals.shape
+ scores, _ = predictions
+ pred_cls = scores.reshape((batch_size, rois_num, -1))
+ pred_cls = ops.softmax(pred_cls, axis=-1)
+ return pred_cls
+
+
+def get_head(cfg):
+ if cfg.name == "FastRCNNConvFCHead":
+ return FastRCNNConvFCHead(in_channel=cfg.in_channel,
+ out_channel=cfg.out_channel,
+ resolution=cfg.pooler_resolution,
+ conv_dims=cfg.conv_dims,
+ fc_dims=cfg.fc_dims)
+ else:
+ raise InterruptedError(f"Not support bbox_head: {cfg.name}")
+
+
+class CascadeROIHeads(nn.Cell):
+ """Cascade RCNN bbox head"""
+
+ def __init__(self, in_channels, with_mask=False, **cfg):
+ super(CascadeROIHeads, self).__init__()
+ cfg = Dict(cfg)
+ cascade_bbox_reg_weights = cfg.roi_box_cascade_head.bbox_reg_weights
+ cascade_ious = cfg.roi_box_cascade_head.ious
+
+ self.box_pooler = RoIExtractor(resolution=cfg.roi_box_head.pooler_resolution,
+ featmap_strides=cfg.roi_extractor.featmap_strides,
+ pooler_sampling_ratio=cfg.roi_box_head.pooler_sampling_ratio,
+ pooler_type=cfg.roi_box_head.pooler_type)
+
+ self.box_in_features = cfg.in_features
+ self.num_classes = cfg.num_classes
+ self.cls_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
+ self.loc_loss = nn.SmoothL1Loss(reduction="none")
+ self.with_mask = with_mask
+
+ box_heads, box_predictors, proposal_matchers = [], [], []
+ for match_iou, bbox_reg_weights in zip(cascade_ious, cascade_bbox_reg_weights):
+ box_head = get_head(cfg.roi_box_head)
+ box_heads.append(box_head)
+ box_predictors.append(FastRCNNOutputLayers(cfg.roi_box_head.out_channel, self.num_classes,
+ cfg.roi_box_head.cls_agnostic_bbox_reg))
+ proposal_matchers.append(BBoxAssigner(
+ rois_per_batch=cfg.bbox_assigner.rois_per_batch,
+ bg_thresh=cfg.bbox_assigner.bg_thresh,
+ fg_thresh=cfg.bbox_assigner.fg_thresh,
+ fg_fraction=cfg.bbox_assigner.fg_fraction,
+ num_classes=cfg.num_classes,
+ with_mask=with_mask
+ ))
+
+ self.box_head = nn.CellList(box_heads)
+ self.box_predictor = nn.CellList(box_predictors)
+ self.proposal_matchers = nn.CellList(proposal_matchers)
+
+ self.num_cascade_stages = len(box_heads)
+
+ if cfg.mask_on:
+ self.mask_head = MaskRCNNConvUpSampleHead(in_channels=cfg.roi_mask_head.in_channel,
+ num_classes=self.num_classes,
+ conv_dims=cfg.roi_mask_head.conv_dims)
+ self.mask_pooler = RoIExtractor(resolution=cfg.roi_mask_head.pooler_resolution,
+ featmap_strides=cfg.roi_extractor.featmap_strides,
+ pooler_sampling_ratio=cfg.roi_mask_head.pooler_sampling_ratio,
+ pooler_type=cfg.roi_mask_head.pooler_type)
+
+ def construct(self, feats, rois, rois_mask, gts, gt_masks=None):
+ """
+ feats (list[Tensor]): Feature maps from backbone
+ rois (list[Tensor]): RoIs generated from RPN module
+ rois_mask (Tensor): The number of RoIs in each image
+ gts (Tensor): The ground-truth
+ """
+ pass
+
+ def _run_stage(self, features, proposals, proposals_mask, stage):
+ box_features = self.box_pooler(features, proposals, proposals_mask)
+ if self.training:
+ pass
+ box_features = self.box_head[stage](box_features)
+ return self.box_predictor[stage](box_features)
+
+ def clip_boxes(self, boxes, im_shape):
+ h, w = im_shape
+ x1 = ops.clip_by_value(boxes[..., 0], 0, w)
+ y1 = ops.clip_by_value(boxes[..., 1], 0, h)
+ x2 = ops.clip_by_value(boxes[..., 2], 0, w)
+ y2 = ops.clip_by_value(boxes[..., 3], 0, h)
+ boxes = ops.stack((x1, y1, x2, y2), -1)
+ return boxes
+
+ def _create_proposals_from_boxes(self, boxes, image_sizes):
+ proposals = []
+ for boxes_per_image, image_size in zip(boxes, image_sizes):
+ boxes_per_image = self.clip_boxes(boxes_per_image, image_size)
+ if self.training:
+ pass
+ proposals.append(boxes_per_image)
+ return ops.stack(proposals, axis=0)
+
+ def predict(self, features, proposals, proposals_mask, image_sizes):
+ features = [features[f] for f in self.box_in_features]
+ head_outputs = [] # (predictor, predictions, proposals)
+ prev_pred_boxes = None
+
+ for k in range(self.num_cascade_stages):
+ if k > 0:
+ proposals = self._create_proposals_from_boxes(prev_pred_boxes, image_sizes)
+ predictions = self._run_stage(features, proposals, proposals_mask, k)
+ prev_pred_boxes = self.box_predictor[k].predict_boxes(predictions, proposals)
+ head_outputs.append((self.box_predictor[k], predictions, proposals))
+
+ scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
+
+ # Average the scores across heads
+ scores = [
+ sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
+ for scores_per_image in zip(*scores_per_stage)
+ ]
+ scores = ops.stack(scores, axis=0)
+ # Use the boxes of the last head
+ predictor, predictions, proposals = head_outputs[-1]
+ boxes = predictor.predict_boxes(predictions, proposals)
+ boxes = self._create_proposals_from_boxes(boxes, image_sizes)
+
+ res = ops.concat((boxes, scores), axis=-1)
+ return res
diff --git a/mindocr/models/heads/roi_head/mask_head.py b/mindocr/models/heads/roi_head/mask_head.py
new file mode 100644
index 000000000..3e00b2229
--- /dev/null
+++ b/mindocr/models/heads/roi_head/mask_head.py
@@ -0,0 +1,53 @@
+from mindspore import nn
+from mindspore.common.initializer import HeNormal, Normal
+
+
+class MaskRCNNConvUpSampleHead(nn.SequentialCell):
+
+ def __init__(self, in_channels, num_classes=5, conv_dims=[]):
+ super().__init__()
+
+ cur_channels = in_channels
+ for k, conv_dim in enumerate(conv_dims[:-1]):
+ conv = nn.Conv2d(
+ cur_channels,
+ conv_dim,
+ kernel_size=3,
+ padding=1,
+ pad_mode='pad',
+ weight_init=HeNormal(mode="fan_out", nonlinearity="relu"),
+ has_bias=True,
+ bias_init="zeros"
+ )
+ self.insert_child_to_cell("mask_fcn{}".format(k + 1), conv)
+ cur_channels = conv_dim
+
+ self.deconv = nn.Conv2dTranspose(
+ in_channels=cur_channels,
+ out_channels=conv_dims[-1],
+ kernel_size=2,
+ stride=2,
+ pad_mode="valid",
+ weight_init=HeNormal(mode="fan_out", nonlinearity="relu"),
+ has_bias=True,
+ bias_init="zeros"
+ )
+ self.insert_child_to_cell("deconv_relu", nn.ReLU())
+ cur_channels = conv_dims[-1]
+
+ self.predictor = nn.Conv2d(
+ cur_channels,
+ num_classes,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ pad_mode='valid',
+ weight_init=Normal(sigma=0.001),
+ has_bias=True,
+ bias_init="zeros"
+ )
+
+ def construct(self, x):
+ for layer in self:
+ x = layer(x)
+ return x
diff --git a/mindocr/models/heads/roi_head/patch.py b/mindocr/models/heads/roi_head/patch.py
new file mode 100644
index 000000000..44d7465d6
--- /dev/null
+++ b/mindocr/models/heads/roi_head/patch.py
@@ -0,0 +1,23 @@
+from mindspore import _checkparam as validator
+from mindspore.ops import ROIAlign
+from mindspore.ops.primitive import prim_attr_register
+
+
+@prim_attr_register
+def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2, roi_end_mode=1):
+ """Initialize ROIAlign"""
+ validator.check_value_type("pooled_height", pooled_height, [int], self.name)
+ validator.check_value_type("pooled_width", pooled_width, [int], self.name)
+ validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
+ validator.check_value_type("sample_num", sample_num, [int], self.name)
+ validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
+ validator.check_int_range(roi_end_mode, 0, 2, validator.INC_BOTH, "roi_end_mode", self.name)
+ self.pooled_height = pooled_height
+ self.pooled_width = pooled_width
+ self.spatial_scale = spatial_scale
+ self.sample_num = sample_num
+ self.roi_end_mode = roi_end_mode
+
+
+def patch_roialign():
+ ROIAlign.__init__ = __init__
diff --git a/mindocr/models/heads/roi_head/roi_extractor.py b/mindocr/models/heads/roi_head/roi_extractor.py
new file mode 100644
index 000000000..4d130d3a6
--- /dev/null
+++ b/mindocr/models/heads/roi_head/roi_extractor.py
@@ -0,0 +1,98 @@
+import math
+
+import numpy as np
+
+import mindspore as ms
+from mindspore import nn, ops
+
+from ...utils.box_utils import tensor
+from .patch import patch_roialign
+
+patch_roialign()
+
+pooler_type_map = {
+ "ROIAlign": 0,
+ "ROIAlignV2": 2
+}
+
+
+class RoIExtractor(nn.Cell):
+ """
+ Extract RoI features from multiple feature map.
+
+ Args:
+ resolution (int) - RoI resolution.
+ featmap_strides (List[int]): Strides of input feature maps.
+ finest_scale (int): Scale threshold of mapping to level 0. Default: 56.
+ """
+
+ def __init__(self,
+ resolution,
+ featmap_strides,
+ pooler_sampling_ratio=0,
+ pooler_type="ROIAlign",
+ finest_scale=224,
+ canonical_level=4):
+ super(RoIExtractor, self).__init__()
+ self.finest_scale = finest_scale
+ self.canonical_level = canonical_level
+ self.roi_layers = []
+ self.num_levels = len(featmap_strides)
+ self.min_level = int(-(math.log2(1 / featmap_strides[0])))
+ self.max_level = int(-(math.log2(1 / featmap_strides[-1])))
+ self.resolution = resolution
+ for s in featmap_strides:
+ self.roi_layers.append(
+ ops.ROIAlign(pooled_height=resolution,
+ pooled_width=resolution,
+ spatial_scale=1 / s,
+ sample_num=pooler_sampling_ratio,
+ roi_end_mode=pooler_type_map[pooler_type])
+ )
+ self.featmap_strides = featmap_strides
+ self.temp_roi = ms.Tensor(
+ np.array([0, 0, featmap_strides[-1] + 1, featmap_strides[-1] + 1]).astype(np.float32).reshape(1, 4)
+ )
+
+ def log2(self, value):
+ return ops.log(value + 1e-4) / ops.log(tensor(2, value.dtype))
+
+ def map_roi_levels(self, rois, num_levels):
+ """Map rois to corresponding feature levels by scales.
+
+ - scale < finest_scale * 2: level 0
+ - finest_scale * 2 <= scale < finest_scale * 4: level 1
+ - finest_scale * 4 <= scale < finest_scale * 8: level 2
+ - scale >= finest_scale * 8: level 3
+
+ Args:
+ rois (Tensor): Input RoIs, shape (k, 5).
+ num_levels (int): Total level number.
+
+ Returns:
+ Tensor: Level index (0-based) of each RoI, shape (k, )
+ """
+ scale = ops.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
+ target_lvls = ops.floor(self.canonical_level + ops.log2(scale / self.finest_scale + 1e-8))
+ target_lvls = target_lvls.clamp(min=self.min_level, max=self.max_level) - self.min_level
+ return target_lvls
+
+ def construct(self, features, rois, rois_mask):
+ # rois shape is [batch_size, self.post_nms_top_n, 4] 4 is (x0, y0, x1, y1)
+ batch_size, num_sample, _ = rois.shape
+ batch_c = ops.repeat_elements(ops.arange(batch_size).astype(rois.dtype), num_sample, axis=0).reshape(-1, 1)
+ rois = rois.reshape(batch_size * num_sample, 4)
+ rois_mask = rois_mask.reshape(batch_size * num_sample, 1).astype(ms.bool_)
+ rois = ops.select(
+ ops.tile(rois_mask, (1, 4)), rois, ops.tile(self.temp_roi.astype(rois.dtype), (batch_size * num_sample, 1))
+ )
+ rois = ops.concat((batch_c, rois), 1)
+ out_channel = features[0].shape[1]
+ target_lvls = self.map_roi_levels(rois, self.num_levels).reshape(batch_size * num_sample, 1)
+ res = ops.zeros((batch_size * num_sample, out_channel, self.resolution, self.resolution), features[0].dtype)
+ for i in range(self.num_levels):
+ mask = ops.logical_and(target_lvls == i, rois_mask)
+ mask = ops.tile(mask.reshape((-1, 1, 1, 1)), (1, out_channel, self.resolution, self.resolution))
+ roi_feats_t = self.roi_layers[i](features[i], rois)
+ res = ops.select(mask, roi_feats_t, res)
+ return res.reshape(batch_size, num_sample, out_channel, self.resolution, self.resolution)
diff --git a/mindocr/models/label_assignment.py b/mindocr/models/label_assignment.py
new file mode 100644
index 000000000..6b231804b
--- /dev/null
+++ b/mindocr/models/label_assignment.py
@@ -0,0 +1,367 @@
+import numpy as np
+
+import mindspore as ms
+from mindspore import nn, ops
+
+from .utils.box_utils import bbox2delta
+
+
+class RPNLabelAssignment(nn.Cell):
+ """
+ RPN targets assignment module
+
+ The assignment consists of three steps:
+ 1. Match anchor and ground-truth box, label the anchor with foreground
+ or background sample
+ 2. Sample anchors to keep the properly ratio between foreground and
+ background
+ 3. Generate the targets for classification and regression branch
+ """
+
+ def __init__(
+ self,
+ rnp_sample_batch=256,
+ fg_fraction=0.5,
+ positive_overlap=0.7,
+ negative_overlap=0.3,
+ ignore_thresh=-1,
+ use_random=False,
+ ):
+ super(RPNLabelAssignment, self).__init__()
+ self.rnp_sample_batch = rnp_sample_batch
+ self.fg_fraction = fg_fraction
+ self.positive_overlap = positive_overlap
+ self.negative_overlap = negative_overlap
+ self.ignore_thresh = ignore_thresh
+ self.use_random = use_random
+
+ def construct(self, gts, anchors):
+ """
+ gts: ground-truth instances. [batch_size, max_gt, 5], 5 is cls_id, x, y, x, y
+ anchors (Tensor): [[num_anchors_i, 4]*5], num_anchors_i are all anchors in all feature maps.
+ """
+ batch_size, max_gt, _ = gts.shape
+ anchors = ops.concat(anchors, 0)
+ tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
+ anchors,
+ gts,
+ self.rnp_sample_batch,
+ self.positive_overlap,
+ self.negative_overlap,
+ self.fg_fraction,
+ self.use_random,
+ batch_size,
+ )
+ return tgt_labels, tgt_bboxes, tgt_deltas
+
+
+class BBoxAssigner(nn.Cell):
+ """
+ RCNN targets assignment module
+
+ The assignment consists of three steps:
+ 1. Match RoIs and ground-truth box, label the RoIs with foreground
+ or background sample
+ 2. Sample anchors to keep the properly ratio between foreground and
+ background
+ 3. Generate the targets for classification and regression branch
+
+ Args:
+ rois_per_batch (int): Total number of RoIs per image.
+ default 512
+ fg_fraction (float): Fraction of RoIs that is labeled
+ foreground, default 0.25
+ fg_thresh (float): Minimum overlap required between a RoI
+ and ground-truth box for the (roi, gt box) pair to be
+ a foreground sample. default 0.5
+ bg_thresh (float): Maximum overlap allowed between a RoI
+ and ground-truth box for the (roi, gt box) pair to be
+ a background sample. default 0.5
+ ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
+ if the value is larger than zero.
+ num_classes (int): The number of class.
+ """
+
+ def __init__(
+ self,
+ rois_per_batch=512,
+ fg_fraction=0.25,
+ fg_thresh=0.5,
+ bg_thresh=0.5,
+ ignore_thresh=-1.0,
+ num_classes=80,
+ with_mask=False,
+ ):
+ super(BBoxAssigner, self).__init__()
+ self.rois_per_batch = rois_per_batch
+ self.fg_fraction = fg_fraction
+ self.fg_thresh = fg_thresh
+ self.bg_thresh = bg_thresh
+ self.ignore_thresh = ignore_thresh
+ self.num_classes = num_classes
+ self.with_mask = with_mask
+
+ def construct(self, rois, rois_mask, gts, masks=None):
+ if self.with_mask:
+ return generate_proposal_target_with_mask(
+ rois,
+ rois_mask,
+ gts,
+ masks,
+ self.rois_per_batch,
+ self.fg_fraction,
+ self.fg_thresh,
+ self.bg_thresh,
+ self.num_classes,
+ )
+ return generate_proposal_target(
+ rois,
+ rois_mask,
+ gts,
+ self.rois_per_batch,
+ self.fg_fraction,
+ self.fg_thresh,
+ self.bg_thresh,
+ self.num_classes,
+ )
+
+
+def rpn_anchor_target(
+ anchors,
+ gt_boxes,
+ rnp_sample_batch,
+ rpn_positive_overlap,
+ rpn_negative_overlap,
+ rpn_fg_fraction,
+ use_random=True,
+ batch_size=1,
+):
+ """
+ return:
+ tgt_labels(Tensor): 0 or 1, indicates whether it is a positive sample
+ tgt_bboxes(Tensor): matched boxes, shape is (num_samples, 5)
+ tgt_deltas(Tensor): matched encoding boxes, shape is (num_samples, 4), 4 is encoding xywh.
+ """
+ tgt_labels = []
+ tgt_bboxes = []
+ tgt_deltas = []
+ for i in range(batch_size):
+ gt_bbox = gt_boxes[i]
+ # Step1: match anchor and gt_bbox
+ # matches is the matched box index of anchors
+ # match_labels is the matched label of anchors, -1 is ignore label, 0 is background label.
+ matches, match_labels = label_box(anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True)
+ # Step2: sample anchor
+ fg_mask = ops.logical_and(match_labels != -1, match_labels != 0) # nonzero
+ bg_mask = match_labels == 0
+ if use_random:
+ fg_num = int(rnp_sample_batch * rpn_fg_fraction)
+ fg_sampler = ops.RandomChoiceWithMask(count=fg_num)
+ fg_idx, fg_s_mask = fg_sampler(fg_mask)
+ fg_mask = ops.zeros_like(fg_mask)
+ fg_mask[fg_idx.reshape(-1)] = fg_s_mask
+
+ bg_num = rnp_sample_batch - fg_num
+ bg_num_mask = ms.numpy.arange(int(rnp_sample_batch)) < bg_num
+ bg_sampler = ops.RandomChoiceWithMask(count=int(rnp_sample_batch))
+ bg_idx, bg_s_mask = bg_sampler(bg_mask)
+ bg_mask = ops.zeros_like(bg_mask)
+ bg_mask[bg_idx.reshape(-1)] = ops.logical_and(bg_s_mask, bg_num_mask)
+ else:
+ fg_num = rnp_sample_batch * rpn_fg_fraction
+ fg_num = min(fg_num, fg_mask.astype(ms.float32).sum().astype(ms.int32))
+ bg_num = rnp_sample_batch - fg_num
+ fg_mask = ops.logical_and(ops.cumsum(fg_mask.astype(ms.float32), 0) < fg_num, fg_mask)
+ bg_mask = ops.logical_and(ops.cumsum(bg_mask.astype(ms.float32), 0) < bg_num, bg_mask)
+
+ # Fill with the ignore label (-1), then set positive and negative labels
+ labels = ops.ones(match_labels.shape, ms.int32) * -1
+ labels = ops.select(bg_mask, ops.zeros_like(labels), labels)
+ labels = ops.select(fg_mask, ops.ones_like(labels), labels)
+
+ # Step3: make output
+ matched_gt_boxes = gt_bbox[matches]
+ tgt_delta = bbox2delta(anchors, matched_gt_boxes[:, 1:], weights=(1.0, 1.0, 1.0, 1.0))
+ tgt_labels.append(labels)
+ tgt_bboxes.append(matched_gt_boxes)
+ tgt_deltas.append(tgt_delta)
+ tgt_labels = ops.stop_gradient(ops.stack(tgt_labels, 0))
+ tgt_bboxes = ops.stop_gradient(ops.stack(tgt_bboxes, 0))
+ tgt_deltas = ops.stop_gradient(ops.stack(tgt_deltas, 0))
+ # tgt_labels:
+ return tgt_labels, tgt_bboxes, tgt_deltas
+
+
+# TODO mask
+def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, allow_low_quality):
+ iou = ops.iou(anchors, gt_boxes[:, 1:])
+ # when invalid gt, iou is -1
+ iou = ops.select(ops.tile(gt_boxes[:, 0:1] >= 0, (1, anchors.shape[0])), iou, -ops.ones_like(iou))
+
+ # select best matched gt per anchor
+ matches, matched_vals = ops.ArgMaxWithValue(axis=0, keep_dims=False)(iou)
+
+ # set ignored anchor with match_labels = -1
+ match_labels = ops.ones(matches.shape, ms.int32) * -1
+
+ # ignored is -1, positive is 1, negative is 0
+ neg_cond = ops.logical_and(matched_vals >= 0, matched_vals < negative_overlap)
+ match_labels = ops.select(neg_cond, ops.zeros_like(match_labels), match_labels)
+ match_labels = ops.select(matched_vals >= positive_overlap, ops.ones_like(match_labels), match_labels)
+
+ if allow_low_quality:
+ highest_quality_foreach_gt = ops.ReduceMax(True)(iou, 1)
+ pred_inds_with_highest_quality = (
+ ops.logical_and(iou > 0, iou == highest_quality_foreach_gt).astype(ms.float32).sum(axis=0, keepdims=False)
+ )
+ match_labels = ops.select(pred_inds_with_highest_quality > 0, ops.ones_like(match_labels), match_labels)
+
+ match_labels = match_labels.reshape((-1,))
+ return matches, match_labels
+
+
+def generate_proposal_target(rois, rois_mask, gts, rois_per_batch, fg_fraction, fg_thresh, bg_thresh, num_classes):
+ gt_classes, gt_bboxes, valid_rois, fg_masks, valid_masks = [], [], [], [], []
+ batch_size = len(rois)
+ for i in range(batch_size):
+ roi = rois[i]
+ gt = gts[i]
+ roi_mask = rois_mask[i]
+
+ # Step1: label bbox
+ # matches is the matched box index of roi
+ # match_labels is the matched label of roi, -1 is ignore label, 0 is background label.
+ roi = ops.concat((roi, gt[:, 1:]), 0)
+ roi_mask = ops.concat((roi_mask, gt[:, 0] >= 0), 0)
+ matches, match_labels = label_box(roi, gt, fg_thresh, bg_thresh, False)
+ match_labels = ops.select(roi_mask.astype(ms.bool_), match_labels, ops.ones_like(match_labels) * -1)
+
+ # Step2: sample bbox
+ # structure gt_classes
+ gt_class = gt[:, 0][matches].astype(ms.int32)
+ gt_class = ops.select(match_labels == 0, ops.ones_like(gt_class) * num_classes, gt_class)
+ gt_class = ops.select(match_labels == -1, ops.ones_like(gt_class) * -1, gt_class)
+
+ # structure gt_box
+ fg_mask = ops.logical_and(gt_class > -1, gt_class != num_classes) # nonzero
+ fg_num = int(rois_per_batch * fg_fraction)
+ fg_sampler = ops.RandomChoiceWithMask(count=fg_num)
+ fg_idx, fg_s_mask = fg_sampler(fg_mask)
+
+ bg_mask = gt_class == num_classes
+ bg_sampler = ops.RandomChoiceWithMask(count=int(rois_per_batch))
+ bg_idx, bg_s_mask = bg_sampler(bg_mask)
+ bg_num = int(rois_per_batch - fg_num)
+ bg_num_mask = ms.numpy.arange(int(rois_per_batch)) < bg_num
+ bg_s_mask = ops.logical_and(bg_s_mask, bg_num_mask)
+
+ vaild_idx = ops.concat((fg_idx, bg_idx), 0).reshape(-1)
+ vaild_mask = ops.concat((fg_s_mask, bg_s_mask), 0).reshape(-1)
+ fg_s_mask = ops.concat((fg_s_mask, ops.zeros_like(bg_s_mask)), 0).reshape(-1)
+
+ # Step3: get result
+ # set ignore cls to 0
+ gt_class = gt_class[vaild_idx]
+ gt_class = ops.select(vaild_mask, gt_class, ops.zeros_like(gt_class))
+ gt_classes.append(gt_class)
+ gt_bboxes.append(gt[:, 1:][matches][vaild_idx])
+ fg_masks.append(fg_s_mask)
+ valid_masks.append(vaild_mask)
+ valid_rois.append(roi[vaild_idx])
+
+ gt_classes = ops.stop_gradient(ops.stack(gt_classes, 0))
+ gt_bboxes = ops.stop_gradient(ops.stack(gt_bboxes, 0))
+ fg_masks = ops.stop_gradient(ops.stack(fg_masks, 0))
+ valid_masks = ops.stop_gradient(ops.stack(valid_masks, 0))
+ valid_rois = ops.stop_gradient(ops.stack(valid_rois, 0))
+ return gt_classes, gt_bboxes, fg_masks, valid_masks, valid_rois
+
+
+def generate_proposal_target_with_mask(
+ rois, rois_mask, gts, masks, rois_per_batch, fg_fraction, fg_thresh, bg_thresh, num_classes
+):
+ gt_classes, gt_bboxes, valid_rois, pos_rois, fg_masks, valid_masks, gt_masks = [], [], [], [], [], [], []
+ batch_size = len(rois)
+ for i in range(batch_size):
+ roi = rois[i]
+ gt = gts[i]
+ roi_mask = rois_mask[i]
+ mask = masks[i]
+ # Step1: label bbox
+ # matches is the matched box index of roi
+ # match_labels is the matched label of roi, -1 is ignore label, 0 is background label.
+ roi = ops.concat((roi, gt[:, 1:]), 0)
+ roi_mask = ops.concat((roi_mask, gt[:, 0] >= 0), 0)
+ matches, match_labels = label_box(roi, gt, fg_thresh, bg_thresh, False)
+ match_labels = ops.select(roi_mask.astype(ms.bool_), match_labels, ops.ones_like(match_labels) * -1)
+
+ # Step2: sample bbox
+ # structure gt_classes
+ gt_class = gt[:, 0][matches].astype(ms.int32)
+ gt_class = ops.select(match_labels == 0, ops.ones_like(gt_class) * num_classes, gt_class)
+ gt_class = ops.select(match_labels == -1, ops.ones_like(gt_class) * -1, gt_class)
+
+ # structure gt_box
+ # structure gt_box
+ fg_mask = ops.logical_and(gt_class > -1, gt_class != num_classes) # nonzero
+ fg_num = int(rois_per_batch * fg_fraction)
+ fg_sampler = ops.RandomChoiceWithMask(count=fg_num)
+ fg_idx, fg_s_mask = fg_sampler(fg_mask)
+
+ bg_mask = gt_class == num_classes
+ bg_sampler = ops.RandomChoiceWithMask(count=int(rois_per_batch))
+ bg_idx, bg_s_mask = bg_sampler(bg_mask)
+ bg_num = int(rois_per_batch - fg_num)
+ bg_num_mask = ops.arange(int(rois_per_batch)) < bg_num
+ bg_s_mask = ops.logical_and(bg_s_mask, bg_num_mask)
+
+ fg_idx = fg_idx.reshape(-1)
+ bg_idx = bg_idx.reshape(-1)
+ vaild_idx = ops.concat((fg_idx, bg_idx), 0)
+ vaild_mask = ops.concat((fg_s_mask, bg_s_mask), 0).reshape(-1)
+ fg_s_mask = ops.concat((fg_s_mask, ops.zeros_like(bg_s_mask)), 0).reshape(-1)
+
+ # Step3: get result
+ # set ignore cls to 0
+ gt_class = gt_class[vaild_idx]
+ gt_class = ops.select(vaild_mask, gt_class, ops.zeros_like(gt_class))
+ gt_classes.append(gt_class)
+ gt_bboxes.append(gt[:, 1:][matches][vaild_idx])
+ gt_masks.append(mask[matches][fg_idx])
+ fg_masks.append(fg_s_mask)
+ valid_masks.append(vaild_mask)
+ valid_rois.append(roi[vaild_idx])
+ pos_rois.append(roi[fg_idx])
+
+ gt_classes = ops.stop_gradient(ops.stack(gt_classes, 0))
+ gt_bboxes = ops.stop_gradient(ops.stack(gt_bboxes, 0))
+ fg_masks = ops.stop_gradient(ops.stack(fg_masks, 0))
+ gt_masks = ops.stop_gradient(ops.stack(gt_masks, 0))
+ valid_masks = ops.stop_gradient(ops.stack(valid_masks, 0))
+ valid_rois = ops.stop_gradient(ops.stack(valid_rois, 0))
+ pos_rois = ops.stop_gradient(ops.stack(pos_rois, 0))
+ return gt_classes, gt_bboxes, gt_masks, fg_masks, valid_masks, valid_rois, pos_rois
+
+
+if __name__ == "__main__":
+ bbox_a = BBoxAssigner()
+ rois = ms.Tensor(
+ np.concatenate(
+ (
+ np.random.uniform(0, 640, (2, 1000, 4)).astype(np.float32),
+ np.random.uniform(0, 5, (2, 1000, 1)).astype(np.int32).astype(np.float32),
+ ),
+ -1,
+ )
+ )
+ rois_mask = ms.Tensor(np.random.random((2, 1000)) > 0.5)
+ xy1 = np.random.uniform(0, 440, (2, 120, 2)).astype(np.float32)
+ wh = np.random.uniform(0, 200, (2, 120, 2)).astype(np.float32)
+ c = np.random.uniform(0, 80, (2, 120, 1)).astype(np.int32).astype(np.float32)
+ xy2 = xy1 + wh
+ gts = np.concatenate((c, xy1, xy2), -1)
+ gts[:, 50:] = np.ones((2, 70, 5)).astype(np.float32) * -1
+ gts = ms.Tensor(gts)
+ out = bbox_a(rois, rois_mask, gts)
+ for o in out:
+ print(o.shape)
diff --git a/mindocr/models/necks/builder.py b/mindocr/models/necks/builder.py
index f7f7e9678..1fe7e3d72 100644
--- a/mindocr/models/necks/builder.py
+++ b/mindocr/models/necks/builder.py
@@ -11,13 +11,15 @@
'MasterEncoder',
'RSEFPN',
'YOLOv8Neck',
- 'Identity'
+ 'Identity',
+ 'RPN'
]
from .fpn import DBFPN, EASTFPN, FCEFPN, FPN, PSEFPN, RSEFPN
from .identity import Identity
from .img2seq import Img2Seq
from .master_encoder import MasterEncoder
from .rnn import RNNEncoder
+from .rpn.rpn import RPN
from .select import Select
from .yolov8_neck import YOLOv8Neck
diff --git a/mindocr/models/necks/rpn/__init__.py b/mindocr/models/necks/rpn/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindocr/models/necks/rpn/anchor_generator.py b/mindocr/models/necks/rpn/anchor_generator.py
new file mode 100644
index 000000000..495ef8540
--- /dev/null
+++ b/mindocr/models/necks/rpn/anchor_generator.py
@@ -0,0 +1,103 @@
+# The code is based on
+# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/anchor_generator.py
+
+import math
+
+import mindspore as ms
+from mindspore import nn, ops
+
+
+class AnchorGenerator(nn.Cell):
+ """
+ Generate anchors for RCNN
+
+ Args:
+ anchor_sizes (list[float]): The anchor sizes at each feature point.
+ aspect_ratios (list[float]): The aspect ratios at each feature point.
+ strides (list[float]): The strides of feature maps of anchors.
+ offset (float): The offset of anchors.
+ """
+
+ def __init__(
+ self,
+ anchor_sizes=[[64], [128], [256], [512]],
+ aspect_ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64],
+ variance=[1.0, 1.0, 1.0, 1.0],
+ offset=0.0,
+ ):
+ super(AnchorGenerator, self).__init__()
+ self.anchor_sizes = anchor_sizes
+ self.aspect_ratios = aspect_ratios
+ self.strides = strides
+ self.variance = variance
+ self.cell_anchors = self.calculate_anchors(len(strides))
+ self.offset = offset
+
+ def broadcast_params(self, params, num_features):
+ if not isinstance(params[0], (list, tuple)):
+ return [params] * num_features
+ if len(params) == 1:
+ return list(params) * num_features
+ return params
+
+ def generate_cell_anchors(self, sizes, aspect_ratios):
+ anchors = []
+ for size in sizes:
+ area = size**2.0
+ for aspect_ratio in aspect_ratios:
+ w = math.sqrt(area / aspect_ratio)
+ h = aspect_ratio * w
+ x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
+ anchors.append([x0, y0, x1, y1])
+ return ms.Tensor(anchors, ms.float32)
+
+ def calculate_anchors(self, num_features):
+ sizes = self.broadcast_params(self.anchor_sizes, num_features)
+ aspect_ratios = self.broadcast_params(self.aspect_ratios, num_features)
+ cell_anchors = [self.generate_cell_anchors(s, a) for s, a in zip(sizes, aspect_ratios)]
+ return cell_anchors
+
+ def create_grid_offsets(self, size, stride, offset):
+ grid_height, grid_width = size[0], size[1]
+ shifts_x = ms.ops.arange(offset * stride, grid_width * stride, step=stride, dtype=ms.float32)
+ shifts_y = ms.ops.arange(offset * stride, grid_height * stride, step=stride, dtype=ms.float32)
+ # shift_x, shift_y = ops.meshgrid((shifts_x, shifts_y))
+ shift_x, shift_y = ops.meshgrid(shifts_x, shifts_y)
+ shift_x = ops.reshape(shift_x, (-1,))
+ shift_y = ops.reshape(shift_y, (-1,))
+ return shift_x, shift_y
+
+ def grid_anchors(self, grid_sizes):
+ anchors = []
+ for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
+ shift_x, shift_y = self.create_grid_offsets(size, stride, self.offset)
+ shifts = ops.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
+ shifts = ops.reshape(shifts, (-1, 1, 4))
+ base_anchors = ops.reshape(base_anchors, (1, -1, 4))
+ anchor = ops.reshape(shifts + base_anchors, (-1, 4))
+ anchors.append(anchor)
+
+ return anchors
+
+ def construct(self, grid_sizes):
+ anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
+ return anchors_over_all_feature_maps
+
+ @property
+ def num_anchors(self):
+ """
+ Returns:
+ int: number of anchors at every pixel
+ location, on that feature map.
+ For example, if at every pixel we use anchors of 3 aspect
+ ratios and 5 sizes, the number of anchors is 15.
+ For FPN models, `num_anchors` on every feature map is the same.
+ """
+ return len(self.cell_anchors[0])
+
+
+if __name__ == "__main__":
+ anchors = AnchorGenerator()(((192, 320), (96, 160), (48, 80), (24, 40)))
+ for a in anchors:
+ print(a.shape, a[100:120])
diff --git a/mindocr/models/necks/rpn/proposal_generator.py b/mindocr/models/necks/rpn/proposal_generator.py
new file mode 100644
index 000000000..0acd0f88b
--- /dev/null
+++ b/mindocr/models/necks/rpn/proposal_generator.py
@@ -0,0 +1,139 @@
+from mindspore import nn, ops
+
+from ...utils.box_utils import delta2bbox
+
+
+def nonempty(box, threshold):
+ widths = box[:, 2] - box[:, 0]
+ heights = box[:, 3] - box[:, 1]
+ valid = ops.logical_and((widths > threshold), (heights > threshold))
+ return valid
+
+
+def batch_nms(boxes, score, idxs, threshold):
+ max_coordinate = boxes.max()
+ offsets = idxs * (max_coordinate + 1)
+ boxes_for_nms = (boxes + ops.expand_dims(offsets, 1)).astype(score.dtype)
+ boxes_for_nms = ops.concat((boxes_for_nms, ops.expand_dims(score, -1)), axis=-1)
+ output_boxes, output_idx, selected_mask = ops.NMSWithMask(threshold)(boxes_for_nms)
+ return output_boxes, output_idx, selected_mask
+
+
+def clip_boxes(boxes, im_shape):
+ h, w = im_shape
+ x1 = ops.clip_by_value(boxes[..., 0], 0, w)
+ y1 = ops.clip_by_value(boxes[..., 1], 0, h)
+ x2 = ops.clip_by_value(boxes[..., 2], 0, w)
+ y2 = ops.clip_by_value(boxes[..., 3], 0, h)
+ boxes = ops.stack((x1, y1, x2, y2, boxes[..., 4]), -1)
+ return boxes
+
+
+class ProposalGenerator(nn.Cell):
+ """
+ Proposal generation module
+ Args:
+ pre_nms_top_n (int): Number of total bboxes to be kept per
+ image before NMS. default 6000
+ post_nms_top_n (int): Number of total bboxes to be kept per
+ image after NMS. default 1000
+ nms_thresh (float): Threshold in NMS. default 0.5
+ min_size (flaot): Remove predicted boxes with either height or
+ width < min_size. default 0.1
+ """
+
+ def __init__(self, pre_nms_top_n=12000, post_nms_top_n=2000, nms_thresh=0.5, min_size=1.0):
+ super(ProposalGenerator, self).__init__()
+ self.pre_nms_top_n = pre_nms_top_n
+ self.post_nms_top_n = post_nms_top_n
+ self.min_size = min_size
+ self.topk = ops.TopK()
+ self.nms = ops.NMSWithMask(nms_thresh)
+
+ def construct(self, scores, bbox_deltas, anchors, im_shape):
+ N = scores[0].shape[0]
+ B = anchors[0].shape[-1]
+ pred_objectness_logits, pred_anchor_deltas = (), ()
+
+ for score in scores:
+ pred_objectness_logits = pred_objectness_logits + (score.transpose((0, 2, 3, 1)).reshape(N, -1),)
+ for delta in bbox_deltas:
+ pred_anchor_deltas = pred_anchor_deltas + (delta.transpose((0, 2, 3, 1)).reshape((N, -1, B)),)
+ rpn_rois, rpn_rois_mask = self.predict_proposals(anchors, pred_objectness_logits, pred_anchor_deltas, im_shape)
+ rpn_rois = ops.stop_gradient(rpn_rois)
+ rpn_rois_mask = ops.stop_gradient(rpn_rois_mask)
+ return rpn_rois, rpn_rois_mask
+
+ def predict_proposals(self, anchors, pred_objectness_logits, pred_anchor_deltas, image_sizes):
+ pred_proposals = self.decode_proposals(anchors, pred_anchor_deltas)
+ return self.find_top_rpn_proposals(pred_proposals, pred_objectness_logits, image_sizes)
+
+ def decode_proposals(self, anchors, pred_anchor_deltas):
+ """decode pred_anchor_deltas to true box xyxy"""
+ proposals = ()
+ for anchors_i, pred_anchor_deltas_i in zip(anchors, pred_anchor_deltas):
+ N = pred_anchor_deltas_i.shape[0]
+ B = anchors_i.shape[-1]
+ pred_anchor_deltas_i = pred_anchor_deltas_i.reshape(-1, B)
+ anchors_i = ops.tile(ops.expand_dims(anchors_i, 0), (N, 1, 1)).reshape(-1, B)
+ proposals_i = delta2bbox(
+ pred_anchor_deltas_i, anchors_i, weights=(1.0, 1.0, 1.0, 1.0)
+ )
+ proposals = proposals + (proposals_i.reshape(N, -1, B),)
+ return proposals
+
+ def find_top_rpn_proposals(self, proposals, pred_objectness_logits, image_sizes):
+ """get top post_nms_top_n proposals"""
+ # The NMS implementation has a difference in accuracy from that of Torch, and the interface needs to be replaced
+ # 1. Select top-k anchor after nms for every level and every image
+ boxes = []
+ for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)):
+ batch_size, Hi_Wi_A, B = proposals_i.shape
+ # temp_proposals = ops.concat(
+ # (ops.zeros((Hi_Wi_A, 2), proposals_i.dtype), ops.ones((Hi_Wi_A, 2), proposals_i.dtype)), axis=-1
+ # )
+ batch_boxes = []
+ # logits_i = ops.sigmoid(logits_i)
+ for b in range(batch_size):
+ proposals_ib = proposals_i[b]
+ logits_ib = logits_i[b]
+ # vaild = nonempty(proposals_ib, self.min_size)
+ # logits_ib = ops.select(vaild, logits_i[b], ops.zeros_like(logits_i[b]))
+ # proposals_ib = ops.select(ops.tile(vaild.reshape(-1, 1), (1, 4)), proposals_ib, temp_proposals)
+ num_proposals_i = min(Hi_Wi_A, self.pre_nms_top_n)
+
+ # select top num_proposals_i proposals
+ _, idx = self.topk(logits_ib, num_proposals_i)
+ boxes_for_nms = ops.concat((proposals_ib, ops.expand_dims(logits_ib, -1)), axis=-1)
+ boxes_for_nms = boxes_for_nms[idx]
+ boxes_for_nms = clip_boxes(boxes_for_nms, image_sizes[b])
+ nms_box, _, nms_mask = self.nms(boxes_for_nms)
+ nms_box_logits = ops.select(
+ nms_mask, boxes_for_nms[:, 4], ops.full_like(boxes_for_nms[:, 4], -1000).astype(nms_box.dtype)
+ )
+ boxes_for_nms = ops.concat((boxes_for_nms[:, :4], ops.expand_dims(nms_box_logits, -1)), axis=-1)
+ batch_boxes.append(boxes_for_nms)
+ boxes.append(ops.stack(batch_boxes, 0))
+
+ # 2. Concat all levels together
+ boxes = ops.concat(boxes, axis=1)
+
+ # 3. For each image choose topk results.
+ proposal_boxes = []
+ proposal_masks = []
+ for b in range(boxes.shape[0]):
+ box = boxes[b]
+ nms_box_logits = box[:, 4]
+ _, idx = self.topk(nms_box_logits, self.post_nms_top_n)
+ box_keep = box[idx]
+ nms_box, _, nms_mask = self.nms(box_keep)
+ proposal_boxes.append(box_keep[:, :4])
+ mask = ops.logical_and(box_keep[:, 4] > 0, nms_mask)
+ proposal_masks.append(mask)
+
+ proposal_boxes = ops.stack(proposal_boxes, 0)
+ proposal_masks = ops.stack(proposal_masks, 0)
+ proposal_boxes = ops.stop_gradient(proposal_boxes)
+ proposal_masks = ops.stop_gradient(proposal_masks)
+ # proposal_boxes shape is [self.batch_size, post_nms_top_n, 4] 4 is (x0, y0, x1, y1)
+ return proposal_boxes, proposal_masks
diff --git a/mindocr/models/necks/rpn/rpn.py b/mindocr/models/necks/rpn/rpn.py
new file mode 100644
index 000000000..9ffa6caa6
--- /dev/null
+++ b/mindocr/models/necks/rpn/rpn.py
@@ -0,0 +1,140 @@
+import math
+
+from addict import Dict
+
+from mindspore import nn, ops
+from mindspore.common.initializer import HeUniform
+
+from ...label_assignment import RPNLabelAssignment
+from .anchor_generator import AnchorGenerator
+from .proposal_generator import ProposalGenerator
+
+
+class RPNFeat(nn.Cell):
+ """
+ Feature extraction in RPN head
+
+ Args:
+ num_layers (int): Feat numbers
+ in_channel (int): Input channel
+ out_channel (int): Output channel
+ """
+
+ def __init__(self, num_layers=1, num_anchors=3, in_channel=1024, out_channel=1024):
+ super(RPNFeat, self).__init__()
+ self.rpn_conv = nn.Conv2d(in_channel,
+ out_channel,
+ kernel_size=3,
+ padding=1,
+ pad_mode="pad",
+ weight_init=HeUniform(math.sqrt(5)),
+ has_bias=True,
+ bias_init="zeros")
+ self.rpn_rois_score = nn.Conv2d(
+ out_channel, num_anchors, 1, weight_init=HeUniform(math.sqrt(5)),
+ has_bias=True, bias_init="zeros")
+ self.rpn_rois_delta = nn.Conv2d(
+ out_channel, 4 * num_anchors, 1, weight_init=HeUniform(math.sqrt(5)),
+ has_bias=True, bias_init="zeros")
+ self.relu = nn.ReLU()
+
+ def construct(self, feats):
+ scores = ()
+ deltas = ()
+ for i, feat in enumerate(feats):
+ x = self.relu(self.rpn_conv(feat))
+ scores = scores + (self.rpn_rois_score(x),)
+ deltas = deltas + (self.rpn_rois_delta(x),)
+ return scores, deltas
+
+
+class RPN(nn.Cell):
+ """
+ Region Proposal Network
+
+ Args:
+ cfg(Config): rpn_head config
+ backbone_feat_nums(int): backbone feature numbers
+ in_channel(int): rpn feature conv in channel
+ loss_rpn_bbox(Cell): bbox loss function Cell, default is MAELoss
+ """
+
+ def __init__(self, in_channels, backbone_feat_nums=1, loss_rpn_bbox=None, **config):
+ super(RPN, self).__init__()
+ cfg = Dict(config)
+ self.in_features = cfg.in_features
+ self.out_channels = in_channels
+ acfg = cfg.anchor_generator
+ self.anchor_generator = AnchorGenerator(
+ aspect_ratios=acfg.aspect_ratios, anchor_sizes=acfg.anchor_sizes, strides=acfg.strides
+ )
+ self.num_anchors = self.anchor_generator.num_anchors
+ self.rpn_feat = RPNFeat(backbone_feat_nums, self.num_anchors, in_channels, cfg.feat_channel)
+ tr_pcfg = cfg.train_proposal
+ self.train_gen_proposal = ProposalGenerator(
+ min_size=tr_pcfg.min_size,
+ nms_thresh=tr_pcfg.nms_thresh,
+ pre_nms_top_n=tr_pcfg.pre_nms_top_n,
+ post_nms_top_n=tr_pcfg.post_nms_top_n,
+ )
+ te_pcfg = cfg.test_proposal
+ self.test_gen_proposal = ProposalGenerator(
+ min_size=te_pcfg.min_size,
+ nms_thresh=te_pcfg.nms_thresh,
+ pre_nms_top_n=te_pcfg.pre_nms_top_n,
+ post_nms_top_n=te_pcfg.post_nms_top_n,
+ )
+ rcfg = cfg.rpn_label_assignment
+ self.rpn_target_assign = RPNLabelAssignment(
+ rnp_sample_batch=rcfg.rnp_sample_batch,
+ fg_fraction=rcfg.fg_fraction,
+ positive_overlap=rcfg.positive_overlap,
+ negative_overlap=rcfg.negative_overlap,
+ use_random=rcfg.use_random,
+ )
+ self.loss_rpn_bbox = loss_rpn_bbox
+ if self.loss_rpn_bbox is None:
+ self.loss_rpn_bbox = nn.SmoothL1Loss(reduction="none")
+
+ def construct(self, feats, gts, image_shape):
+ scores, deltas = self.rpn_feat(feats)
+ shapes = ()
+ for feat in feats:
+ shapes += (feat.shape[-2:],)
+ anchors = self.anchor_generator(shapes)
+ rois, rois_mask = self.train_gen_proposal(scores, deltas, anchors, image_shape)
+ tgt_labels, tgt_bboxes, tgt_deltas = self.rpn_target_assign(gts, anchors)
+
+ # cls loss
+ score_pred = ()
+ batch_size = scores[0].shape[0]
+ for score in scores:
+ score_pred = score_pred + (ops.transpose(score, (0, 2, 3, 1)).reshape((batch_size, -1)),)
+ score_pred = ops.concat(score_pred, 1)
+ valid_mask = tgt_labels >= 0
+ fg_mask = tgt_labels > 0
+
+ loss_rpn_cls = ops.SigmoidCrossEntropyWithLogits()(score_pred, fg_mask.astype(score_pred.dtype))
+ loss_rpn_cls = ops.select(valid_mask, loss_rpn_cls, ops.zeros_like(loss_rpn_cls))
+
+ # reg loss
+ delta_pred = ()
+ for delta in deltas:
+ delta_pred = delta_pred + (ops.transpose(delta, (0, 2, 3, 1)).reshape((batch_size, -1, 4)),)
+ delta_pred = ops.concat(delta_pred, 1)
+ loss_rpn_reg = self.loss_rpn_bbox(delta_pred, tgt_deltas)
+ fg_mask = ops.tile(ops.expand_dims(fg_mask, -1), (1, 1, 4))
+ loss_rpn_reg = ops.select(fg_mask, loss_rpn_reg, ops.zeros_like(loss_rpn_reg))
+ loss_rpn_cls = loss_rpn_cls.sum() / (valid_mask.astype(loss_rpn_cls.dtype).sum() + 1e-4)
+ loss_rpn_reg = loss_rpn_reg.sum() / (valid_mask.astype(loss_rpn_reg.dtype).sum() + 1e-4)
+ return rois, rois_mask, loss_rpn_cls, loss_rpn_reg
+
+ def predict(self, features, image_shape):
+ features = [features[f] for f in self.in_features]
+ scores, deltas = self.rpn_feat(features)
+ shapes = ()
+ for feat in features:
+ shapes += (feat.shape[-2:],)
+ anchors = self.anchor_generator(shapes)
+ rois, rois_mask = self.test_gen_proposal(scores, deltas, anchors, image_shape)
+ return rois, rois_mask
diff --git a/mindocr/models/utils/box_utils.py b/mindocr/models/utils/box_utils.py
new file mode 100644
index 000000000..9c658bca3
--- /dev/null
+++ b/mindocr/models/utils/box_utils.py
@@ -0,0 +1,105 @@
+import mindspore as ms
+from mindspore import ops
+
+
+@ops.constexpr
+def tensor(x, dtype=ms.float32):
+ return ms.Tensor(x, dtype)
+
+
+@ops.constexpr
+def shape_prod(shape):
+ size = 1
+ for i in shape:
+ size *= i
+ return size
+
+
+def delta2bbox(deltas, boxes, weights=(10.0, 10.0, 5.0, 5.0), max_shape=None):
+ """Decode deltas to boxes.
+ Note: return tensor shape [n,1,4]
+ """
+ clip_scale = 4
+
+ widths = boxes[:, 2] - boxes[:, 0]
+ heights = boxes[:, 3] - boxes[:, 1]
+ ctr_x = boxes[:, 0] + 0.5 * widths
+ ctr_y = boxes[:, 1] + 0.5 * heights
+
+ wx, wy, ww, wh = weights
+ dx = deltas[:, 0:1] / wx
+ dy = deltas[:, 1:2] / wy
+ dw = deltas[:, 2:3] / ww
+ dh = deltas[:, 3:4] / wh
+ # Prevent sending too large values into ops.exp()
+ dw = ops.minimum(dw, clip_scale)
+ dh = ops.minimum(dh, clip_scale)
+
+ pred_ctr_x = dx * ops.expand_dims(widths, 1) + ops.expand_dims(ctr_x, 1)
+ pred_ctr_y = dy * ops.expand_dims(heights, 1) + ops.expand_dims(ctr_y, 1)
+ pred_w = ops.exp(dw) * ops.expand_dims(widths, 1)
+ pred_h = ops.exp(dh) * ops.expand_dims(heights, 1)
+
+ pred_boxes = []
+ pred_boxes.append(pred_ctr_x - 0.5 * pred_w)
+ pred_boxes.append(pred_ctr_y - 0.5 * pred_h)
+ pred_boxes.append(pred_ctr_x + 0.5 * pred_w)
+ pred_boxes.append(pred_ctr_y + 0.5 * pred_h)
+ pred_boxes = ops.stack(pred_boxes, axis=-1)
+
+ if max_shape is not None:
+ h, w = max_shape
+ x1 = ops.clip_by_value(pred_boxes[..., 0], 0, w)
+ y1 = ops.clip_by_value(pred_boxes[..., 1], 0, h)
+ x2 = ops.clip_by_value(pred_boxes[..., 2], 0, w)
+ y2 = ops.clip_by_value(pred_boxes[..., 3], 0, h)
+ pred_boxes = ops.stack((x1, y1, x2, y2), -1)
+
+ return pred_boxes
+
+
+def bbox2delta(src_boxes, tgt_boxes, weights=(10.0, 10.0, 5.0, 5.0)):
+ """Encode bboxes to deltas."""
+ src_w = src_boxes[:, 2] - src_boxes[:, 0]
+ src_h = src_boxes[:, 3] - src_boxes[:, 1]
+ src_ctr_x = src_boxes[:, 0] + 0.5 * src_w
+ src_ctr_y = src_boxes[:, 1] + 0.5 * src_h
+
+ tgt_w = tgt_boxes[:, 2] - tgt_boxes[:, 0]
+ tgt_h = tgt_boxes[:, 3] - tgt_boxes[:, 1]
+ tgt_vaild = ops.logical_and(tgt_w > 0, tgt_h > 0)
+ tgt_ctr_x = tgt_boxes[:, 0] + 0.5 * tgt_w
+ tgt_ctr_y = tgt_boxes[:, 1] + 0.5 * tgt_h
+ tgt_w = ops.select(tgt_vaild, tgt_w, src_w)
+ tgt_h = ops.select(tgt_vaild, tgt_h, src_h)
+ tgt_ctr_x = ops.select(tgt_vaild, tgt_ctr_x, src_ctr_x)
+ tgt_ctr_y = ops.select(tgt_vaild, tgt_ctr_y, src_ctr_y)
+
+ wx, wy, ww, wh = weights
+ dx = wx * (tgt_ctr_x - src_ctr_x) / src_w
+ dy = wy * (tgt_ctr_y - src_ctr_y) / src_h
+ dw = ww * ops.log(tgt_w / src_w)
+ dh = wh * ops.log(tgt_h / src_h)
+
+ deltas = ops.stack((dx, dy, dw, dh), axis=1)
+ return deltas
+
+
+def xywh2xyxy(box):
+ """box shape is (N, 4), format is xywh"""
+ x, y, w, h = box[:, 0], box[:, 1], box[:, 2], box[:, 3]
+ x1 = x - w // 2
+ y1 = y - h // 2
+ x2 = x1 + w
+ y2 = y1 + h
+ return ops.stack((x1, y1, x2, y2), -1)
+
+
+def xyxy2xywh(box):
+ """box shape is (N, 4), format is xyxy"""
+ x1, y1, x2, y2 = box[:, 0], box[:, 1], box[:, 2], box[:, 3]
+ w = x2 - x1
+ h = y2 - y1
+ x = x1 + w // 2
+ y = y1 + h // 2
+ return ops.stack((x, y, w, h), -1)
diff --git a/mindocr/postprocess/builder.py b/mindocr/postprocess/builder.py
index 24be022f6..27d076cef 100644
--- a/mindocr/postprocess/builder.py
+++ b/mindocr/postprocess/builder.py
@@ -18,7 +18,7 @@
from .det_pse_postprocess import *
from .kie_re_postprocess import VQAReTokenLayoutLMPostProcess
from .kie_ser_postprocess import VQASerTokenLayoutLMPostProcess
-from .layout_postprocess import YOLOv8Postprocess
+from .layout_postprocess import *
from .rec_abinet_postprocess import *
from .rec_postprocess import *
from .table_postprocess import *
diff --git a/mindocr/postprocess/det_db_postprocess.py b/mindocr/postprocess/det_db_postprocess.py
index ccac15110..59f19b433 100644
--- a/mindocr/postprocess/det_db_postprocess.py
+++ b/mindocr/postprocess/det_db_postprocess.py
@@ -110,7 +110,11 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
continue
poly = Polygon(points)
- poly = np.array(expand_poly(points, distance=poly.area * self._expand_ratio / poly.length))
+ poly_list = expand_poly(points, distance=poly.area * self._expand_ratio / poly.length)
+ if self._is_uneven_nested_list(poly_list):
+ poly = np.array(poly_list, dtype=object)
+ else:
+ poly = np.array(poly_list)
if self._out_poly and len(poly) > 1:
continue
poly = poly.reshape(-1, 2)
@@ -134,6 +138,18 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
return polys, scores
return np.array(polys), np.array(scores).astype(np.float32)
+ def _is_uneven_nested_list(self, arr_list):
+ if not isinstance(arr_list, list):
+ return False
+
+ first_length = len(arr_list[0]) if isinstance(arr_list[0], list) else None
+
+ for sublist in arr_list:
+ if not isinstance(sublist, list) or len(sublist) != first_length:
+ return True
+
+ return False
+
@staticmethod
def _fit_box(contour):
"""
diff --git a/mindocr/postprocess/layout_postprocess.py b/mindocr/postprocess/layout_postprocess.py
index 6c219a285..2a1d85a50 100644
--- a/mindocr/postprocess/layout_postprocess.py
+++ b/mindocr/postprocess/layout_postprocess.py
@@ -4,7 +4,7 @@
from mindspore import Tensor
-__all__ = ["YOLOv8Postprocess"]
+__all__ = ["YOLOv8Postprocess", "Layoutlmv3Postprocess"]
class YOLOv8Postprocess(object):
@@ -63,6 +63,47 @@ def __call__(self, preds, img_shape, meta_info, **kwargs):
return result_dicts
+class Layoutlmv3Postprocess(YOLOv8Postprocess):
+ """return image_id, category_id, bbox and scores."""
+
+ def __call__(self, preds, img_shape, meta_info, **kwargs):
+ meta_info = [_.numpy() if isinstance(_, Tensor) else _ for _ in meta_info]
+ image_ids, ori_shape, hw_scale, pad = meta_info
+ preds = preds if isinstance(preds, np.ndarray) else preds.numpy()
+ preds = non_max_suppression_for_layoutlmv3(
+ preds,
+ conf_thres=self.conf_thres,
+ iou_thres=self.iou_thres,
+ conf_free=self.conf_free,
+ multi_label=True,
+ time_limit=self.time_limit,
+ )
+ # Statistics pred
+ result_dicts = list()
+ for si, pred in enumerate(preds):
+ if len(pred) == 0:
+ continue
+
+ # Predictions
+ predn = np.copy(pred)
+ scale_coords_for_layoutlmv3(
+ img_shape[-2:], predn[:, :4], ori_shape[si], ratio=hw_scale[si], pad=None
+ ) # native-space pred
+
+ box = xyxy2xywh(predn[:, :4]) # xywh
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
+ for p, b in zip(pred.tolist(), box.tolist()):
+ result_dicts.append(
+ {
+ "image_id": image_ids[si],
+ "category_id": int(p[5]) + 1,
+ "bbox": [round(x, 3) for x in b],
+ "score": round(p[4], 5),
+ }
+ )
+ return result_dicts
+
+
def _nms(xyxys, scores, threshold):
"""Calculate NMS"""
x1 = xyxys[:, 0]
@@ -156,12 +197,12 @@ def non_max_suppression(
prediction on (bs, N, 4+nc) ndarray each point, the last dimension meaning
[center_x, center_y, width, height, cls0, ...].
conf_free (bool): Whether the prediction result include conf.
- time_limit:
- multi_label:
- agnostic:
- classes:
- iou_thres:
- conf_thres:
+ time_limit (float): Batch NMS maximum waiting time
+ multi_label (bool): Whether to use multiple labels
+ agnostic (bool): Whether the NMS is not aware of the category when executed
+ classes (list[int]): Filter for a specified category
+ iou_thres: (float): IoU threshold for NMS
+ conf_thres: (float): Confidence threshold for NMS
Returns:
list of detections, on (n,6) ndarray per image, the last dimension meaning [xyxy, conf, cls].
@@ -254,6 +295,108 @@ def non_max_suppression(
return output
+def non_max_suppression_for_layoutlmv3(
+ prediction,
+ conf_thres=0.25,
+ iou_thres=0.45,
+ conf_free=False,
+ classes=None,
+ agnostic=True,
+ multi_label=False,
+ time_limit=20.0,
+):
+ """Runs Non-Maximum Suppression (NMS) on inference results
+
+ Args:
+ prediction (ndarray): Prediction. If conf_free is False, prediction on (bs, N, 5+nc) ndarray each point,
+ the last dimension meaning [center_x, center_y, width, height, conf, cls0, ...]; If conf_free is True,
+ prediction on (bs, N, 4+nc) ndarray each point, the last dimension meaning
+ [center_x, center_y, width, height, cls0, ...].
+ conf_free (bool): Whether the prediction result include conf.
+ time_limit (float): Batch NMS maximum waiting time
+ multi_label (bool): Whether to use multiple labels
+ agnostic (bool): Whether the NMS is not aware of the category when executed
+ classes (list[int]): Filter for a specified category
+ iou_thres: (float): IoU threshold for NMS
+ conf_thres: (float): Confidence threshold for NMS
+
+ Returns:
+ list of detections, on (n,6) ndarray per image, the last dimension meaning [xyxy, conf, cls].
+ """
+
+ if not conf_free:
+ nc = prediction.shape[2] - 5 # number of classes
+ else:
+ nc = prediction.shape[2] - 4 # number of classes
+ prediction = np.concatenate(
+ (prediction[..., :4], prediction[..., 4:].max(-1, keepdims=True), prediction[..., 4:]), axis=-1
+ )
+
+ # Settings
+ max_wh = 4096 # (pixels) minimum and maximum box width and height
+ max_det = 300 # maximum number of detections per image
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
+ time_limit = time_limit if time_limit > 0 else 1e3 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [np.zeros((0, 6))] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ box = x[:, :4]
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 4:-1] > conf_thres).nonzero()
+ x = np.concatenate((box[i], x[i, j + 4, None], j[:, None].astype(np.float32)), 1)
+ else: # best class only
+ conf, j = x[:, 4:-1].max(1, keepdim=True)
+ x = np.concatenate((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == np.array(classes)).any(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+ elif n > max_nms: # excess boxes
+ x = x[x[:, 4].argsort()[-max_nms:]] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+
+ i = _nms(boxes, scores, iou_thres) # NMS for per sample
+
+ if i.shape[0] > max_det: # limit detections
+ i = i[:max_det]
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = _box_iou(boxes[i], boxes) > iou_thres # iou matrix # (N, M)
+ weights = iou * scores[None] # box weights
+ # (N, M) @ (M, 4) / (N, 1)
+ x[i, :4] = np.matmul(weights, x[:, :4]) / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ print(
+ f"WARNING: Batch NMS time limit {time_limit}s exceeded, this batch "
+ f"process {xi + 1}/{prediction.shape[0]} sample."
+ )
+ break # time limit exceeded
+
+ return output
+
+
def scale_coords(img1_shape, coords, img0_shape, ratio=None, pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
@@ -275,6 +418,15 @@ def scale_coords(img1_shape, coords, img0_shape, ratio=None, pad=None):
return coords
+def scale_coords_for_layoutlmv3(img1_shape, coords, img0_shape, ratio=None, pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+
+ coords[:, [0, 2]] /= ratio[1] # x rescale
+ coords[:, [1, 3]] /= ratio[0] # y rescale
+ coords = _clip_coords(coords, img0_shape)
+ return coords
+
+
def _clip_coords(boxes, img_shape):
# Clip bounding xyxy bounding boxes to image shape (height, width)
boxes[:, 0] = boxes[:, 0].clip(0, img_shape[1]) # x1
diff --git a/mindocr/utils/dict/layout_category_dict.txt b/mindocr/utils/dict/layout_category_dict.txt
new file mode 100644
index 000000000..d9c92dacb
--- /dev/null
+++ b/mindocr/utils/dict/layout_category_dict.txt
@@ -0,0 +1,5 @@
+text
+title
+list
+table
+figure
diff --git a/requirements.txt b/requirements.txt
index 7bbd2fadb..3ac404261 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -25,3 +25,4 @@ beautifulsoup4
pandas
tablepyxl
lxml
+python-docx
diff --git a/tools/arg_parser.py b/tools/arg_parser.py
index abba6092d..e58f9cd1c 100644
--- a/tools/arg_parser.py
+++ b/tools/arg_parser.py
@@ -84,14 +84,14 @@ def _merge_options(config, options):
return config
-def parse_args_and_config():
+def parse_args_and_config(args=None):
"""
Return:
args: command line argments
cfg: train/eval config dict
"""
parser = create_parser()
- args = parser.parse_args() # CLI args
+ args = parser.parse_args(args) # CLI args
modelarts_setup(args)
diff --git a/tools/infer/text/README.md b/tools/infer/text/README.md
index 02bfb537a..3266eddc1 100644
--- a/tools/infer/text/README.md
+++ b/tools/infer/text/README.md
@@ -191,49 +191,6 @@ web_cvpr.png [{"transcription": "canada", "points": [[430, 148], [540, 148], [54
**Notes:**
1. For more argument illustrations and usage, please run `python tools/infer/text/predict_system.py -h` or view `tools/infer/text/config.py`
-## Layout Analysis
-
-To run layout analysis on an input image or a directory containing multiple images, please execute
-```shell
-python tools/infer/text/predict_layout.py --image_dir {path_to_img or dir_to_imgs} --layout_algorithm YOLOv8 --visualize_output True
-```
-After running, the inference results will be saved in `{args.draw_img_save_dir}/det_results.txt`, where `--draw_img_save_dir` is the directory for saving results and is set to `./inference_results` by default Here are some results for examples.
-
-Example 1:
-
-
-
-
- Visualization of layout analysis result on PMC4958442_00003.jpg
-
-
-, where the saved layout_result.txt file is as follows
-```
-{"image_id": 0, "category_id": 1, "bbox": [308.649, 559.189, 240.211, 81.412], "score": 0.98431}
-{"image_id": 0, "category_id": 1, "bbox": [50.435, 673.018, 240.232, 70.262], "score": 0.98414}
-{"image_id": 0, "category_id": 3, "bbox": [322.805, 348.831, 225.949, 203.302], "score": 0.98019}
-{"image_id": 0, "category_id": 1, "bbox": [308.658, 638.657, 240.31, 70.583], "score": 0.97986}
-{"image_id": 0, "category_id": 1, "bbox": [50.616, 604.736, 240.044, 70.086], "score": 0.9797}
-{"image_id": 0, "category_id": 1, "bbox": [50.409, 423.237, 240.132, 183.652], "score": 0.97805}
-{"image_id": 0, "category_id": 1, "bbox": [308.66, 293.918, 240.181, 47.497], "score": 0.97471}
-{"image_id": 0, "category_id": 1, "bbox": [308.64, 707.13, 240.271, 36.028], "score": 0.97427}
-{"image_id": 0, "category_id": 1, "bbox": [308.697, 230.568, 240.062, 43.545], "score": 0.96921}
-{"image_id": 0, "category_id": 4, "bbox": [51.787, 100.444, 240.267, 273.653], "score": 0.96839}
-{"image_id": 0, "category_id": 5, "bbox": [308.637, 74.439, 237.878, 149.174], "score": 0.96707}
-{"image_id": 0, "category_id": 1, "bbox": [50.615, 70.667, 240.068, 22.0], "score": 0.94156}
-{"image_id": 0, "category_id": 2, "bbox": [50.549, 403.5, 67.392, 12.85], "score": 0.92577}
-{"image_id": 0, "category_id": 1, "bbox": [51.384, 374.84, 171.939, 10.736], "score": 0.76692}
-```
-In this file, `image_id` is the image ID, `bbox` is the detected bounding box `[x-coordinate of the top-left corner, y-coordinate of the bottom-right corner, width, height]`, `score` is the detection confidence, and `category_id` has the following meanings:
-- `1: text`
-- `2: title`
-- `3: list`
-- `4: table`
-- `5: figure`
-
-**Notes:**
-- For more argument illustrations and usage, please run `python tools/infer/text/predict_layout.py -h` or view `tools/infer/text/config.py`
-
### Supported Detection Algorithms and Networks
@@ -291,6 +248,53 @@ Evaluation of the text spotting inference results on Ascend 910 with MindSpore 2
2. Unless extra inidication, all experiments are run with `--det_limit_type`="min" and `--det_limit_side`=720.
3. SVTR is run in mixed precision mode (amp_level=O2) since it is optimized for O2.
+### Text direction classification
+
+If there are non-upright text characters in the image, they can be classified and corrected for orientation using a text direction classifier after the detection. If you run text direction classification and correction on an input image, please perform
+```shell
+python tools/infer/text/predict_system.py --image_dir {path_to_img or dir_to_imgs} \
+ --det_algorithm DB++ \
+ --rec_algorithm CRNN \
+ --cls_algorithm M3
+```
+The default parameter `--cls_alorithm` is None, which means that text direction classification is not performed. By setting `--cls_alorithm`, text direction classification is performed in the text detection and recognition flow. In the process of execution, the text direction classifier classifies the list of images detected by the text and corrects the direction of the non-upright images. Here are some examples of the results.
+
+- Text direction classification
+
+
+
+
+
+ word_01.png
+
+
+
+
+
+
+ word_02.png
+
+
+Classification Results::
+```text
+word_01.png 0 1.0
+word_02.png 180 1.0
+```
+
+The currently supported text direction classification network is `mobilnet_v3`, which can be set by configuring `--cls_algorithm` for `M3`. And through `--cls_amp_level` and `--cls_model_dir` to set the text direction classifier automatic mixing precision and weight file. At present, the default weight file has been configured, the default mixing precision of the network is `O0`, and the direction classification supports `0` and `180` degrees under the default configuration. We will support the classification of other directions in the future.
+
+
+
+ |**Algorithm Name**|**Network Name**|**Language**|
+ | :------: | :------: | :------: |
+ | M3 | mobilenet_v3 | CH/EN|
+
+
+
+In addition, by setting `--save_cls_result` to `True`, text orientation classification results can be saved to `{args.crop_res_save_dir}/cls_results.txt`, Where `--crop_res_save_dir` is the directory where the results are saved.
+
+For more parameter descriptions and usage information, please refer to `tools/infer/text/config.py`.
+
## Table Structure Recognition
To run table structure recognition on an input image or multiple images in a directory, please run:
@@ -373,6 +377,105 @@ HDL Cholesterol (mg/dL),42 ± 11.1,46 ± 11.4
**Notes:**
1. For more argument illustrations and usage, please run `python tools/infer/text/predict_table_recognition.py -h` or view `tools/infer/text/config.py`
+## Layout Analysis
+
+To run layout analysis on an input image or a directory containing multiple images, please execute
+```shell
+python tools/infer/text/predict_layout.py --image_dir {path_to_img or dir_to_imgs} --layout_algorithm YOLOv8 --visualize_output True
+```
+After running, the inference results will be saved in `{args.draw_img_save_dir}/det_results.txt`, where `--draw_img_save_dir` is the directory for saving results and is set to `./inference_results` by default Here are some results for examples.
+
+Example 1:
+
+
+
+
+ Visualization of layout analysis result on PMC4958442_00003.jpg
+
+
+, where the saved layout_result.txt file is as follows
+```
+{"image_id": 0, "category_id": 1, "bbox": [308.649, 559.189, 240.211, 81.412], "score": 0.98431}
+{"image_id": 0, "category_id": 1, "bbox": [50.435, 673.018, 240.232, 70.262], "score": 0.98414}
+{"image_id": 0, "category_id": 3, "bbox": [322.805, 348.831, 225.949, 203.302], "score": 0.98019}
+{"image_id": 0, "category_id": 1, "bbox": [308.658, 638.657, 240.31, 70.583], "score": 0.97986}
+{"image_id": 0, "category_id": 1, "bbox": [50.616, 604.736, 240.044, 70.086], "score": 0.9797}
+{"image_id": 0, "category_id": 1, "bbox": [50.409, 423.237, 240.132, 183.652], "score": 0.97805}
+{"image_id": 0, "category_id": 1, "bbox": [308.66, 293.918, 240.181, 47.497], "score": 0.97471}
+{"image_id": 0, "category_id": 1, "bbox": [308.64, 707.13, 240.271, 36.028], "score": 0.97427}
+{"image_id": 0, "category_id": 1, "bbox": [308.697, 230.568, 240.062, 43.545], "score": 0.96921}
+{"image_id": 0, "category_id": 4, "bbox": [51.787, 100.444, 240.267, 273.653], "score": 0.96839}
+{"image_id": 0, "category_id": 5, "bbox": [308.637, 74.439, 237.878, 149.174], "score": 0.96707}
+{"image_id": 0, "category_id": 1, "bbox": [50.615, 70.667, 240.068, 22.0], "score": 0.94156}
+{"image_id": 0, "category_id": 2, "bbox": [50.549, 403.5, 67.392, 12.85], "score": 0.92577}
+{"image_id": 0, "category_id": 1, "bbox": [51.384, 374.84, 171.939, 10.736], "score": 0.76692}
+```
+In this file, `image_id` is the image ID, `bbox` is the detected bounding box `[x-coordinate of the top-left corner, y-coordinate of the bottom-right corner, width, height]`, `score` is the detection confidence, and `category_id` has the following meanings:
+- `1: text`
+- `2: title`
+- `3: list`
+- `4: table`
+- `5: figure`
+
+**Notes:**
+- For more argument illustrations and usage, please run `python tools/infer/text/predict_layout.py -h` or view `tools/infer/text/config.py`
+
+## End-to-end Document Analysis and Recovery
+
+To run end-to-end document analysis and recovery on an input image or multiple images in a directory (detecting all the text, table, and figure regions, recognizing words in these regions, and putting everything into docx files according to the original layout), please run:
+
+```shell
+python tools/infer/text/predict_table_e2e.py --image_dir {path_to_img or dir_to_imgs} \
+ --det_algorithm {DET_ALGO} \
+ --rec_algorithm {REC_ALGO}
+```
+>Note: To visualize the outputs of layout, table and ocr, please set `--visualize_output True`.
+
+After running, the inference results will be saved in `{args.draw_img_save_dir}/{img_name}_e2e_result.txt`, where `--draw_img_save_dir` is the directory for saving results and is set to `./inference_results` by default. Here are some results for examples.
+
+Example 1:
+
+
+
+
+
+ PMC4958442_00003.jpg Converting into docx
+
+
+, where the saved txt file is as follows
+```text
+{"type": "text", "bbox": [50.615, 70.667, 290.683, 92.667], "res": "tabley predictive value ofbasic clinical laboratory and suciode variables surney anc yea after tramphenins", "layout": "double"}
+{"type": "table", "bbox": [51.787, 100.444, 292.054, 374.09700000000004], "res": "sign factor | prediction valucofthe the | from difereness significance levelaf the |
gender | 0027 0021 | o442 |
| 00z44 | 0480 |
cause | tooza 0017 | o547 |
cadaverieilizing donorst | 0013 aont | 0740 |
induction transplantation before dialysis | doattoos | 0125 |
depleting antibodies monoclomalor cn immunosuppression with | doista09 | 0230 |
ititis | 0029 | aaso |
status itional | 0047 toots | |
townfrillage | non | |
transplantations number | toos 0017 | o5s1 |
creatinine | 02400g | caoor |
pressure bload systolic | aidaloloss | aoz |
pressure diastolic blood | dobetods | ass |
hemoglobin | 0044 0255t | caoor |
| 004 | caoor |
", "layout": "double"}
+{"type": "text", "bbox": [51.384, 374.84, 223.32299999999998, 385.57599999999996], "res": "nanc rmeans more significant forecasting factor sign", "layout": "double"}
+{"type": "title", "bbox": [50.549, 403.5, 117.941, 416.35], "res": "discussion", "layout": "double"}
+{"type": "text", "bbox": [50.409, 423.237, 290.541, 606.889], "res": "determination of creatinine and hemoglobin level in the blood well aetho concentration of protein in the urine in one year atter kidney transplantation with the calculation of prognostic criterion predics the loss of renal allotransplant function in years fafter surgery advantages ff the method are the possibility oof quantitative forecasting of renal allotransplant losser which based not only its excretory function assessment but also on assessment other characteristics that may have important prognostic value and does not always directly correlate with changes in its excretors function in order the riskof death with transplant sfunctioning returntothe program hemodialysis the predictive model was implemented cabular processor excel forthe useofthe model litisquite enough the value ethel given indices calculation and prognosis will be automatically done in the electronic table figure 31", "layout": "double"}
+{"type": "text", "bbox": [50.616, 604.736, 290.66, 674.822], "res": "the calculator designed by us has been patented chttpell napatentscomy 68339 sposib prognozuvannys vtrati funk caniskovogo transplanatchti and disnvailable on the in ternet chitpsolivad skillwond the accuract ot prediction of renal transplant function loss three years after transplantation was 92x", "layout": "double"}
+{"type": "text", "bbox": [50.435, 673.018, 290.66700000000003, 743.28], "res": "progression of chronic renal dysfunctional the transplant accompanied the simultaneous losa the benefits of successful transplantation and the growth of problems due to immunosuppresson bosed on retrospective analysis nt resultsof treatment tofkidney transplantof the recipients with blood creatinine higher than d3 immold we adhere to the", "layout": "double"}
+{"type": "figure", "bbox": [308.637, 74.439, 546.515, 223.613], "res": "./inference_results/example_figure_10.png", "layout": "double"}
+{"type": "text", "bbox": [308.697, 230.568, 548.759, 274.113], "res": "figures the cnerhecadfmuthrnatical modeltor prognostication ofkidaey transplant function during the periodal three years after thetransplantation according oletectercipiolgaps after theoperation", "layout": "double"}
+{"type": "text", "bbox": [308.66, 293.918, 548.841, 341.415], "res": "following principles in thecorrectionod immunisuppresion which allow decreasing the rateofs chronic dysfunctionof the transplant development or edecreasing the risk fof compliea tions incaeoflasof function", "layout": "double"}
+{"type": "list", "bbox": [322.805, 348.831, 548.754, 552.133], "res": "wdo not prescribe hish doses steroids and do have the steroid pulse therapy cy do not increase the dose of received cyclosporine tacrolimus and stop medication ifthere isan increase in nephropathy tj continue immunosuppression with medicines ofmy cophenolic acid which are not nephrotoxic k4 enhance amonitoring of immunosuppression andpe vention infectious com cancel immunosuppression atreturning hemodi alysis treatment cancellation of steroids should done egradually sometimes for several months when thediscomfort eassociated transplant tempera ture main in the projection the transplanted kidney and hematurial short course of low doses of steroids administered orally of intravenously can be effective", "layout": "double"}
+{"type": "text", "bbox": [308.649, 559.189, 548.86, 640.601], "res": "according to plasma concentration of creatinine the return hemodialvsis the patients were divided into groups ln the first group the creatinine concentration in blood plasma waso mmoly in the 2nd groun con centration in blood plasma was azlommaty and in the third group concentration in blood plasma was more than commolt", "layout": "double"}
+{"type": "text", "bbox": [308.658, 638.657, 548.9680000000001, 709.24], "res": "dates or the return of transplant recipients with delaved rena transplant disfunction are largely dependent ion the psychological state ofthe patient severity of depression the desire to ensure the irreversibility the transplanted kidney dysfunction and fear that the dialysis will contribute to the deterioration of renal transplant function", "layout": "double"}
+{"type": "text", "bbox": [308.64, 707.13, 548.911, 743.158], "res": "the survival rateof patients ofthe first group after return in hemodialysis was years and in the second and third groups respectively 53132 and28426 years", "layout": "double"}
+
+```
+In this file, `type` is the classification of the detected region, `bbox` is the detected bounding box `[x-coordinate of the top-left corner, y-coordinate of the bottom-right corner, width, height]`, and `res` is the detected result.
+
+**Notes:**
+1. For more argument illustrations and usage, please run `python tools/infer/text/predict_table_e2e.py -h` or view `tools/infer/text/config.py`
+2. Besides the parameters in the config.py, predict_table_e2e.py also accepts the following parameters:
+
+
+ | **Parameter** |**Description**| **Default** |
+ |:------------:| :------: |:------:|
+ | layout | Layout Analysis | True |
+ | ocr | Text Recognition | True |
+ | table | Table Analysis | True |
+ | recovery | Docx Convertion | True |
+
+
+
## Argument List
All CLI argument definition can be viewed via `python tools/infer/text/predict_system.py -h` or reading `tools/infer/text/config.py`.
diff --git a/tools/infer/text/README_CN.md b/tools/infer/text/README_CN.md
index 5a673f45b..e61e8b1d9 100644
--- a/tools/infer/text/README_CN.md
+++ b/tools/infer/text/README_CN.md
@@ -230,6 +230,52 @@ python deploy/eval_utils/eval_pipeline.py --gt_path path/to/gt.txt --pred_path p
3、SVTR在混合精度模式下运行(amp_level=O2),因为它针对O2进行了优化。
+### 文本方向分类
+
+若图像中存在非正向的文字,可通过文本方向分类器对检测后的图像进行方向分类与矫正。若对输入图像运行文本方向分类与矫正,请执行
+```shell
+python tools/infer/text/predict_system.py --image_dir {path_to_img or dir_to_imgs} \
+ --det_algorithm DB++ \
+ --rec_algorithm CRNN \
+ --cls_algorithm M3
+```
+其中,参数`--cls_alorithm`默认配置为None,表示不执行文本方向分类,通过设置`--cls_alorithm`即可在文本检测识别流程中进行文本方向分类。执行过程中,文本方向分类器将对文本检测所得图像列表进行方向分类,并对非正向的图像进行方向矫正。以下为部分结果示例。
+
+- 文本方向分类
+
+
+
+
+
+ word_01.png
+
+
+
+
+
+
+ word_02.png
+
+
+分类结果:
+```text
+word_01.png 0 1.0
+word_02.png 180 1.0
+```
+当前支持的文本方向分类网络为`mobilnet_v3`,可通过配置`--cls_algorithm`为`M3`进行设置,并通过`--cls_amp_level`与`--cls_model_dir`来设置文本方向分类器的自动混合精度与权重文件。当前已配置默认权重文件,该网络默认混合精度为`O0`,默认配置下方向分类支持`0`与`180`度,对于其他方向的分类我们将在未来予以支持。
+
+
+
+ |**算法名称**|**网络名称**|**语言**|
+ | :------: | :------: | :------: |
+ | M3 | mobilenet_v3 | 中/英|
+
+
+
+此外,可通过设置`--save_cls_result`为`True`可将文本方向分类结果保存至`{args.crop_res_save_dir}/cls_results.txt`中,其中`--crop_res_save_dir`是保存结果的目录。
+
+有关更多参数说明和用法,请查看`tools/infer/text/config.py`
+
## 表格结构识别
要对输入图像或包含多个图像的目录运行表格结构识别,请执行
@@ -311,10 +357,6 @@ HDL Cholesterol (mg/dL),42 ± 11.1,46 ± 11.4
**注意事项:**
1、如需更多参数说明和用法,请运行`python tools/infer/text/predict_table_recognition.py -h`或查看`tools/infer/text/config.py`
-## 参数列表
-
-所有CLI参数定义都可以通过`python tools/infer/text/predict_system.py -h`或`tools/infer/text/config.py`查看。
-
## 版面分析
要对输入图像或包含多个图像的目录运行版面分析,请执行
@@ -370,6 +412,66 @@ python tools/infer/text/predict_layout.py --image_dir {path_to_img or dir_to_im
算法网络在`tools/infer/text/predict_layout.py`中定义。
+## 端到端文档分析及恢复
+
+要对输入图像或目录中的多个图像运行文档分析(即检测所有文本区域、表格区域、图像区域,并对这些区域进行文字识别,最终将结果按照图像原来的排版方式转换成Docx文件),请运行:
+
+```shell
+python tools/infer/text/predict_table_e2e.py --image_dir {path_to_img or dir_to_imgs} \
+ --det_algorithm {DET_ALGO} \
+ --rec_algorithm {REC_ALGO}
+```
+>注意:如果要可视化版面分析、表格识别和文字识别的结果,请设置`--visualize_output True`。
+
+运行后,推理结果保存在`{args.draw_img_save_dir}/{img_name}_e2e_result.txt`中,其中`--draw_img_save_dir`是保存结果的目录,这是`./inference_results`的默认设置。下面是一些结果的例子。
+
+示例1:
+
+
+
+
+
+ PMC4958442_00003.jpg转换成docx文件的效果
+
+
+其中保存的txt文件如下
+```text
+{"type": "text", "bbox": [50.615, 70.667, 290.683, 92.667], "res": "tabley predictive value ofbasic clinical laboratory and suciode variables surney anc yea after tramphenins", "layout": "double"}
+{"type": "table", "bbox": [51.787, 100.444, 292.054, 374.09700000000004], "res": "sign factor | prediction valucofthe the | from difereness significance levelaf the |
gender | 0027 0021 | o442 |
| 00z44 | 0480 |
cause | tooza 0017 | o547 |
cadaverieilizing donorst | 0013 aont | 0740 |
induction transplantation before dialysis | doattoos | 0125 |
depleting antibodies monoclomalor cn immunosuppression with | doista09 | 0230 |
ititis | 0029 | aaso |
status itional | 0047 toots | |
townfrillage | non | |
transplantations number | toos 0017 | o5s1 |
creatinine | 02400g | caoor |
pressure bload systolic | aidaloloss | aoz |
pressure diastolic blood | dobetods | ass |
hemoglobin | 0044 0255t | caoor |
| 004 | caoor |
", "layout": "double"}
+{"type": "text", "bbox": [51.384, 374.84, 223.32299999999998, 385.57599999999996], "res": "nanc rmeans more significant forecasting factor sign", "layout": "double"}
+{"type": "title", "bbox": [50.549, 403.5, 117.941, 416.35], "res": "discussion", "layout": "double"}
+{"type": "text", "bbox": [50.409, 423.237, 290.541, 606.889], "res": "determination of creatinine and hemoglobin level in the blood well aetho concentration of protein in the urine in one year atter kidney transplantation with the calculation of prognostic criterion predics the loss of renal allotransplant function in years fafter surgery advantages ff the method are the possibility oof quantitative forecasting of renal allotransplant losser which based not only its excretory function assessment but also on assessment other characteristics that may have important prognostic value and does not always directly correlate with changes in its excretors function in order the riskof death with transplant sfunctioning returntothe program hemodialysis the predictive model was implemented cabular processor excel forthe useofthe model litisquite enough the value ethel given indices calculation and prognosis will be automatically done in the electronic table figure 31", "layout": "double"}
+{"type": "text", "bbox": [50.616, 604.736, 290.66, 674.822], "res": "the calculator designed by us has been patented chttpell napatentscomy 68339 sposib prognozuvannys vtrati funk caniskovogo transplanatchti and disnvailable on the in ternet chitpsolivad skillwond the accuract ot prediction of renal transplant function loss three years after transplantation was 92x", "layout": "double"}
+{"type": "text", "bbox": [50.435, 673.018, 290.66700000000003, 743.28], "res": "progression of chronic renal dysfunctional the transplant accompanied the simultaneous losa the benefits of successful transplantation and the growth of problems due to immunosuppresson bosed on retrospective analysis nt resultsof treatment tofkidney transplantof the recipients with blood creatinine higher than d3 immold we adhere to the", "layout": "double"}
+{"type": "figure", "bbox": [308.637, 74.439, 546.515, 223.613], "res": "./inference_results/example_figure_10.png", "layout": "double"}
+{"type": "text", "bbox": [308.697, 230.568, 548.759, 274.113], "res": "figures the cnerhecadfmuthrnatical modeltor prognostication ofkidaey transplant function during the periodal three years after thetransplantation according oletectercipiolgaps after theoperation", "layout": "double"}
+{"type": "text", "bbox": [308.66, 293.918, 548.841, 341.415], "res": "following principles in thecorrectionod immunisuppresion which allow decreasing the rateofs chronic dysfunctionof the transplant development or edecreasing the risk fof compliea tions incaeoflasof function", "layout": "double"}
+{"type": "list", "bbox": [322.805, 348.831, 548.754, 552.133], "res": "wdo not prescribe hish doses steroids and do have the steroid pulse therapy cy do not increase the dose of received cyclosporine tacrolimus and stop medication ifthere isan increase in nephropathy tj continue immunosuppression with medicines ofmy cophenolic acid which are not nephrotoxic k4 enhance amonitoring of immunosuppression andpe vention infectious com cancel immunosuppression atreturning hemodi alysis treatment cancellation of steroids should done egradually sometimes for several months when thediscomfort eassociated transplant tempera ture main in the projection the transplanted kidney and hematurial short course of low doses of steroids administered orally of intravenously can be effective", "layout": "double"}
+{"type": "text", "bbox": [308.649, 559.189, 548.86, 640.601], "res": "according to plasma concentration of creatinine the return hemodialvsis the patients were divided into groups ln the first group the creatinine concentration in blood plasma waso mmoly in the 2nd groun con centration in blood plasma was azlommaty and in the third group concentration in blood plasma was more than commolt", "layout": "double"}
+{"type": "text", "bbox": [308.658, 638.657, 548.9680000000001, 709.24], "res": "dates or the return of transplant recipients with delaved rena transplant disfunction are largely dependent ion the psychological state ofthe patient severity of depression the desire to ensure the irreversibility the transplanted kidney dysfunction and fear that the dialysis will contribute to the deterioration of renal transplant function", "layout": "double"}
+{"type": "text", "bbox": [308.64, 707.13, 548.911, 743.158], "res": "the survival rateof patients ofthe first group after return in hemodialysis was years and in the second and third groups respectively 53132 and28426 years", "layout": "double"}
+
+```
+其中,`type`为检测区域的类型,`bbox`为检测出的边界框`[左上角的x坐标,右下角的y坐标,宽度,高度]`, `res`是检测的结果内容。
+
+**注意事项:**
+1. 如需更多参数说明和用法,请运行`python tools/infer/text/predict_table_e2e.py -h`或查看`tools/infer/text/config.py`
+2. 除了config.py中的参数,predict_table_e2e.py还接受如下参数:
+
+
+ | **参数名** |**描述**| **默认值** |
+ |:------------:| :------: |:------:|
+ | layout | 版面分析任务 | True |
+ | ocr | 文字识别任务 | True |
+ | table | 表格识别任务 | True |
+ | recovery | 转换成Docx任务 | True |
+
+
+
+## 参数列表
+
+所有CLI参数定义都可以通过`python tools/infer/text/predict_system.py -h`或`tools/infer/text/config.py`查看。
+
## 开发人员指南-如何添加新的推断模型
### 预处理
diff --git a/tools/infer/text/config.py b/tools/infer/text/config.py
index 292166662..47979760e 100644
--- a/tools/infer/text/config.py
+++ b/tools/infer/text/config.py
@@ -126,7 +126,7 @@ def create_parser():
"--draw_img_save_dir",
type=str,
default="./inference_results",
- help="Dir to save visualization and detection/recogintion/system prediction results",
+ help="Dir to save visualization and detection/recognition/system prediction results",
)
parser.add_argument(
"--save_crop_res",
@@ -166,6 +166,15 @@ def create_parser():
"due to padding or resizing to the same shape.",
)
parser.add_argument("--kie_batch_num", type=int, default=8)
+
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=str,
+ default="",
+ help="YAML config file specifying default arguments (default=" ")",
+ )
+
parser.add_argument(
"--table_algorithm",
type=str,
@@ -198,15 +207,24 @@ def create_parser():
)
parser.add_argument(
- "--layout_algorithm", type=str, default="YOLOv8", choices=["YOLOv8"], help="layout analyzer algorithm"
+ "--layout_algorithm",
+ type=str,
+ default="YOLOv8",
+ choices=["YOLOv8", "LAYOUTLMV3"],
+ help="layout analyzer algorithm",
)
-
parser.add_argument(
"--layout_model_dir",
type=str,
help="directory containing the layout model checkpoint best.ckpt, or path to a specific checkpoint file.",
) # determine the network weights
-
+ parser.add_argument(
+ "--layout_category_dict_path",
+ type=str,
+ default="./mindocr/utils/dict/layout_category_dict.txt",
+ help="path to category dictionary for layout recognition. "
+ "If None, will pick according to layout_algorithm and layout_model_dir.",
+ )
parser.add_argument(
"--layout_amp_level",
type=str,
@@ -215,6 +233,35 @@ def create_parser():
help="Auto Mixed Precision level. This setting only works on GPU and Ascend",
)
+ parser.add_argument(
+ "--cls_algorithm",
+ type=str,
+ default=None,
+ choices=["M3"],
+ help="classification algorithm. The default is None,"
+ "meaning that text orientation classification is not performed",
+ )
+ parser.add_argument(
+ "--cls_amp_level",
+ type=str,
+ default="O0",
+ choices=["O0", "O1", "O2", "O3"],
+ help="Auto Mixed Precision level. This setting only works on GPU and Ascend",
+ )
+ parser.add_argument(
+ "--cls_model_dir",
+ type=str,
+ help="directory containing the classification model checkpoint best.ckpt"
+ "or path to a specific checkpoint file.",
+ )
+ parser.add_argument("--cls_batch_num", type=int, default=8)
+ parser.add_argument(
+ "--save_cls_result",
+ type=str2bool,
+ default=True,
+ help="whether to save the text direction classification result",
+ )
+
return parser
diff --git a/tools/infer/text/postprocess.py b/tools/infer/text/postprocess.py
index 940317539..fa90dd44a 100644
--- a/tools/infer/text/postprocess.py
+++ b/tools/infer/text/postprocess.py
@@ -81,6 +81,20 @@ def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs):
elif task == "ser":
class_path = "mindocr/utils/dict/class_list_xfun.txt"
postproc_cfg = dict(name="VQASerTokenLayoutLMPostProcess", class_path=class_path)
+ elif task == "layout":
+ if algo == "LAYOUTLMV3":
+ postproc_cfg = dict(
+ name="Layoutlmv3Postprocess",
+ conf_thres=0.05,
+ iou_thres=0.5,
+ conf_free=False,
+ multi_label=True,
+ time_limit=100,
+ )
+ elif algo == "YOLOv8":
+ postproc_cfg = dict(name="YOLOv8Postprocess", conf_thres=0.5, iou_thres=0.7, conf_free=True)
+ else:
+ raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.")
elif task == "table":
table_char_dict_path = kwargs.get(
"table_char_dict_path", "mindocr/utils/dict/table_master_structure_dict.txt"
@@ -91,8 +105,8 @@ def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs):
merge_no_span_structure=True,
box_shape="pad",
)
- elif task == "layout":
- postproc_cfg = dict(name="YOLOv8Postprocess", conf_thres=0.5, iou_thres=0.7, conf_free=True)
+ elif task == "cls":
+ postproc_cfg = dict(name="ClsPostprocess", label_list=["0", "180"])
postproc_cfg.update(kwargs)
self.task = task
@@ -156,6 +170,10 @@ def __call__(self, pred, data=None, **kwargs):
return output
elif self.task == "table":
output = self.postprocess(pred, labels=kwargs.get("labels"))
+ return output
elif self.task == "layout":
output = self.postprocess(pred, img_shape=kwargs.get("img_shape"), meta_info=kwargs.get("meta_info"))
return output
+ elif self.task == "cls":
+ output = self.postprocess(pred)
+ return output
diff --git a/tools/infer/text/predict_layout.py b/tools/infer/text/predict_layout.py
index a9fb47205..b616342ed 100644
--- a/tools/infer/text/predict_layout.py
+++ b/tools/infer/text/predict_layout.py
@@ -1,5 +1,5 @@
"""
-Infer layout from images using yolov8 model.
+Layout analyzer inference
Example:
$ python tools/infer/text/predict_layout.py --image_dir {path_to_img} --layout_algorithm YOLOv8
@@ -7,10 +7,12 @@
import json
import logging
import os
-from typing import Dict, List
+from typing import List
import cv2
import numpy as np
+import yaml
+from addict import Dict
from postprocess import Postprocessor
from preprocess import Preprocessor
from utils import get_ckpt_file, get_image_paths
@@ -20,9 +22,7 @@
from mindocr import build_model
from mindocr.utils.logger import set_logger
-algo_to_model_name = {
- "YOLOv8": "yolov8",
-}
+algo_to_model_name = {"YOLOv8": "yolov8", "LAYOUTLMV3": "layoutlmv3"}
logger = logging.getLogger("mindocr")
@@ -35,8 +35,16 @@ def __init__(self, args):
self.img_dir = os.path.dirname(args.image_dir)
self.vis_dir = args.draw_img_save_dir
+ cfg = None
+ if args.config:
+ with open(args.config, "r") as f:
+ cfg = yaml.safe_load(f)
+ self.cfg = Dict(cfg)
+
# build model for algorithm with pretrained weights or local checkpoint
ckpt_dir = args.layout_model_dir
+ if self.cfg.predict.ckpt_load_path:
+ ckpt_dir = self.cfg.predict.ckpt_load_path
if ckpt_dir is None:
pretrained = True
ckpt_load_path = None
@@ -51,6 +59,9 @@ def __init__(self, args):
f"Supported layout algorithms are {list(algo_to_model_name.keys())}"
)
model_name = algo_to_model_name[args.layout_algorithm]
+ self.model_name = model_name
+ if self.cfg:
+ model_name = self.cfg.model
self.model = build_model(
model_name,
pretrained=pretrained,
@@ -60,8 +71,8 @@ def __init__(self, args):
)
self.model.set_train(False)
- self.preprocess = Preprocessor(task="layout")
- self.postprocess = Postprocessor(task="layout")
+ self.preprocess = Preprocessor(task="layout", algo=args.layout_algorithm)
+ self.postprocess = Postprocessor(task="layout", algo=args.layout_algorithm)
def __call__(self, img_path: str, do_visualize: bool = False) -> List:
"""
@@ -89,7 +100,11 @@ def __call__(self, img_path: str, do_visualize: bool = False) -> List:
self.img_shape = net_input.shape
# infer
- preds = self.model(net_input)
+ if self.model_name == "layoutlmv3":
+ input = [net_input, ms.Tensor(np.array([self.hw_ori])), ms.Tensor(np.array([self.hw_scale]))]
+ preds = self.model(*input)
+ else:
+ preds = self.model(net_input)
# postprocess
results = self.postprocess(
@@ -98,11 +113,13 @@ def __call__(self, img_path: str, do_visualize: bool = False) -> List:
if do_visualize:
img_name = os.path.basename(img_path).rsplit(".", 1)[0]
- visualize_layout(img_path, results, save_path=os.path.join(self.vis_dir, img_name + "_layout_result.png"))
+ self.visualize_layout(
+ img_path, results, save_path=os.path.join(self.vis_dir, img_name + "_layout_result.png")
+ )
return results
- def _load_image(self, img_path: str) -> Dict:
+ def _load_image(self, img_path: str):
"""
Load image from path
"""
@@ -110,65 +127,72 @@ def _load_image(self, img_path: str) -> Dict:
h_ori, w_ori = image.shape[:2] # orig hw
hw_ori = np.array([h_ori, w_ori])
target_size = 800
- r = target_size / max(h_ori, w_ori) # resize image to img_size
- if r != 1: # always resize down, only resize up if training with augmentation
- interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
- image = cv2.resize(image, (int(w_ori * r), int(h_ori * r)), interpolation=interp)
+ if self.model_name == "layoutlmv3":
+ r = target_size / min(h_ori, w_ori)
+ image = cv2.resize(image, (int(round(w_ori * r)), int(round(h_ori * r))), interpolation=cv2.INTER_LINEAR)
+ else:
+ r = target_size / max(h_ori, w_ori) # resize image to img_size
+ if r != 1: # always resize down, only resize up if training with augmentation
+ interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
+ image = cv2.resize(image, (int(w_ori * r), int(h_ori * r)), interpolation=interp)
data = {"image": image, "raw_img_shape": hw_ori, "target_size": target_size}
return data
+ def visualize_layout(self, image_path, results, conf_thres=0.8, save_path: str = ""):
+ """
+ Visualize layout analysis results
+ """
+ from matplotlib import pyplot as plt
+ from PIL import Image
+
+ img = Image.open(image_path)
+ img_cv = cv2.imread(image_path)
+
+ fig, ax = plt.subplots()
+ ax.imshow(img)
+
+ category_dict = {1: "text", 2: "title", 3: "list", 4: "table", 5: "figure"}
+ color_dict = {1: (255, 0, 0), 2: (0, 0, 255), 3: (0, 255, 0), 4: (0, 255, 255), 5: (255, 0, 255)}
+ if self.cfg.predict.category_dict:
+ category_dict = self.cfg.predict.category_dict
+ if self.cfg.predict.color_dict:
+ color_dict = self.cfg.predict.color_dict
+
+ for item in results:
+ category_id = item["category_id"]
+ bbox = item["bbox"]
+ score = item["score"]
+
+ if score < conf_thres:
+ continue
+
+ left, bottom, w, h = bbox
+ right = left + w
+ top = bottom + h
+
+ cv2.rectangle(img_cv, (int(left), int(bottom)), (int(right), int(top)), color_dict[category_id], 2)
+
+ label = "{} {:.2f}".format(category_dict[category_id], score)
+ label_size, base_line = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
+ top = max(top, label_size[1])
+ cv2.rectangle(
+ img_cv,
+ (int(left), int(bottom - label_size[1] - base_line)),
+ (int(left + label_size[0]), int(bottom)),
+ color_dict[category_id],
+ cv2.FILLED,
+ )
+ cv2.putText(
+ img_cv, label, (int(left), int(bottom - base_line)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
+ )
-def visualize_layout(image_path, results, conf_thres=0.8, save_path: str = ""):
- """
- Visualize layout analysis results
- """
- from matplotlib import pyplot as plt
- from PIL import Image
-
- img = Image.open(image_path)
- img_cv = cv2.imread(image_path)
-
- fig, ax = plt.subplots()
- ax.imshow(img)
-
- category_dict = {1: "text", 2: "title", 3: "list", 4: "table", 5: "figure"}
- color_dict = {1: (255, 0, 0), 2: (0, 0, 255), 3: (0, 255, 0), 4: (0, 255, 255), 5: (255, 0, 255)}
-
- for item in results:
- category_id = item["category_id"]
- bbox = item["bbox"]
- score = item["score"]
-
- if score < conf_thres:
- continue
-
- left, bottom, w, h = bbox
- right = left + w
- top = bottom + h
-
- cv2.rectangle(img_cv, (int(left), int(bottom)), (int(right), int(top)), color_dict[category_id], 2)
-
- label = "{} {:.2f}".format(category_dict[category_id], score)
- label_size, base_line = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
- top = max(top, label_size[1])
- cv2.rectangle(
- img_cv,
- (int(left), int(bottom - label_size[1] - base_line)),
- (int(left + label_size[0]), int(bottom)),
- color_dict[category_id],
- cv2.FILLED,
- )
- cv2.putText(
- img_cv, label, (int(left), int(bottom - base_line)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
- )
-
- if save_path:
- cv2.imwrite(save_path, img_cv)
- else:
- plt.axis("off")
- plt.imshow(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
- plt.show()
+ if save_path:
+ cv2.imwrite(save_path, img_cv)
+ else:
+ plt.axis("off")
+ plt.imshow(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
+ plt.show()
def save_layout_res(layout_res: List, img_path: str, save_dir: str):
diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py
index d78c4805f..ee7636c09 100644
--- a/tools/infer/text/predict_system.py
+++ b/tools/infer/text/predict_system.py
@@ -13,31 +13,235 @@
import os
import sys
from time import time
-from typing import Union
+from typing import List, Union
import cv2
import numpy as np
from config import parse_args
+from postprocess import Postprocessor
from predict_det import TextDetector
from predict_rec import TextRecognizer
-from utils import crop_text_region, get_image_paths
+from preprocess import Preprocessor
+from utils import crop_text_region, get_image_paths, img_rotate
import mindspore as ms
+from mindspore import ops
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))
+from mindocr import build_model
from mindocr.utils.logger import set_logger
from mindocr.utils.visualize import visualize # noqa
+from tools.infer.text.utils import get_ckpt_file
logger = logging.getLogger("mindocr")
+class TextClassifier(object):
+ """
+ Infer model for text orientation classification
+ Example:
+ >>> args = parse_args()
+ >>> text_classification = TextClassifier(args)
+ >>> img_path = "path/to/image.jpg"
+ >>> cls_res_all = text_classification(image_path)
+ """
+
+ def __init__(self, args):
+ algo_to_model_name = {
+ "M3": "cls_mobilenet_v3_small_100_model",
+ }
+ self.batch_num = args.cls_batch_num
+ logger.info("classify in {} mode {}".format("batch", "batch_size: " + str(self.batch_num)))
+
+ # build model for algorithm with pretrained weights or local checkpoint
+ ckpt_dir = args.cls_model_dir
+ if ckpt_dir is None:
+ pretrained = True
+ ckpt_load_path = None
+ else:
+ ckpt_load_path = get_ckpt_file(ckpt_dir)
+ pretrained = False
+
+ assert args.cls_algorithm in algo_to_model_name, (
+ f"Invalid cls_algorithm: {args.cls_algorithm}. "
+ f"Supported classification algorithms are {list(algo_to_model_name.keys())}"
+ )
+ model_name = algo_to_model_name[args.cls_algorithm]
+
+ amp_level = args.cls_amp_level
+ if amp_level != "O0" and args.cls_algorithm == "M3":
+ logger.warning("The M3 model supports only amp_level O0")
+ amp_level = "O0"
+
+ self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=amp_level)
+ self.model.set_train(False)
+ self.cast_pred_fp32 = amp_level != "O0"
+ if self.cast_pred_fp32:
+ self.cast = ops.Cast()
+ logger.info(
+ "Init classification model: {} --> {}. Model weights loaded from {}".format(
+ args.cls_algorithm, model_name, "pretrained url" if pretrained else ckpt_load_path
+ )
+ )
+
+ # build preprocess
+ self.preprocess = Preprocessor(
+ task="cls",
+ algo=args.cls_algorithm,
+ )
+
+ # build postprocess
+ self.postprocess = Postprocessor(task="cls", algo=args.cls_algorithm)
+
+ def __call__(self, img_or_path_list: list) -> List:
+ """
+ Run text classification serially for input images
+
+ Args:
+ img_or_path_list: list or str for img path or np.array for RGB image
+
+ Returns:
+ list of dict, each contains the follow keys for classification result.
+ e.g. [{'angle': 180, 'score': 1.0}, {'angle': 0, 'score': 1.0}]
+ - angle: text angle
+ - score: prediction confidence
+ """
+
+ assert isinstance(
+ img_or_path_list, (list, str)
+ ), "Input for text classification must be list of images or image paths."
+ logger.info(f"num images for cls: {len(img_or_path_list)}")
+
+ if isinstance(img_or_path_list, list):
+ cls_res_all_crops = self.run_batch(img_or_path_list)
+ else:
+ cls_res_all_crops = self.run_single(img_or_path_list)
+
+ return cls_res_all_crops
+
+ def run_batch(self, img_or_path_list: list):
+ """
+ Run text angle classification serially for input images
+
+ Args:
+ img_or_path_list: list of str for img path or np.array for RGB image
+
+ Return:
+ cls_res: list of tuple, where each tuple is (angle, score)
+ - text angle classification result for each input image in order.
+ where text is the predicted text string, scores is its confidence score.
+ e.g. [(180, 0.9), (0, 1.0)]
+ """
+
+ cls_res = []
+ num_imgs = len(img_or_path_list)
+
+ for idx in range(0, num_imgs, self.batch_num):
+ batch_begin = idx
+ batch_end = min(idx + self.batch_num, num_imgs)
+ logger.info(f"CLS img idx range: [{batch_begin}, {batch_end})")
+ img_batch = []
+
+ # preprocess
+ for j in range(batch_begin, batch_end):
+ data = self.preprocess(img_or_path_list[j])
+ img_batch.append(data["image"])
+
+ # infer
+ img_batch = np.stack(img_batch) if len(img_batch) > 1 else np.expand_dims(img_batch[0], axis=0)
+
+ net_pred = self.model(ms.Tensor(img_batch))
+ if self.cast_pred_fp32:
+ if isinstance(net_pred, (list, tuple)):
+ net_pred = [self.cast(p, ms.float32) for p in net_pred]
+ else:
+ net_pred = self.cast(net_pred, ms.float32)
+
+ # postprocess
+ batch_res = self.postprocess(net_pred)
+ cls_res.extend(list(zip(batch_res["angles"], batch_res["scores"])))
+
+ return cls_res
+
+ def run_single(self, img_or_path: str):
+ """
+ Text angle classification inference on a single image
+
+ Args:
+ img_or_path: str for image path or np.array for image RGB value
+
+ Return:
+ dict with keys:
+ - angle: text angle
+ - score: prediction confidence
+ """
+
+ # preprocess
+ data = self.preprocess(img_or_path)
+
+ # infer
+ input_np = data["image"]
+ if len(input_np.shape) == 3:
+ net_input = np.expand_dims(input_np, axis=0)
+ net_pred = self.model(ms.Tensor(net_input))
+ if self.cast_pred_fp32:
+ if isinstance(net_pred, (list, tuple)):
+ net_pred = [self.cast(p, ms.float32) for p in net_pred]
+ else:
+ net_pred = self.cast(net_pred, ms.float32)
+
+ # postprocess
+ cls_res_raw = self.postprocess(net_pred)
+ cls_res = list(zip(cls_res_raw["angles"], cls_res_raw["scores"]))
+
+ return cls_res
+
+ def save_cls_res(
+ self,
+ cls_res_all,
+ fn="img",
+ save_path="./cls_results.txt",
+ include_score=True,
+ ):
+ """
+ Generate cls_results files that store the angle classification results.
+
+ Args:
+ cls_res_all: list of dict, each contains the follow keys for classification result.
+ fn: customize the prefix name for image information, default is "img".
+ save_path: file storage path
+ include_score: whether to write prediction confidence
+
+ Return:
+ lines: the content of angle information written to the document
+ """
+
+ lines = []
+ for i, cls_res in enumerate(cls_res_all):
+ if include_score:
+ img_pred = f"{fn}_crop_{i}" + "\t" + cls_res[0] + "\t" + str(cls_res[1]) + "\n"
+ else:
+ img_pred = f"{fn}_crop_{i}" + "\t" + cls_res[0] + "\n"
+ lines.append(img_pred)
+
+ with open(save_path, "a", encoding="utf-8") as f_cls:
+ f_cls.writelines(lines)
+ f_cls.close()
+
+
class TextSystem(object):
def __init__(self, args):
self.text_detect = TextDetector(args)
self.text_recognize = TextRecognizer(args)
+ self.cls_algorithm = args.cls_algorithm
+ if self.cls_algorithm is not None:
+ self.text_classification = TextClassifier(args)
+ self.save_cls_result = args.save_cls_result
+ self.save_cls_dir = args.crop_res_save_dir
+
self.box_type = args.det_box_type
self.drop_score = args.drop_score
self.save_crop_res = args.save_crop_res
@@ -87,6 +291,28 @@ def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True):
cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img)
# show_imgs(crops, is_bgr_img=False)
+ if self.cls_algorithm is not None:
+ img_or_path = crops
+ ct = time()
+ cls_res_all = self.text_classification(img_or_path)
+ time_profile["cls"] = time() - ct
+
+ cls_count = 0
+ for i, cls_res in enumerate(cls_res_all):
+ if cls_res[0] != "0":
+ crops[i] = img_rotate(crops[i], -int(cls_res[0]))
+ cls_count = cls_count + 1
+
+ logger.info(
+ f"The number of images corrected by rotation is {cls_count}/{len(cls_res_all)}"
+ f"\nCLS time: {time_profile['cls']}"
+ )
+
+ if self.save_cls_result:
+ os.makedirs(self.crop_res_save_dir, exist_ok=True)
+ save_fp = os.path.join(self.save_cls_dir, "cls_results.txt")
+ self.text_classification.save_cls_res(cls_res_all, fn=fn, save_path=save_fp)
+
# recognize cropped images
rs = time()
rec_res_all_crops = self.text_recognize(crops, do_visualize=False)
diff --git a/tools/infer/text/predict_table_e2e.py b/tools/infer/text/predict_table_e2e.py
new file mode 100644
index 000000000..e2d7c35c9
--- /dev/null
+++ b/tools/infer/text/predict_table_e2e.py
@@ -0,0 +1,265 @@
+"""
+Infer end-to-end from images and convert them into docx.
+
+Example:
+ $ python tools/infer/text/predict_table_e2e.py --image_dir {path_to_img}
+"""
+import json
+import logging
+import os
+import time
+from typing import List
+
+import cv2
+from config import create_parser, str2bool
+from predict_layout import LayoutAnalyzer
+from predict_system import TextSystem
+from predict_table_recognition import TableAnalyzer
+from utils import (
+ add_padding,
+ convert_info_docx,
+ get_dict_from_file,
+ get_image_paths,
+ sort_words_by_poly,
+ sorted_layout_boxes,
+)
+
+logger = logging.getLogger("mindocr")
+
+
+def e2e_parse_args():
+ """
+ Inherit the parser from the config.py file, and add the following arguments:
+ 1. layout: Whether to enable layout analyzer
+ 2. ocr: Whether to enable ocr
+ 3. table: Whether to enable table recognizer
+ 4. recovery: Whether to recovery output to docx
+ """
+ parser = create_parser()
+
+ parser.add_argument(
+ "--layout",
+ type=str2bool,
+ default=True,
+ help="Whether to enable layout analyzer. The default layout analysis algorithm is YOLOv8.",
+ )
+
+ parser.add_argument(
+ "--ocr",
+ type=str2bool,
+ default=True,
+ help="Whether to enable ocr. The default ocr detection algorithm is DB++ and recognition algorithm is CRNN.",
+ )
+
+ parser.add_argument(
+ "--table",
+ type=str2bool,
+ default=True,
+ help="Whether to table recognizer. The default table analysis algorithm is TableMaster.",
+ )
+
+ parser.add_argument(
+ "--recovery",
+ type=str2bool,
+ default=True,
+ help="Whether to recovery output to docx. The docx will be saved in the ./inferrence_results as default.",
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def init_ocr(args):
+ """
+ Initialize text detection and recognition system
+
+ Args:
+ ocr: enable text system or not
+ det_algorithm: detection algorithm
+ rec_algorithm: recognition algorithm
+ det_model_dir: detection model directory
+ rec_model_dir: recognition model directory
+ """
+ if args.ocr:
+ return TextSystem(args)
+
+ return None
+
+
+def init_layout(args):
+ """
+ Initialize layout analysis system
+
+ Args:
+ layout: enable layout module or not
+ layout_algorithm: layout algorithm
+ layout_model_dir: layout model ckpt path
+ layout_amp_level: Auto Mixed Precision level for layout
+ """
+ if args.layout:
+ return LayoutAnalyzer(args)
+
+ return None
+
+
+def init_table(args):
+ """
+ Initialize table recognition system
+
+ Args:
+ table: enable table recognizer or not
+ table_algorithm: table algorithm
+ table_model_dir: table model ckpt path
+ table_max_len: max length of the input image
+ table_char_dict_path: path to character dictionary for table
+ table_amp_level: Auto Mixed Precision level for table
+ """
+ if args.table:
+ return TableAnalyzer(args)
+
+ return None
+
+
+def save_e2e_res(e2e_res: List, img_path: str, save_path: str):
+ """
+ Save the end-to-end results to a txt file
+ """
+ lines = []
+ img_name = os.path.basename(img_path).rsplit(".", 1)[0]
+ save_path = os.path.join(save_path, img_name + "_e2e_result.txt")
+ for i, res in enumerate(e2e_res):
+ img_pred = str(json.dumps(res)) + "\n"
+ lines.append(img_pred)
+
+ with open(save_path, "w") as f:
+ f.writelines(lines)
+ f.close()
+
+
+def predict_table_e2e(
+ img_path, layout_category_dict, layout_analyzer, text_system, table_analyzer, do_visualize, save_folder, recovery
+):
+ """
+ Predict the end-to-end results for the input image
+
+ Args:
+ img_path: path to the input image
+ layout_category_dict: category dictionary for layout recognition
+ layout_analyzer: layout analyzer model, for more details, please refer to predict_layout.py
+ text_system: text system model, for more details, please refer to predict_system.py
+ table_analyzer: table analyzer model, for more details, please refer to predict_table.py
+ do_visualize: whether to visualize the output
+ save_folder: folder to save the output
+ recovery: whether to recovery the output to docx
+ """
+ img_name = os.path.basename(img_path).rsplit(".", 1)[0]
+ image = cv2.imread(img_path)
+
+ if text_system is not None and do_visualize:
+ text_system(img_path, do_visualize=do_visualize)
+
+ if layout_analyzer is not None:
+ results = layout_analyzer(img_path, do_visualize=do_visualize)
+ else:
+ results = [{"category_id": 1, "bbox": [0, 0, image.shape[1], image.shape[0]], "score": 1.0}]
+
+ logger.info(f"Infering {len(results)} detected regions in {img_path}")
+
+ # crop text regions
+ h_ori, w_ori = image.shape[:2]
+ final_results = []
+ for i in range(len(results)):
+ category_id = results[i]["category_id"]
+ left, top, w, h = results[i]["bbox"]
+ right = left + w
+ bottom = top + h
+ cropped_img = image[int(top) : int(bottom), int(left) : int(right)]
+
+ if (category_id == 1 or category_id == 2 or category_id == 3) and text_system is not None:
+ start_time = time.time()
+
+ # only add white padding for text, title and list images for better recognition
+ if layout_analyzer is not None:
+ cropped_img = add_padding(cropped_img, padding_size=10, padding_color=(255, 255, 255))
+
+ rec_res_all_crops = text_system(cropped_img, do_visualize=False)
+ output = sort_words_by_poly(rec_res_all_crops[1], rec_res_all_crops[0])
+ final_results.append(
+ {"type": layout_category_dict[category_id], "bbox": [left, top, right, bottom], "res": " ".join(output)}
+ )
+
+ logger.info(
+ f"Processing {layout_category_dict[category_id]} at [{left}, {top}, {right}, {bottom}]"
+ f" {time.time() - start_time:.2f}s"
+ )
+ elif category_id == 4 and table_analyzer is not None:
+ start_time = time.time()
+ pred_html, _ = table_analyzer(cropped_img, do_visualize=do_visualize)
+ final_results.append(
+ {"type": layout_category_dict[category_id], "bbox": [left, top, right, bottom], "res": pred_html}
+ )
+
+ logger.info(
+ f"Processing {layout_category_dict[category_id]} at [{left}, {top}, {right}, {bottom}]"
+ f" {time.time() - start_time:.2f}s"
+ )
+ else:
+ start_time = time.time()
+ save_path = os.path.join(save_folder, f"{img_name}_figure_{i}.png")
+ cv2.imwrite(save_path, cropped_img)
+ final_results.append(
+ {"type": layout_category_dict[category_id], "bbox": [left, top, right, bottom], "res": save_path}
+ )
+
+ logger.info(
+ f"Processing {layout_category_dict[category_id]} at [{left}, {top}, {right}, {bottom}]"
+ f" {time.time() - start_time:.2f}s"
+ )
+
+ if recovery:
+ final_results = sorted_layout_boxes(final_results, w_ori)
+ convert_info_docx(final_results, save_folder, f"{img_name}_converted_docx")
+
+ return final_results
+
+
+def main():
+ from mindocr.utils.logger import set_logger
+
+ set_logger(name="mindocr")
+
+ first_time = time.time()
+ args = e2e_parse_args()
+ save_folder = args.draw_img_save_dir
+
+ save_folder, _ = os.path.splitext(save_folder)
+ if not os.path.exists(save_folder):
+ os.makedirs(save_folder)
+
+ text_system = init_ocr(args)
+ layout_analyzer = init_layout(args)
+ layout_category_dict = get_dict_from_file(args.layout_category_dict_path)
+ table_analyzer = init_table(args)
+
+ img_paths = get_image_paths(args.image_dir)
+ for i, img_path in enumerate(img_paths):
+ logger.info(f"Infering [{i+1}/{len(img_paths)}]: {img_path}")
+ final_results = predict_table_e2e(
+ img_path,
+ layout_category_dict,
+ layout_analyzer,
+ text_system,
+ table_analyzer,
+ args.visualize_output,
+ save_folder,
+ args.recovery,
+ )
+
+ save_e2e_res(final_results, img_path, save_folder)
+
+ logger.info(f"Processing e2e total time: {time.time() - first_time:.2f}s")
+ logger.info(f"Done! predict {len(img_paths)} e2e results saved in {save_folder}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/infer/text/preprocess.py b/tools/infer/text/preprocess.py
index a89e70bf5..47c6288dd 100644
--- a/tools/infer/text/preprocess.py
+++ b/tools/infer/text/preprocess.py
@@ -191,6 +191,30 @@ def __init__(self, task="det", algo="DB", **kwargs):
},
{"ToCHWImage": None},
]
+ elif task == "layout":
+ if algo == "LAYOUTLMV3":
+ pipeline = [
+ letterbox(scaleup=True, model_name="layoutlmv3"),
+ {
+ "NormalizeImage": {
+ "infer_mode": True,
+ "bgr_to_rgb": True,
+ "is_hwc": True,
+ "mean": [127.5, 127.5, 127.5],
+ "std": [127.5, 127.5, 127.5],
+ }
+ },
+ {"ToCHWImage": None},
+ {"ImageStridePad": {"stride": 32}},
+ ]
+ elif algo == "YOLOv8":
+ pipeline = [
+ letterbox(scaleup=False),
+ image_norm(scale=255.0),
+ image_transpose(bgr2rgb=True, hwc2chw=True),
+ ]
+ else:
+ raise ValueError(f"No preprocess config defined for {algo}. Please check the algorithm name.")
elif task == "table":
table_max_len = kwargs.get("table_max_len", 480)
pipeline = [
@@ -207,11 +231,21 @@ def __init__(self, task="det", algo="DB", **kwargs):
},
{"ToCHWImage": None},
]
- elif task == "layout":
+
+ elif task == "cls":
pipeline = [
- letterbox(scaleup=False),
- image_norm(scale=255.0),
- image_transpose(bgr2rgb=True, hwc2chw=True),
+ {"DecodeImage": {"img_mode": "BGR", "to_float32": False}},
+ {"Rotate90IfVertical": {"threshold": 2.0, "direction": "counterclockwise"}},
+ {"RecResizeImg": {"image_shape": [48, 192], "padding": False}},
+ {
+ "NormalizeImage": {
+ "bgr_to_rgb": True,
+ "is_hwc": True,
+ "mean": [127.0, 127.0, 127.0],
+ "std": [127.0, 127.0, 127.0],
+ }
+ },
+ {"ToCHWImage": None},
]
self.pipeline = pipeline
diff --git a/tools/infer/text/utils/__init__.py b/tools/infer/text/utils/__init__.py
index 1e474566d..c923c924d 100644
--- a/tools/infer/text/utils/__init__.py
+++ b/tools/infer/text/utils/__init__.py
@@ -1,2 +1,4 @@
from .matcher import Matcher, TableMasterMatcher
+from .recovery_to_doc import *
+from .table_process import *
from .utils import *
diff --git a/tools/infer/text/utils/recovery_to_doc.py b/tools/infer/text/utils/recovery_to_doc.py
new file mode 100644
index 000000000..49fb0161d
--- /dev/null
+++ b/tools/infer/text/utils/recovery_to_doc.py
@@ -0,0 +1,153 @@
+import os
+from typing import Dict, List
+
+from docx import Document, shared
+from docx.enum.section import WD_SECTION
+from docx.enum.text import WD_ALIGN_PARAGRAPH
+from docx.oxml.ns import qn
+
+from .table_process import HtmlToDocx
+
+
+def set_document_styles(doc: Document) -> None:
+ """
+ Set the styles for the document.
+ Args:
+ doc (Document): The document to set styles for.
+ """
+ doc.styles["Normal"].font.name = "Times New Roman"
+ doc.styles["Normal"]._element.rPr.rFonts.set(qn("w:eastAsia"), "宋体")
+ doc.styles["Normal"].font.size = shared.Pt(6.5)
+
+
+def convert_info_docx(res: List[Dict], save_folder: str, doc_name: str) -> None:
+ """
+ Convert OCR results to a DOCX file.
+ Args:
+ res (List[Dict]): OCR results.
+ save_folder (str): Folder to save the DOCX file.
+ doc_name (str): Name of the DOCX file.
+ Returns:
+ None
+ """
+ doc = Document()
+ set_document_styles(doc)
+
+ flag = 1 # Current layout flag
+ previous_layout = None # To record the previous layout
+
+ for region in res:
+ if not region["res"]:
+ continue
+
+ # Check if the current layout has changed to avoid creating the same layout repeatedly
+ current_layout = region["layout"]
+ if current_layout != previous_layout:
+ section = doc.add_section(WD_SECTION.CONTINUOUS)
+ if current_layout == "single":
+ section._sectPr.xpath("./w:cols")[0].set(qn("w:num"), "1")
+ flag = 1
+ elif current_layout == "double":
+ section._sectPr.xpath("./w:cols")[0].set(qn("w:num"), "2")
+ flag = 2
+ elif current_layout == "triple":
+ section._sectPr.xpath("./w:cols")[0].set(qn("w:num"), "3")
+ flag = 3
+ previous_layout = current_layout # Update the previous layout record
+
+ # Insert content based on the region type
+ if region["type"].lower() == "figure":
+ img_path = region["res"]
+ paragraph_pic = doc.add_paragraph()
+ paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER
+ run = paragraph_pic.add_run("")
+ # Insert picture, width depends on the column layout
+ if flag == 1:
+ run.add_picture(img_path, width=shared.Inches(5))
+ elif flag == 2:
+ run.add_picture(img_path, width=shared.Inches(2.5))
+ elif flag == 3:
+ run.add_picture(img_path, width=shared.Inches(1.5))
+
+ elif region["type"].lower() == "title":
+ doc.add_heading(region["res"])
+
+ elif region["type"].lower() == "table":
+ parser = HtmlToDocx()
+ parser.table_style = "TableGrid"
+ parser.handle_table(region["res"], doc)
+
+ else: # Default to handling text regions
+ paragraph = doc.add_paragraph()
+ text_run = paragraph.add_run(region["res"])
+ text_run.font.size = shared.Pt(10)
+
+ # Save as DOCX file
+ docx_path = os.path.join(save_folder, f"{doc_name}_ocr.docx")
+ doc.save(docx_path)
+
+
+def sorted_layout_boxes(res: List[Dict], w: int) -> List[Dict]:
+ """
+ Sort boxes based on distribution, supporting single, double, and triple column layouts,
+ considering columns with large spans.
+ Args:
+ res (List[Dict]): Results from layout.
+ w (int): Document width.
+ Returns:
+ List[Dict]: Sorted results.
+ """
+ num_boxes = len(res)
+ if num_boxes == 1:
+ res[0]["layout"] = "single"
+ return res
+
+ # Sort by y-coordinate from top to bottom, then by x-coordinate from right to left
+ sorted_boxes = sorted(res, key=lambda x: (x["bbox"][1], -x["bbox"][0]))
+ _boxes = list(sorted_boxes)
+
+ res_left = []
+ res_center = []
+ res_right = []
+ new_res = []
+
+ column_thresholds = [w / 3, 2 * w / 3]
+ tolerance = 0.02 * w
+
+ # First round: classify columns, determine the distribution of boxes in each column
+ for current_box in _boxes:
+ box_left, box_right = current_box["bbox"][0], current_box["bbox"][2]
+ box_width = box_right - box_left
+
+ # Determine column layout, ensuring each box is assigned to only one column
+ if box_width > column_thresholds[1]:
+ current_box["layout"] = "spanning"
+ new_res.append(current_box)
+ elif box_right < column_thresholds[0] + tolerance:
+ res_left.append(current_box)
+ elif box_left > column_thresholds[1] - tolerance:
+ res_right.append(current_box)
+ elif column_thresholds[0] - tolerance <= box_left <= column_thresholds[1] + tolerance:
+ res_center.append(current_box)
+ else:
+ res_left.append(current_box)
+
+ # Second round: determine specific layout based on column distribution
+ for box in res_left:
+ if res_center and res_right:
+ box["layout"] = "triple"
+ elif res_right or res_center:
+ box["layout"] = "double"
+ else:
+ box["layout"] = "single"
+ new_res.append(box)
+
+ for box in res_center:
+ box["layout"] = "triple" if res_left and res_right else "double"
+ new_res.append(box)
+
+ for box in res_right:
+ box["layout"] = "triple" if res_center else "double"
+ new_res.append(box)
+
+ return new_res
diff --git a/tools/infer/text/utils/table_process.py b/tools/infer/text/utils/table_process.py
new file mode 100644
index 000000000..045a0bb10
--- /dev/null
+++ b/tools/infer/text/utils/table_process.py
@@ -0,0 +1,272 @@
+import re
+from html.parser import HTMLParser
+
+import docx
+from bs4 import BeautifulSoup
+from docx import Document
+
+
+def get_table_rows(table_soup):
+ """
+ Get all rows for the table.
+ """
+ table_row_selectors = [
+ "table > tr",
+ "table > thead > tr",
+ "table > tbody > tr",
+ "table > tfoot > tr",
+ ]
+ return table_soup.select(", ".join(table_row_selectors), recursive=False)
+
+
+def get_table_columns(row):
+ """
+ Get all columns for the specified row tag.
+ """
+ return row.find_all(["th", "td"], recursive=False) if row else []
+
+
+def get_table_dimensions(table_soup):
+ """
+ Get the number of rows and columns in the table.
+ """
+ rows = get_table_rows(table_soup)
+ cols = get_table_columns(rows[0]) if rows else []
+
+ col_count = 0
+ for col in cols:
+ colspan = col.attrs.get("colspan", 1)
+ col_count += int(colspan)
+
+ return rows, col_count
+
+
+def get_cell_html(soup):
+ """
+ Return the HTML content of a cell without the tags.
+ """
+ return " ".join([str(i) for i in soup.contents])
+
+
+def delete_paragraph(paragraph):
+ """
+ Delete a paragraph from a docx document.
+ """
+ p = paragraph._element
+ p.getparent().remove(p)
+ p._p = p._element = None
+
+
+def remove_whitespace(string, leading=False, trailing=False):
+ """
+ Remove white space from a string.
+ """
+ if leading:
+ string = re.sub(r"^\s*\n+\s*", "", string)
+ if trailing:
+ string = re.sub(r"\s*\n+\s*$", "", string)
+ string = re.sub(r"\s*\n\s*", " ", string)
+ return re.sub(r"\s+", " ", string)
+
+
+font_styles = {
+ "b": "bold",
+ "strong": "bold",
+ "em": "italic",
+ "i": "italic",
+ "u": "underline",
+ "s": "strike",
+ "sup": "superscript",
+ "sub": "subscript",
+ "th": "bold",
+}
+
+font_names = {
+ "code": "Courier",
+ "pre": "Courier",
+}
+
+
+class HtmlToDocx(HTMLParser):
+ def __init__(self):
+ super().__init__()
+ self.options = {
+ "fix-html": True,
+ "images": True,
+ "tables": True,
+ "styles": True,
+ }
+ self.table_row_selectors = [
+ "table > tr",
+ "table > thead > tr",
+ "table > tbody > tr",
+ "table > tfoot > tr",
+ ]
+ self.table_style = None
+ self.paragraph_style = None
+
+ def set_initial_attrs(self, document=None):
+ self.tags = {
+ "span": [],
+ "list": [],
+ }
+ if document:
+ self.doc = document
+ else:
+ self.doc = Document()
+ self.bs = self.options["fix-html"]
+ self.document = self.doc
+ self.include_tables = True
+ self.include_images = self.options["images"]
+ self.include_styles = self.options["styles"]
+ self.paragraph = None
+ self.skip = False
+ self.skip_tag = None
+ self.instances_to_skip = 0
+
+ def copy_settings_from(self, other):
+ """
+ Copy settings from another instance of HtmlToDocx
+ """
+ self.table_style = other.table_style
+ self.paragraph_style = other.paragraph_style
+
+ def ignore_nested_tables(self, tables_soup):
+ """
+ Return only the highest level tables.
+ """
+ new_tables = []
+ nest = 0
+ for table in tables_soup:
+ if nest:
+ nest -= 1
+ continue
+ new_tables.append(table)
+ nest = len(table.find_all("table"))
+ return new_tables
+
+ def get_tables(self):
+ """
+ Get all tables from the HTML.
+ """
+ if not hasattr(self, "soup"):
+ self.include_tables = False
+ return
+ # find other way to do it, or require this dependency?
+ self.tables = self.ignore_nested_tables(self.soup.find_all("table"))
+ self.table_no = 0
+
+ def run_process(self, html):
+ """
+ Process the HTML content.
+ """
+ if self.bs and BeautifulSoup:
+ self.soup = BeautifulSoup(html, "html.parser")
+ html = str(self.soup)
+ if self.include_tables:
+ self.get_tables()
+ self.feed(html)
+
+ def add_html_to_cell(self, html, cell):
+ """
+ Add HTML content to a table cell.
+ """
+ if not isinstance(cell, docx.table._Cell):
+ raise ValueError("Second argument needs to be a %s" % docx.table._Cell)
+ unwanted_paragraph = cell.paragraphs[0]
+ if unwanted_paragraph.text == "":
+ delete_paragraph(unwanted_paragraph)
+ self.set_initial_attrs(cell)
+ self.run_process(html)
+ if not self.doc.paragraphs:
+ self.doc.add_paragraph("")
+
+ def apply_paragraph_style(self, style=None):
+ """
+ Apply style to the current paragraph.
+ """
+ try:
+ if style:
+ self.paragraph.style = style
+ elif self.paragraph_style:
+ self.paragraph.style = self.paragraph_style
+ except KeyError as e:
+ raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e
+
+ def handle_table(self, html, doc):
+ """
+ Handle nested tables by parsing them manually.
+ """
+ table_soup = BeautifulSoup(html, "html.parser")
+ rows, cols_len = get_table_dimensions(table_soup)
+ table = doc.add_table(len(rows), cols_len)
+ table.style = doc.styles["Table Grid"]
+
+ num_rows = len(table.rows)
+ num_cols = len(table.columns)
+
+ cell_row = 0
+ for _, row in enumerate(rows):
+ cols = get_table_columns(row)
+ cell_col = 0
+ for col in cols:
+ colspan = int(col.attrs.get("colspan", 1))
+ rowspan = int(col.attrs.get("rowspan", 1))
+
+ cell_html = get_cell_html(col)
+ if col.name == "th":
+ cell_html = f"{cell_html}"
+
+ if cell_row >= num_rows or cell_col >= num_cols:
+ continue
+
+ docx_cell = table.cell(cell_row, cell_col)
+
+ while docx_cell.text != "": # Skip the merged cell
+ cell_col += 1
+ docx_cell = table.cell(cell_row, cell_col)
+
+ cell_to_merge = table.cell(cell_row + rowspan - 1, cell_col + colspan - 1)
+ if docx_cell != cell_to_merge:
+ docx_cell.merge(cell_to_merge)
+
+ child_parser = HtmlToDocx()
+ child_parser.copy_settings_from(self)
+ child_parser.add_html_to_cell(cell_html or " ", docx_cell)
+
+ cell_col += colspan
+ cell_row += 1
+
+ def handle_data(self, data):
+ """
+ Handle text data within HTML tags.
+ """
+ if self.skip:
+ return
+
+ if "pre" not in self.tags:
+ data = remove_whitespace(data, True, True)
+
+ if not self.paragraph:
+ self.paragraph = self.doc.add_paragraph()
+ self.apply_paragraph_style()
+
+ link = self.tags.get("a")
+ if link:
+ self.handle_link(link["href"], data)
+ else:
+ self.run = self.paragraph.add_run(data)
+ spans = self.tags["span"]
+ for span in spans:
+ if "style" in span:
+ style = self.parse_dict_string(span["style"])
+ self.add_styles_to_run(style)
+
+ for tag in self.tags:
+ if tag in font_styles:
+ font_style = font_styles[tag]
+ setattr(self.run.font, font_style, True)
+
+ if tag in font_names:
+ font_name = font_names[tag]
+ self.run.font.name = font_name
diff --git a/tools/infer/text/utils/utils.py b/tools/infer/text/utils/utils.py
index 1717a9ab4..e75f45494 100644
--- a/tools/infer/text/utils/utils.py
+++ b/tools/infer/text/utils/utils.py
@@ -183,3 +183,75 @@ def get_ocr_result_paths(ocr_result_dir: str) -> List[str]:
"Please check the `image_dir` arg value."
)
return sorted(ocr_result_path)
+
+
+def add_padding(image, padding_size, padding_color=(0, 0, 0)):
+ """
+ Add padding to the image with color
+ """
+ if isinstance(padding_size, int):
+ top, bottom, left, right = padding_size, padding_size, padding_size, padding_size
+ else:
+ top, bottom, left, right = padding_size
+
+ padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color)
+ return padded_image
+
+
+def sort_words_by_poly(words, polys):
+ """
+ Sort detected word-boxes by polygon position in order to create a sentence
+ """
+ from functools import cmp_to_key
+
+ def compare(x, y):
+ dist1 = y[1][3][1] - x[1][0][1]
+ dist2 = x[1][3][1] - y[1][0][1]
+ if abs(dist1 - dist2) < x[1][3][1] - x[1][0][1] or abs(dist1 - dist2) < y[1][3][1] - y[1][0][1]:
+ if x[1][0][0] < y[1][0][0]:
+ return -1
+ elif x[1][0][0] == y[1][0][0]:
+ return 0
+ else:
+ return 1
+ else:
+ if x[1][0][1] < y[1][0][1]:
+ return -1
+ elif x[1][0][1] == y[1][0][1]:
+ return 0
+ else:
+ return 1
+
+ tmp = sorted(zip(words, polys), key=cmp_to_key(compare))
+ return [item[0][0] for item in tmp]
+
+
+def get_dict_from_file(file_path: str) -> dict:
+ """
+ Read a file and return a dictionary
+ Args:
+ file_path: path to a file
+ """
+ with open(file_path, "rb") as f:
+ lines = f.readlines()
+ return {i + 1: line.decode("utf-8").strip("\n").strip("\r\n") for i, line in enumerate(lines)}
+
+
+def img_rotate(image, angle):
+ """
+ Rotate the incoming image at a specified angle.
+
+ Args:
+ image: an encoded image that needs to be rotated.
+ angle: the target Angle at which the image is rotated
+
+ Returns:
+ rotated: the output image after rotation.
+ """
+
+ (h, w) = image.shape[:2]
+ center = (w / 2, h / 2)
+ M = cv2.getRotationMatrix2D(center, angle, 1.0)
+ rotated = cv2.warpAffine(image, M, (w, h))
+
+ return rotated
diff --git a/tools/param_converter_from_torch.py b/tools/param_converter_from_torch.py
new file mode 100644
index 000000000..770c226ab
--- /dev/null
+++ b/tools/param_converter_from_torch.py
@@ -0,0 +1,57 @@
+import argparse
+import json
+import os
+
+import torch
+
+from mindspore import Parameter, save_checkpoint
+
+
+def convert_hepler(input_path: str, json_path: str, output_path: str):
+ if os.path.exists(input_path) is not True:
+ raise ValueError("The torch_ckpt_path doesn't exist.")
+ if os.path.exists(json_path) is not True:
+ raise ValueError("The json path doesn't exist.")
+
+ output_dir, output_filename = os.path.split(os.path.abspath(output_path))
+
+ if os.path.exists(output_dir) is not True:
+ os.mkdir(output_dir)
+ real_output_path = os.path.join(output_dir, output_filename)
+
+ pt_ckpt = torch.load(input_path, map_location=torch.device("cpu"), weights_only=False)["model"]
+ ms_ckpt = list()
+
+ with open(json_path, "r") as json_file:
+ helper_json = json.load(json_file)
+ convert_map = helper_json["convert_map"]
+ for pt_name, ms_name in convert_map.items():
+ np_param = pt_ckpt[pt_name].detach().numpy()
+ ms_name = convert_map[pt_name]
+ ms_ckpt.append({"name": ms_name, "data": Parameter(np_param)})
+
+ save_checkpoint(ms_ckpt, real_output_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Get the path of paddle pdparams, and convert it to mindspore ckpt.")
+ parser.add_argument(
+ "-i",
+ "--input_path",
+ type=str,
+ default="model_final.pt",
+ help="The input path of the paddle pdparams.",
+ )
+ parser.add_argument(
+ "-j",
+ "--json_path",
+ type=str,
+ default="configs/layout/layoutlmv3/layoutlmv3_publaynet_param_map.json",
+ help="The path of the json.",
+ )
+ parser.add_argument(
+ "-o", "--output_path", type=str, default="from_torch.ckpt", help="The output path of the mindspore ckpt."
+ )
+ args = parser.parse_args()
+
+ convert_hepler(input_path=args.input_path, json_path=args.json_path, output_path=args.output_path)
|