Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CAN model #739

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions configs/rec/can/README_CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[English]() | 中文

# CAN (Counting-Aware Network)
<!--- Guideline: use url linked to abstract in ArXiv instead of PDF for fast loading. -->

> [CAN: When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/pdf/2207.11463.pdf)

## 1. 模型描述
<!--- Guideline: Introduce the model and architectures. Cite if you use/adopt paper explanation from others. -->

CAN是具有一个弱监督计数模块的注意力机制编码器-解码器手写数学公式识别算法。本文作者通过对现有的大部分手写数学公式识别算法研究,发现其基本采用基于注意力机制的编码器-解码器结构。该结构可使模型在识别每一个符号时,注意到图像中该符号对应的位置区域,在识别常规文本时,注意力的移动规律比较单一(通常为从左至右或从右至左),该机制在此场景下可靠性较高。然而在识别数学公式时,注意力在图像中的移动具有更多的可能性。因此,模型在解码较复杂的数学公式时,容易出现注意力不准确的现象,导致重复识别某符号或者是漏识别某符号。

针对于此,作者设计了一个弱监督计数模块,该模块可以在没有符号级位置注释的情况下预测每个符号类的数量,然后将其插入到典型的基于注意的HMER编解码器模型中。这种做法主要基于以下两方面的考虑:1、符号计数可以隐式地提供符号位置信息,这种位置信息可以使得注意力更加准确。2、符号计数结果可以作为额外的全局信息来提升公式识别的准确率。

<p align="center">
<img src="https://temp-data.obs.cn-central-221.ovaijisuan.com/mindocr_material/miss_word.png" width=640 />
</p>
<p align="center">
<em> 图1. 手写数学公式识别算法对比 [<a href="#参考文献">1</a>] </em>
</p>

CAN模型由主干特征提取网络、多尺度计数模块(MSCM)和结合计数的注意力解码器(CCAD)构成。主干特征提取通过采用DenseNet得到特征图,并将特征图输入MSCM,得到一个计数向量(Counting Vector),该计数向量的维度为1*C,C即公式词表大小,然后把这个计数向量和特征图一起输入到CCAD中,最终输出公式的latex。

<p align="center">
<img src="https://temp-data.obs.cn-central-221.ovaijisuan.com/mindocr_material/total_process.png" width=640 />
</p>
<p align="center">
<em> 图2. 整体模型结构 [<a href="#参考文献">1</a>] </em>
</p>

多尺度计数模MSCM块旨在预测每个符号类别的数量,其由多尺度特征提取、通道注意力和池化算子组成。由于书写习惯的不同,公式图像通常包含各种大小的符号。单一卷积核大小无法有效处理尺度变化。为此,首先利用了两个并行卷积分支通过使用不同的内核大小(设置为 3×3 和 5×5)来提取多尺度特征。在卷积层之后,采用通道注意力来进一步增强特征信息。

<p align="center">
<img src="https://temp-data.obs.cn-central-221.ovaijisuan.com/mindocr_material/MSCM.png" width=640 />
</p>
<p align="center">
<em> 图3. MSCM多尺度计数模块 [<a href="#参考文献">1</a>] </em>
</p>

结合计数的注意力解码器:为了加强模型对于空间位置的感知,使用位置编码表征特征图中不同空间位置。另外,不同于之前大部分公式识别方法只使用局部特征进行符号预测的做法,在进行符号类别预测时引入符号计数结果作为额外的全局信息来提升识别准确率。

<p align="center">
<img src="https://temp-data.obs.cn-central-221.ovaijisuan.com/mindocr_material/CCAD.png" width=640 />
</p>
<p align="center">
<em> 图4. 结合计数的注意力解码器CCAD [<a href="#参考文献">1</a>] </em>
</p>

## 参考文献
<!--- Guideline: Citation format GB/T 7714 is suggested. -->
[1] Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne Zhang. RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition. arXiv:2007.07542, ECCV'2020
1 change: 1 addition & 0 deletions mindocr/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .rec_vgg import *
from .table_master_resnet import *
from .yolov8_backbone import yolov8_backbone
from .rec_can_densenet import *

