世界人工智能创新大赛AIWIN手写字体OCR识别竞赛任务一baseline方案(基于paddle的实现)

本项目使用飞桨实现世界人工智能创新大赛AIWIN【手写字体OCR识别竞赛】任务的baseline方案,欢迎小伙伴来fork训练及调优,AI Studio提供高级算力资源(Tesla V100)。

一、竞赛介绍

2021世界人工智能创新大赛(AIWIN),由世界人工智能大会组委会主办,AI SPACE承办,是全球范围内初具影响力的人工智能赛事,是2021世界人工智能大会的重要组成部分。

秋季赛将继续围绕“人工智能助力城市数字化转型”的主题,以“开展算法创新、选拔数字人才”为目标,继续秉持“高端化、专业化、国际化、市场化“的原则开展赛事。

今年提供手写字体OCR识别竞赛和心电智能诊断算法竞赛两个赛题。

我们选取【手写字体OCR识别竞赛】任务一进行实验,接下来对赛题背景及任务进行简单介绍。

1.1 赛题背景

银行日常业务中涉及到各类凭证的识别录入,例如身份证录入、支票录入、对账单录入等。以往的录入方式主要是以人工录入为主,效率较低,人力成本较高。近几年来,OCR相关技术以其自动执行、人为干预较少等特点正逐步替代传统的人工录入方式。但OCR技术在实际应用中也存在一些问题,在各类凭证字段的识别中,手写体由于其字体差异性大、字数不固定、语义关联性较低、凭证背景干扰等原因,导致OCR识别率准确率不高,需要大量人工校正,对日常的银行录入业务造成了一定的影响。

1.2 赛题任务

本次赛题将提供手写体图像切片数据集,数据集从真实业务场景中,经过切片脱敏得到,参赛队伍通过识别技术,获得对应的识别结果。即:

  • 输入:手写体图像切片数据集

  • 输出:对应的识别结果

赛题在赛程中分设为两个独立任务,各自设定不同条件的训练集、测试集和建模环境,概述如下:

  • 任务一(本项目选取):提供开放可下载的训练集及测试集,允许线下建模或线上提供 Notebook 环境及 Terminal 容器环境(脱网)建模,输出识别结果完成赛题。

  • 任务二:提供不可下载的训练集,要求线上通过 Terminal 容器环境(脱网)建模后提交模型,由系统输入测试集(即对选手不可见),输出识别结果完成赛题

二、数据处理

2.1 数据下载

大赛使用数据要求如下"参赛人员不得对外以任何形式转载、发布赛题的训练集、验证集的全部或任意部分",因此需要大家自行去官网下载数据集

注:数据量8000,且均是文字区域,下载速度很快。

# 新建文件夹【dataset】
!mkdir dataset

将下载的数据集上传到【dataset】文件夹内,操作流程如下图所示:

然后解压数据集:

!unzip -q dataset/2021A_T1_Task1_Sample_V1106.zip -d ./dataset/
!unzip -q dataset/2021A_T1_Task1_数据集.zip -d ./dataset/

2.2 数据格式

下载的数据标注为json格式且图片在两个文件夹内,我们需要处理为PaddleOCR训练所需要的格式:建议将训练图片放入同一个文件夹,并用一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:

注意: txt文件中默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错。

" 图像文件名                 图像标注信息 "

train_data/rec/train/word_001.jpg   简单可依赖
train_data/rec/train/word_002.jpg   用科技让复杂的世界更简单
...

最终训练集应有如下文件结构:

|-train_data
  |- rec_gt_train.txt
  |- train
    |- 8bb1941c760a2c1d017626c361da6c4d.jpg
    |- 8bb1941c760a2c1d01762b943a624421.jpg
    |- 8bb1941c760a2c1d0176415a9ec807fe.jpg
    | ...

接下来,我们一起看怎么用代码具体实现吧~

import os
import os.path as osp
import json
import shutil
import yaml

定义write_file函数,处理训练集中date和amount中的数据:

def write_file(file, json_file, save_pic):
    # 读取json文件
    data = yaml.load(open(json_file))

    # all_str为了后面统计训练集的字典
    all_str = ''
    
    for pic_name, label_info in data.items():
        # 修改成OCR需要的格式
        line = os.path.join(save_pic, pic_name)+'\t'+label_info+'\n'
        file.write(line)

        all_str+=label_info

        # 将图片移动到save_pic目录下
        ori_path = osp.join(osp.dirname(json_file), 'images', pic_name)
        save_path = osp.join(save_pic, pic_name)
        shutil.copy(ori_path, save_path)

    return set(all_str)
