论文复现第六期-OCR识别算法-RobustScanner
转自AI Studio,原文链接:论文复现第六期-OCR识别算法-RobustScanner - 飞桨AI Studio一、简介近年来,针对任意形状文本识别的方法占据了主导地位。这些方法大致可以分为基于纠错的方法、基于分割的方法和基于注意力的编解码器方法。Rectification-based approaches.试图在识别之前将不规则的图像纠正为常规图像。以STN为基础,出现了一系列基于可学习
转自AI Studio,原文链接:论文复现第六期-OCR识别算法-RobustScanner - 飞桨AI Studio
一、简介
近年来,针对任意形状文本识别的方法占据了主导地位。这些方法大致可以分为基于纠错的方法、基于分割的方法和基于注意力的编解码器方法。
- Rectification-based approaches. 试图在识别之前将不规则的图像纠正为常规图像。以STN为基础,出现了一系列基于可学习TPS变换的方法。例如RARE、ASTER、STAR-NET、S-cRN、ESIR等。char-net先检测单个字符,然后对其进行单独纠错。
- Segmentation-based approaches. 对每个字符进行单独的分割,以避免不规则布局的问题。例如采用全卷积网络对单个字符进行检测和识别,然后进行字符组合。然而字符级别的标注并不容易获取。可以采用半监督的方法,先在合成数据集上进行字符检测的学习,然后迁移到真实数据上。所有这些基于分词的方法都不能简单地扩展到具有巨大字符词典的文本识别,例如中文识别,因为它们为每个字符维护一个概率热图。
- Encoder-decoder with attention-based approaches. 现有的大多数不规则文本识别方法都使用具有注意机制的编解码器框架。当然也出现了很多变体。FAN引入了一个焦点网络来抑制注意力漂移问题。SAR使用了定制的2D attention机制,效果显著。也有一些工作引入了transformer结构来捕获长时上下文。DAN采用解耦注意力来缓解对齐漂移问题。
本文也是具有注意力机制的编解码器框架,作者通过对encoder-decoder解码方式的研究,发现在解码过程中,不仅依赖语义信息,还依赖位置信息。而当前的大多数方法在解码过程中都过于依赖语义信息,导致存在较为严重的注意力偏移问题,对于没有语义信息或者弱语义信息的文本识别效果不佳。因此作者认为面对没有多少语义信息的文本串,有必要尝试增加位置信息,增加字符定位的准确性。 论文在规则和不规则文本识别基准测试上取得了当时最先进的结果,在无上下文基准测试上没有太大的性能下降,从而验证了其在上下文和无上下文应用程序场景中的健壮性。其网络结构如下图所示
具体细节请阅读论文 RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition
二、实验数据及精度
基于paddlepaddle深度学习框架,参考mmocr对文献算法进行复现。在此非常感谢mmocr,提高了本repo复现论文的效率。 采用的训练数据集和测试数据集如下表所示,详情可参考mmocr文档
训练集 | instance num | repeat num |
---|---|---|
icdar2013 | 848 | 20 |
icdar2015 | 4468 | 20 |
coco_text | 42142 | 20 |
IIIT5K | 2000 | 20 |
SynthText | 2400000 | 1 |
SynthAdd | 1216889 | 1 |
Syn90k | 2400000 | 1 |
注:SynthText和Syn90k均是随机挑选240w个样本。
测试集 | instance num | type |
---|---|---|
IIT5K | 3000 | 规则文本 |
SVT | 647 | 规则文本 |
IC13 | 1015 | 规则文本 |
IC15 | 2077 | 不规则文本 |
SVTP | 645 | 不规则文本 |
CT80 | 288 | 不规则文本 |
本项目达到的测试精度,如下表所示。
数据集 | IIIT5K | SVT | IC13 | IC15 | SVTP | CT80 | Avg | |
---|---|---|---|---|---|---|---|---|
论文 | MJ(891W) + ST(726W) + Real | 95.4 | 89.3 | 94.1 | 79.2 | 82.9 | 92.4 | 88.88 |
参考 | MJ(240W) + ST(240W) + SynthAdd(121W) + Real | 95.1 | 89.2 | 93.1 | 77.8 | 80.3 | 90.3 | 87.63 |
复现 | MJ(240W) + ST(240W) + SynthAdd(121W) + Real | 95.6 | 90.4 | 93.2 | 77.2 | 81.7 | 88.5 | 87.77 |
三、准备数据与环境
3.1 数据准备
使用的数据集已在AIStudio上公开,地址如下
训练集: 真实数据由ICDAR2013, ICDAR2015, IIIT5K, COCO-Text的训练集组成。 合成数据由Synth90K(240W), SynthAdd(121W), Synth800K(240W), synthadd组成 测试集:包含规则文本(IIIT5K、SVT、ICDAR2013)和不规则文本(ICDAR2015、SVTP、CUTE80)组成
为方便存储,所有数据都已经打包成lmdb格式。
注:由于数据过大,项目加载较为耗时,因此,暂时不放入项目中。若使用时,可修改项目,将数据集添加进来。
3.2 解压数据集
假设已经将数据加载到项目中,采用以下命令解压数据集
In [ ]
# 解压合成数据集
!mkdir /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/synth90K_shuffle.zip -d /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/SynthAdd.zip -d /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/SynthText800K_shuffle_1_40.zip -d /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/SynthText800K_shuffle_41_80.zip -d /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/SynthText800K_shuffle_81_160.zip -d /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/SynthText800K_shuffle_141_160.zip -d /home/aistudio/data/train_data/synth_data
!unzip /home/aistudio/data/data138433/SynthText800K_shuffle_161_200.zip -d /home/aistudio/data/train_data/synth_data
# 解压真实数据集(重复20次)
!mkdir /home/aistudio/data/train_data/real_data/repeat1
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat1
!mkdir /home/aistudio/data/train_data/real_data/repeat2
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat2
!mkdir /home/aistudio/data/train_data/real_data/repeat3
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat3
!mkdir /home/aistudio/data/train_data/real_data/repeat4
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat4
!mkdir /home/aistudio/data/train_data/real_data/repeat5
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat5
!mkdir /home/aistudio/data/train_data/real_data/repeat6
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat6
!mkdir /home/aistudio/data/train_data/real_data/repeat7
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat7
!mkdir /home/aistudio/data/train_data/real_data/repeat8
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat8
!mkdir /home/aistudio/data/train_data/real_data/repeat9
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat9
!mkdir /home/aistudio/data/train_data/real_data/repeat10
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat10
!mkdir /home/aistudio/data/train_data/real_data/repeat11
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat11
!mkdir /home/aistudio/data/train_data/real_data/repeat12
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat12
!mkdir /home/aistudio/data/train_data/real_data/repeat13
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat13
!mkdir /home/aistudio/data/train_data/real_data/repeat14
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat14
!mkdir /home/aistudio/data/train_data/real_data/repeat15
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat15
!mkdir /home/aistudio/data/train_data/real_data/repeat16
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat16
!mkdir /home/aistudio/data/train_data/real_data/repeat17
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat17
!mkdir /home/aistudio/data/train_data/real_data/repeat18
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat18
!mkdir /home/aistudio/data/train_data/real_data/repeat19
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat19
!mkdir /home/aistudio/data/train_data/real_data/repeat20
unzip /home/aistudio/data/data138433/training_lmdb_real.zip -d /home/aistudio/data/train_data/real_data/repeat20
3.3 获取项目代码
In [ ]
!git clone https://github.com/smilelite/RobustScanner.paddle
3.4 准备环境
- 框架:
- PaddlePaddle == 2.2.2
- 安装方式 直接使用pip进行安装
pip install paddlepaddle-gpu
paddlepaddle安装成功后,使用pip install -r requirements.txt安装依赖。 具体环境配置可参考ppocr
四、开始使用
本复现基于PaddleOCR框架,需要进行部分修改,主要是加入RobustScanner数据读取方式,backbone, RobustScanner_head,以及在训练和评估脚本中加入RobustScanner字段。
- 数据读取 主要是在./ppocr/data/imaug/rec_img_aug.py中加入了
RobustScannerRecResizeImg
- backbone 复用./ppocr/modeling/backbones/rec_resnet_31.py,在参数初始化上做了一些修改,这里没有新建一个rec_resnet_31.py,需要注意。
- RobustScanner_head 见./ppocr/modeling/heads/rec_robustscanner_head.py
- loss 复用SARLoss
整体训练流程与PaddleOCR一致,可参考PaddleOCR的流程,下面进行简述。
4.1 启动训练
开始训练时需要修改configs/rec/rec_r31_robustscanner.yml文件,主要是dataset设置
In [ ]
#单卡训练(训练周期长,不建议)
# python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
根据配置文件中设置的的 save_model_dir
和 save_epoch_step
字段,会有以下几种参数被保存下来:
output/rec/rec_r31_robustscanner/
├── best_accuracy.pdopt
├── best_accuracy.pdparams
├── best_accuracy.states
├── config.yml
├── iter_epoch_3.pdopt
├── iter_epoch_3.pdparams
├── iter_epoch_3.states
├── latest.pdopt
├── latest.pdparams
├── latest.states
└── train.log
其中 best_accuracy.* 是评估集上的最优模型;iter_epoch_x.* 是以 save_epoch_step
为间隔保存下来的模型;latest.* 是最后一个epoch的模型。
本项目训练好的模型权重及训练日志地址为(链接:https://pan.baidu.com/s/1IXVRqRSuGQFouAMLrgNOXA 提取码:no4x)。
4.2 模型评估
评估数据集可以通过 configs/rec/rec_r31_robustscanner.yml 修改Eval中的 data_dir 设置。
In [ ]
# GPU 评估, Global.checkpoints 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.checkpoints=./output/rec/rec_r31_robustscanner/best_accuracy
4.3 模型预测
默认预测图片存储在配置文件的 infer_img
字段里,通过Global.infer_img
进行修改。通过 -o Global.checkpoints
加载训练好的参数文件:
In [ ]
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model=./output/rec/rec_r31_robustscanner/best_accuracy Global.load_static_weights=false Global.infer_img=./inference/rec_inference/word_1.png
五、复现问题及心得
- 本项目复现时主要的问题时数据集的对齐,需要按照参考重新整理数据。
- 在复现过程中,存在精度不足的问题,主要原因在于主干网络的参数初始化方式存在差异。在paddlepaddle中,卷积层默认使用Normal初始化,而参考代码中采用kaiming初始化。其次batchnorm层的初始化也需要改成Uniform。下面给出两种不同初始化方式的对比效果(模型一为默认初始化,模型二为修改之后的初始化)
数据集 | IIIT5K | SVT | IC13 | IC15 | SVTP | CT80 | Avg | |
---|---|---|---|---|---|---|---|---|
模型一 | MJ(240W) + ST(240W) + SynthAdd(121W) + Real | 95.1 | 88.4 | 93.0 | 77.9 | 81.1 | 89.9 | 87.4 |
模型二 | MJ(240W) + ST(240W) + SynthAdd(121W) + Real | 95.6 | 90.4 | 93.2 | 77.2 | 81.7 | 88.5 | 87.77 |
- 模型动转静的过程中,需要传入vaid_ratio以产生mask,维护图片的有效长度。
信息 | 说明 |
---|---|
发布者 | smartlite |
时间 | 2022.05 |
框架版本 | Paddle 2.1.2 |
应用场景 | OCR识别 |
支持硬件 | GPU、CPU |
github | https://github.com/smilelite/RobustScanner.paddle |
更多推荐
所有评论(0)