★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

一、项目背景

汉语拼音是中国小学生启蒙教育的重要一环,因此手写汉语拼音的识别具有很高的研究价值。传统人工识别汉语拼音识别效率低下而且容易识别出错,在批阅小学生试卷时带来很大困难。人工识别手写汉语拼音已经难以满足社会需求,所以需要加快手写汉语拼音识别的数字化和信息化,通过人工智能手段来推动手写汉语拼音识别工作。汉语拼音可以帮助教育行业快速协助作业批改,减少人工流程,提升工作效率。

二、项目简介

本项目基于深度学习的手写汉语拼音识别方法研究与实现。项目采用PaddleOCR框架,整体采用主流深度学习文字识别算法CRNN+CTC方法,项目流程主要分为数据集采集及标注,算法构建、模型训练、预测与评估等。

三、数据集介绍:

自建数据集:本项目采用的手写汉语拼音数据集,通过随机生成500个不同的汉语拼音,并将汉语拼音分配给尽量多的人在A4纸上进行书写,以此来保证手写汉语拼音的字迹的多样化。对500个手写汉语拼音进行图像采集,使用人工对拼音进行对应标注。数据集也可以通过Python脚本生成,考虑到脚本自动生成如果没有经过一定规则过滤,可能存在语法错误,因此,本文最终采取人工拍照的方式进行数据集构建,由于时间限制,本项目仅制作500张手写拼音图片,后续可以增加更多数据集增加模型泛化能力。

1.数据集标注:

首先将收集的手写汉语拼音图片进行重命名,进行标签式标注,完成数据集的初步构建,用于进行后续的模型测练。对数据集按9:1比例进行划分。共制作500张手写拼音图片,随机选取 50个图片作为测试集,以此来测试所构建的算法对手写汉语拼音识别的准确率。另外450个图片则作为训练集,为算法构建后的神经网络进行训练,最终数据集预览如图所示。

2.数据集格式:

路径 标注信息 (注:以 ‘\t’ 分割,路径按自己实际情况写)

D:\Python\PycharmProjects\CRNN\ATT_PY\data\bàocháng.png bàocháng
D:\Python\PycharmProjects\CRNN\ATT_PY\data\diànnǎo.png diànnǎo
D:\Python\PycharmProjects\CRNN\ATT_PY\data\guòyǐn.png guòyǐn
D:\Python\PycharmProjects\CRNN\ATT_PY\data\húshōuzhàn.png húshōuzhàn
D:\Python\PycharmProjects\CRNN\ATT_PY\data\kǒngpà.png kǒngpà
D:\Python\PycharmProjects\CRNN\ATT_PY\data\péiyǎng.png péiyǎng
D:\Python\PycharmProjects\CRNN\ATT_PY\data\qǔdāo.png qǔdāo
D:\Python\PycharmProjects\CRNN\ATT_PY\data\tiānnèi.png tiānnèi


注:最终识别数据集为txt格式,测试与训练文件路径在/home/aistudio/test502.txt 和/home/aistudio/train502.txt,可直接运行。

四、数据处理及环境安装

1.安装环境

#安装PaddleOCR
!git clone https://gitee.com/paddlepaddle/PaddleOCR  
%cd PaddleOCR
!git checkout -b release/2.4 remotes/origin/release/2.4
#安装环境依赖
!pip install -r requirements.txt

2.下载PaddleOCR预训练模型

选用PaddleOCR模型地址

# 获取预训练模型
!wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/en_number_mobile_v2.0_rec_slim_train.tar 
!tar -xf /home/aistudio/PaddleOCR/pretrain_models/en_number_mobile_v2.0_rec_slim_train.tar -C /home/aistudio/PaddleOCR/pretrain_models

3.数据集处理

3.1解压数据集
!unzip /home/aistudio/data/data191980/py_data.zip -d /home/aistudio/data502/
3.2 格式转换(图片转txt)
import os
paths=r'./data502'
f=open('./txt/all_picnum502.txt','wt',encoding='utf-8')