# 处理数据之后的保存路径
!mkdir 'train_data'
# 记录图片和标签的txt
save_txt = 'train_data/rec_gt_train.txt'
# 所有图片放在一个文件夹内
save_pic = 'train_data/train/'
if not os.path.exists(save_pic):
    os.mkdir(save_pic)

# 读取date和amount的json文件
date_json = 'dataset/训练集/date/gt.json'
amount_json = 'dataset/训练集/amount/gt.json'

file = open(save_txt, 'w')

date_set = write_file(file, date_json, save_pic)
amount_set = write_file(file, amount_json, save_pic)

file.close()

处理测试集,将所有图片放在一个文件夹内:

!mkdir /home/aistudio/test_data/
!cp -r /home/aistudio/dataset/测试集/amount/images/* /home/aistudio/test_data/
!cp -r /home/aistudio/dataset/测试集/date/images/* /home/aistudio/test_data/

2.3 字典

最后需要提供一个字典({rec_gt_label}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。

因此字典需要包含所有希望被正确识别的字符,{rec_gt_label}.txt需要写成如下格式,并以 utf-8 编码格式保存:

l
d
a
d
r
n
character_dict_path = 'train_data/rec_gt_label.txt'
with open(character_dict_path, 'w', encoding='utf-8') as out_file:
    merge_set = date_set|amount_set
    num_class = len(merge_set)
    print('num_class:',num_class)
    for label in merge_set:
        line = label+'\n'
        out_file.write(line)
num_class: 21

三、模型构建

3.1 识别算法

PaddleOCR中提供了如下文本识别算法列表,以及每个算法在英文公开数据集上的模型和指标,主要用于算法简介和算法性能对比。

文本识别算法:

模型骨干网络Avg Accuracy模型存储命名下载链接
RosettaResnet34_vd80.24%rec_r34_vd_none_none_ctc下载链接
RosettaMobileNetV378.16%rec_mv3_none_none_ctc下载链接
CRNNResnet34_vd82.20%rec_r34_vd_none_bilstm_ctc下载链接
CRNNMobileNetV379.37%rec_mv3_none_bilstm_ctc下载链接
STAR-NetResnet34_vd83.93%rec_r34_vd_tps_bilstm_ctc下载链接
STAR-NetMobileNetV381.56%rec_mv3_tps_bilstm_ctc下载链接
RAREResnet34_vd84.90%rec_r34_vd_tps_bilstm_attn下载链接
RAREMobileNetV383.32%rec_mv3_tps_bilstm_attn下载链接
SRNResnet50_vd_fpn88.33%rec_r50fpn_vd_none_srn下载链接

3.2 安装PaddleOCR

本项目中已经帮大家安装好了最新版的PaddleOCR,且修改好配置文件、后处理代码,无需安装~

如仍需安装or安装更新,可以执行以下步骤(目前支持Clone GitHub【推荐】和Gitee两种方式):

注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。

!git clone https://github.com/PaddlePaddle/PaddleOCR
# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
# !git clone https://gitee.com/paddlepaddle/PaddleOCR
# 安装依赖,每次启动项目都需要执行
%cd PaddleOCR
!pip install --upgrade pip
!pip install -r requirements.txt

3.3 下载预训练模型

首先下载模型backbone的pretrain model,您可以根据需求使用PaddleClas中的模型更换backbone,
对应的backbone预训练模型可以从PaddleClas repo 主页中找到下载链接

# 下载MobileNetV3的预训练模型
!wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# 解压模型参数
%cd pretrain_models
!tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc_v2.0_train.tar
%cd ..

3.4 模型训练

这里选择CRNN模型进行训练、MobileNetv3作为backbone,可以在configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml文件里修改训练配置:比如是否使用GPU、模型保存路径、数据集路径、学习率、优化等。

执行命令,启动训练:

!python tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml

四、模型预测

训练好模型之后,即可启动测试,Global.pretrained_model表示预测使用的模型,Global.infer_img表示测试的图片路径或着测试图片文件夹路径:

# 预测中文结果
!python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=output/rec_chinese_lite_v2.0/latest Global.load_static_weights=false Global.infer_img=/home/aistudio/dataset/测试集/amount/images/8bb1941c760a2c1d017626c361da6c4d.jpg

同时修改infer_rec.py将结果保存为比赛要求的格式,保存结果的路径由configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.ymlsave_res_path参数控制,结果answer.json效果如下图:

接下来,大家可以尝试不同的优化策略:修改参数(学习率、epoch、优化器、warmup等)、增加数据增强、更换模型等,开动起来啦。快来fork,一键三连!

参考资源

飞桨PaddleEdu技术交流群(QQ)

目前QQ群已有2000+同学一起学习,欢迎扫码加入

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