__all__ = []
__all__.extend(builder.__all__)
Expand Down
197 changes: 197 additions & 0 deletions mindocr/models/backbones/rec_can_densenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
Rec_DenseNet model
"""
import math
import mindspore as ms

from mindspore import nn
from mindspore import ops
from ._registry import register_backbone, register_backbone_class

ms.set_context(pynative_synchronize=True)

__all__ = ['DenseNet']


class Bottleneck(nn.Cell):
"""Bottleneck block of rec_densenet"""
def __init__(self, n_channels, growth_rate, use_dropout):
super().__init__()
inter_channels = 4 * growth_rate
self.bn1 = nn.BatchNorm2d(inter_channels)
self.conv1 = nn.Conv2d(
n_channels,
inter_channels,
kernel_size=1,
has_bias=False,
pad_mode='pad',
padding=0,
)
self.bn2 = nn.BatchNorm2d(growth_rate)
self.conv2 = nn.Conv2d(
inter_channels,
growth_rate,
kernel_size=3,
has_bias=False,
pad_mode='pad',
padding=1
)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)

def construct(self, x):
out = ops.relu(self.bn1(self.conv1(x)))
if self.use_dropout:
out = self.dropout(out)
out = ops.relu(self.bn2(self.conv2(out)))
if self.use_dropout:
out = self.dropout(out)
out = ops.concat((x, out), 1)
return out


class SingleLayer(nn.Cell):
"""SingleLayer block of rec_densenet"""
def __init__(self, n_channels, growth_rate, use_dropout):
super().__init__()
self.bn1 = nn.BatchNorm2d(n_channels)
self.conv1 = nn.Conv2d(
n_channels,
growth_rate,
kernel_size=3,
has_bias=False,
pad_mode='pad',
padding=1
)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)

def construct(self, x):
out = self.conv1(ops.relu(x))
if self.use_dropout:
out = self.dropout(out)
out = ops.concat((x, out), 1)
return out


class Transition(nn.Cell):
"""Transition Module of rec_densenet"""
def __init__(self, n_channels, out_channels, use_dropout):
super().__init__()
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(
n_channels,
out_channels,
kernel_size=1,
has_bias=False
)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)

def construct(self, x):
out = ops.relu(self.bn1(self.conv1(x)))

if self.use_dropout:
out = self.dropout(out)
out = ops.avg_pool2d(out, 2, stride=2, ceil_mode=True)
return out


@register_backbone_class
class DenseNet(nn.Cell):
r"""The RecDenseNet model is the customized DenseNet backbone for
Handwritten Mathematical Expression Recognition.
For example, in the CAN recognition algorithm, it is used in
feature extraction to obtain a formula feature map.
DenseNet Network is based on
`"When Counting Meets HMER: Counting-Aware Network for
Handwritten Mathematical Expression Recognition"
<https://arxiv.org/abs/2207.11463>`_ paper.

Args:
growth_rate (int): growth rate of DenseNet. The default value is 24.
reduction (float): compression ratio in DenseNet. The default is 0.5.
bottleneck (bool): specifies whether to use a bottleneck layer. The default is True.
use_dropout (bool): indicates whether to use dropout. The default is True.
input_channels (int): indicates the number of channels in the input image. The default is 3.
Return:
nn.Cell for backbone module

Example:
>>> # init a DenseNet network
>>> params = {
>>> 'growth_rate': 24,
>>> 'reduction': 0.5,
>>> 'bottleneck': True,
>>> 'use_dropout': True,
>>> 'input_channels': 3,
>>> }
>>> model = DenseNet(**params)
"""
def __init__(self, growth_rate, reduction, bottleneck, use_dropout, input_channels):
super().__init__()
n_dense_blocks = 16
n_channels = 2 * growth_rate

self.conv1 = nn.Conv2d(
input_channels,
n_channels,
kernel_size=7,
stride=2,
has_bias=False,
pad_mode='pad',
padding=3,
)
self.dense1 = self.make_dense(
n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
)
n_channels += n_dense_blocks * growth_rate
out_channels = int(math.floor(n_channels * reduction))
self.trans1 = Transition(n_channels, out_channels, use_dropout)

n_channels = out_channels
self.dense2 = self.make_dense(
n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
)
n_channels += n_dense_blocks * growth_rate
out_channels = int(math.floor(n_channels * reduction))
self.trans2 = Transition(n_channels, out_channels, use_dropout)

n_channels = out_channels
self.dense3 = self.make_dense(
n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
)
n_channels += n_dense_blocks * growth_rate
self.out_channels = [n_channels]

def construct(self, x):
out = self.conv1(x)
out = ops.relu(out)
out = ops.max_pool2d(out, 2, ceil_mode=True)
out = self.dense1(out)
out = self.trans1(out)
out = self.dense2(out)
out = self.trans2(out)
out = self.dense3(out)
return out

def make_dense(self, n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout):
"""Create dense_layer of DenseNet"""
layers = []
layer_constructor = Bottleneck if bottleneck else SingleLayer
for _ in range(int(n_dense_blocks)):
layers.append(layer_constructor(n_channels, growth_rate, use_dropout))
n_channels += growth_rate
return nn.SequentialCell(*layers)


@register_backbone
def rec_can_densenet(pretrained: bool = False, **kwargs) -> DenseNet:
"""Create a rec_densenet backbone model."""
if pretrained is True:
raise NotImplementedError(
"The default pretrained checkpoint for `rec_densenet` backbone does not exist."
)

model = DenseNet(**kwargs)
return model
2 changes: 2 additions & 0 deletions mindocr/models/heads/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'YOLOv8Head',
'MultiHead',
'TableMasterHead',
'CANHead',
]
from .cls_head import MobileNetV3Head
from .conv_head import ConvHead
Expand All @@ -36,6 +37,7 @@
from .rec_visionlan_head import VisionLANHead
from .table_master_head import TableMasterHead
from .yolov8_head import YOLOv8Head
from .rec_can_head import CANHead


def build_head(head_name, **kwargs):
Expand Down
Loading