filenames=os.listdir(paths)
for filename in filenames:
    if os.path.splitext(filename)[1]=='.png':
        imgname=filename.split('.')[0]
        imgpath=r'./data502'
        out_path=imgpath+'/'+filename+' '+imgname
        print(out_path)
        #f.writable(out_path+'\n')
        f.write(out_path+'\n')

f.close()


此处额外提供一个txt转lmdb脚本create_lmdb.py,适合大规模数据集下使用,PaddleOCR也支持LMDB格式数据。

!python pic_to_txt_num.py
3.3划分数据集
file_name = './txt/all_pic502.txt'
with open('./txt/train502.txt', 'w') as train_txt:
    with open('./txt/test502.txt', 'w') as test_txt:
        # 读取txt文件,每行按制表符分割提取数据
        with open(file_name, 'r') as file_txt:
            count = len(file_txt.readlines())   #获取文件总的行数
            print(count)
            file_txt.seek(0)  #回到文件开头
            for i in range(count):
                line_datas = file_txt.readline()
                if(i%20==0):   #每10个数据1个测试集
                    test_txt.write(line_datas)
                else:
                    train_txt.write(line_datas)
!python split_data.py
# file_name='all_pic.txt'
#拷贝数据集到对应路径,方便配置文件设置,路径可按需修改
!cp /home/aistudio/data502/txt/train502.txt /home/aistudio/
!cp /home/aistudio/data502/txt/test502.txt /home/aistudio/
#每次启动环境可能需要重新安装一下
!pip install imgaug
!pip install Levenshtein
!pip install pyclipper
!pip install lmdb

五、模型选择

本文采用PaddleOCR开源项目进行简单训练和识别

1.PaddleOCRv2

CRNN算法:

PaddleOCRv2采用经典的CRNN+CTC算法进行识别,整体上完成识别模型的搭建、训练、评估和预测过程。训练时可以手动更改config配置文件(数据训练、加载、评估验证等参数),默认采用优化器采用Adam,使用CTC损失函数。

网络结构:

CRNN网络结构包含三部分,从下到上依次为:

(1)卷积层。作用是从输入图像中提取特征序列。

(2)循环层。作用是预测从卷积层获取的特征序列的标签(真实值)分布。

(3)转录层。作用是把从循环层获取的标签分布通过去重整合等操作转换成最终的识别结果。

CRNN模型训练:

在模型训练过程中,首先使用标准的CNN网络提取文本图像的特征,再利用BLSTM将特征向量进行融合以提取字符序列的上下文特征,然后得到每列特征的概率分布,最后通过转录层(CTC)进行预测得到文本序列。

具体模型训练流程为:

1.将输入图像统一缩放至32W3。

2.利用CNN提取后图像卷积特征,得到的大小为:1W/4512。

3.通过上述输入到LSTM提取序列特征,得到W/4*n后验概率矩阵。

4.利用CTC损失,实现标签和输出一一对应,进行训练。


CTC损失函数介绍:

CTC是一种Loss计算方法,用CTC代替Softmax Loss,训练样本无需对齐。引入blank字符,解决有些位置没有字符的问题,通过递推,快速计算梯度。

2.PaddleOCRv3

PaddleOCRv3默认选择SVTR算法,SVTR提出了一种用于场景文本识别的单视觉模型,使用特征提取模块包括**采用单视觉模型(类似ViT),基于patch-wise image tokenization框架,引入Mixing Block获取多粒度特征。完全摒弃了序列建模,在精度具有竞争力的前提下,模型参数量更少,速度更快。

SVTR算法

算法结构

本项目的PaddleOCRv3将陆续更新,欢迎关注!

六、模型训练

本项目以PaddleOCRv2版本为例,算法采用CRNN+CTC,骨干网络分别选择轻量级的MobileNetV3和经典的ResNet网络进行对比,识别配置文件rec_en_number_lite_train.yml和 rec_en_number_ResNet_train.yml,该文件路径自行设置。

rec_en_number_lite_train.yml

Global:
  use_gpu: True
  epoch_num: 200
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/rec_en_number_lite    #保存模型路径
  save_epoch_step: 3
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [0, 100]
  # if pretrained_model is saved in static mode, load_static_weights must set to True
  cal_metric_during_train: True
  pretrained_model: ./pretrain_models/en_number_mobile_v2.0_rec_slim_train/best_accuracy  #预训练模型路径
  checkpoints:  
  save_inference_dir:
  use_visualdl: True  #开启数据可视化
  infer_img:
  # for data or label process
  character_dict_path: ppocr/utils/en_pydict.txt   #含拼音的词典
  max_text_length: 250
  infer_mode: False
  use_space_char: True


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.005
  regularizer:
    name: 'L2'
    factor: 0.00001

Architecture:
  model_type: rec
  algorithm: CRNN
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.5
    model_name: small
    small_stride: [1, 2, 2, 2]
  Neck:
    name: SequenceEncoder
    encoder_type: rnn
    hidden_size: 48
  Head:
    name: CTCHead
    fc_decay: 0.00001

Loss:
  name: CTCLoss

PostProcess:
  name: CTCLabelDecode

Metric:
  name: RecMetric
  main_indicator: acc

Train:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/data502/data502
    label_file_list: ["/home/aistudio/train502.txt"]   #训练集路径
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - RecAug: 
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True
    # batch_size_per_card: 1024
    batch_size_per_card: 8  #原来256
    drop_last: True
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/data502/data502
    label_file_list: ["/home/aistudio/test502.txt"]  #测试集路径
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 8  #原来256
    num_workers: 8

1.MobilnetV3网络训练

%cd /home/aistudio/PaddleOCR
!python tools/train.py -c /home/aistudio/work/rec_en_number_lite_train.yml

2.ResNet网络训练

配置文件路径为/work/rec_en_number_ResNet_train.yml,具体可直接查看文件内容,其他同上。

%cd /home/aistudio/PaddleOCR
!python tools/train.py -c /home/aistudio/work/rec_en_number_ResNet_train.yml

七、查看训练过程

  1. 开启数据可视化,需在配置文件中设置use_visualdl: True

  1. 点击[启动VisualDL服务],在点击进入,即可实时查看相关训练指标

  1. 完成

设置继续训练

#加载上次保存模型,继续训练
%cd PaddleOCR
!python tools/train.py -c /home/aistudio/work/rec_en_number_lite_train.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/rec_en_number_lite/best_accuracy

验证预测结果

# 图片显示
import matplotlib.pyplot  as plt
import cv2
def imshow(img_path):
    im = cv2.imread(img_path)
    plt.imshow(im )

img = '/home/aistudio/data502/data502/āmò.png'
imshow(img)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FMKE8crJ-1681452253253)(main_files/main_26_0.png)]

!python tools/infer_rec.py -c /home/aistudio/work/rec_en_number_lite_train.yml \
       -o Global.infer_img="/home/aistudio/data502/data502/āmò.png" \
       Global.pretrained_model="/home/aistudio/PaddleOCR/output/rec_en_number_lite/best_accuracy"
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):
[2023/02/18 00:27:49] root INFO: Architecture : 
[2023/02/18 00:27:49] root INFO:     Backbone : 
[2023/02/18 00:27:49] root INFO:         model_name : small
[2023/02/18 00:27:49] root INFO:         name : MobileNetV3
[2023/02/18 00:27:49] root INFO:         scale : 0.5
[2023/02/18 00:27:49] root INFO:         small_stride : [1, 2, 2, 2]
[2023/02/18 00:27:49] root INFO:     Head : 
[2023/02/18 00:27:49] root INFO:         fc_decay : 1e-05
[2023/02/18 00:27:49] root INFO:         name : CTCHead
[2023/02/18 00:27:49] root INFO:     Neck : 
[2023/02/18 00:27:49] root INFO:         encoder_type : rnn
[2023/02/18 00:27:49] root INFO:         hidden_size : 48
[2023/02/18 00:27:49] root INFO:         name : SequenceEncoder
[2023/02/18 00:27:49] root INFO:     Transform : None
[2023/02/18 00:27:49] root INFO:     algorithm : CRNN
[2023/02/18 00:27:49] root INFO:     model_type : rec
[2023/02/18 00:27:49] root INFO: Eval : 
[2023/02/18 00:27:49] root INFO:     dataset : 
[2023/02/18 00:27:49] root INFO:         data_dir : /home/aistudio/data502/data502
[2023/02/18 00:27:49] root INFO:         label_file_list : ['/home/aistudio/test502.txt']
[2023/02/18 00:27:49] root INFO:         name : SimpleDataSet
[2023/02/18 00:27:49] root INFO:         transforms : 
[2023/02/18 00:27:49] root INFO:             DecodeImage : 
[2023/02/18 00:27:49] root INFO:                 channel_first : False
[2023/02/18 00:27:49] root INFO:                 img_mode : BGR
[2023/02/18 00:27:49] root INFO:             CTCLabelEncode : None
[2023/02/18 00:27:49] root INFO:             RecResizeImg : 
[2023/02/18 00:27:49] root INFO:                 image_shape : [3, 32, 320]
[2023/02/18 00:27:49] root INFO:             KeepKeys : 
[2023/02/18 00:27:49] root INFO:                 keep_keys : ['image', 'label', 'length']
[2023/02/18 00:27:49] root INFO:     loader : 
[2023/02/18 00:27:49] root INFO:         batch_size_per_card : 8
[2023/02/18 00:27:49] root INFO:         drop_last : False
[2023/02/18 00:27:49] root INFO:         num_workers : 8
[2023/02/18 00:27:49] root INFO:         shuffle : False
[2023/02/18 00:27:49] root INFO: Global : 
[2023/02/18 00:27:49] root INFO:     cal_metric_during_train : True
[2023/02/18 00:27:49] root INFO:     character_dict_path : ppocr/utils/en_pydict.txt
[2023/02/18 00:27:49] root INFO:     checkpoints : None
[2023/02/18 00:27:49] root INFO:     debug : False
[2023/02/18 00:27:49] root INFO:     distributed : False
[2023/02/18 00:27:49] root INFO:     epoch_num : 200
[2023/02/18 00:27:49] root INFO:     eval_batch_step : [0, 100]
[2023/02/18 00:27:49] root INFO:     infer_img : /home/aistudio/data502/data502/āmò.png
[2023/02/18 00:27:49] root INFO:     infer_mode : False
[2023/02/18 00:27:49] root INFO:     log_smooth_window : 20
[2023/02/18 00:27:49] root INFO:     max_text_length : 250
[2023/02/18 00:27:49] root INFO:     pretrained_model : /home/aistudio/PaddleOCR/output/rec_en_number_lite/best_accuracy
[2023/02/18 00:27:49] root INFO:     print_batch_step : 10
[2023/02/18 00:27:49] root INFO:     save_epoch_step : 3
[2023/02/18 00:27:49] root INFO:     save_inference_dir : None
[2023/02/18 00:27:49] root INFO:     save_model_dir : ./output/rec_en_number_lite
[2023/02/18 00:27:49] root INFO:     use_gpu : True
[2023/02/18 00:27:49] root INFO:     use_space_char : True
[2023/02/18 00:27:49] root INFO:     use_visualdl : True
[2023/02/18 00:27:49] root INFO: Loss : 
[2023/02/18 00:27:49] root INFO:     name : CTCLoss
[2023/02/18 00:27:49] root INFO: Metric : 
[2023/02/18 00:27:49] root INFO:     main_indicator : acc
[2023/02/18 00:27:49] root INFO:     name : RecMetric
[2023/02/18 00:27:49] root INFO: Optimizer : 
[2023/02/18 00:27:49] root INFO:     beta1 : 0.9
[2023/02/18 00:27:49] root INFO:     beta2 : 0.999
[2023/02/18 00:27:49] root INFO:     lr : 
[2023/02/18 00:27:49] root INFO:         learning_rate : 0.005
[2023/02/18 00:27:49] root INFO:         name : Cosine
[2023/02/18 00:27:49] root INFO:     name : Adam
[2023/02/18 00:27:49] root INFO:     regularizer : 
[2023/02/18 00:27:49] root INFO:         factor : 1e-05
[2023/02/18 00:27:49] root INFO:         name : L2
[2023/02/18 00:27:49] root INFO: PostProcess : 
[2023/02/18 00:27:49] root INFO:     name : CTCLabelDecode
[2023/02/18 00:27:49] root INFO: Train : 
[2023/02/18 00:27:49] root INFO:     dataset : 
[2023/02/18 00:27:49] root INFO:         data_dir : /home/aistudio/data502/data502
[2023/02/18 00:27:49] root INFO:         label_file_list : ['/home/aistudio/train502.txt']
[2023/02/18 00:27:49] root INFO:         name : SimpleDataSet
[2023/02/18 00:27:49] root INFO:         transforms : 
[2023/02/18 00:27:49] root INFO:             DecodeImage : 
[2023/02/18 00:27:49] root INFO:                 channel_first : False
[2023/02/18 00:27:49] root INFO:                 img_mode : BGR
[2023/02/18 00:27:49] root INFO:             RecAug : None
[2023/02/18 00:27:49] root INFO:             CTCLabelEncode : None
[2023/02/18 00:27:49] root INFO:             RecResizeImg : 
[2023/02/18 00:27:49] root INFO:                 image_shape : [3, 32, 320]
[2023/02/18 00:27:49] root INFO:             KeepKeys : 
[2023/02/18 00:27:49] root INFO:                 keep_keys : ['image', 'label', 'length']
[2023/02/18 00:27:49] root INFO:     loader : 
[2023/02/18 00:27:49] root INFO:         batch_size_per_card : 8
[2023/02/18 00:27:49] root INFO:         drop_last : True
[2023/02/18 00:27:49] root INFO:         num_workers : 4
[2023/02/18 00:27:49] root INFO:         shuffle : True
[2023/02/18 00:27:49] root INFO: profiler_options : None
[2023/02/18 00:27:49] root INFO: train with paddle 2.0.2 and device CUDAPlace(0)
W0218 00:27:49.086216 13407 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0218 00:27:49.090732 13407 device_context.cc:372] device: 0, cuDNN Version: 7.6.
[2023/02/18 00:27:52] root INFO: load pretrain successful from /home/aistudio/PaddleOCR/output/rec_en_number_lite/best_accuracy
[2023/02/18 00:27:52] root INFO: infer_img: /home/aistudio/data502/data502/āmò.png
[2023/02/18 00:27:52] root INFO: 	 result: āmò	0.95821923
[2023/02/18 00:27:52] root INFO: success!

023/02/18 00:27:52] root INFO: result: āmò 0.95821923
[2023/02/18 00:27:52] root INFO: success!

八、 模型导出

#设置导出模型路径./inference/rec_py
%cd PaddleOCR
!python tools/export_model.py -c /home/aistudio/work/rec_en_number_lite_train.yml \
-o Global.pretrained_model=/home/aistudio/PaddleOCR/output/rec_en_number_lite/best_accuracy \
Global.save_inference_dir=./inference/rec_py

结论

本项目主要对手写汉语拼音进行识别,整体上完成深度学习文字识别基本流程,从自建数据集,数据处理、格式转换,模型搭建、模型选择与训练到最后模型预测与推理,主体采用PaddleOCR框架中的CRNN+CTC算法,经过多轮训练及优化,识别准确率可以达到90%。在后续优化数据方面,后期可以结合数据增广操作保证模型泛化行。
此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