一、粤港澳大湾区(黄埔)国际算法算例大赛-古籍文档图像识别与分析算法比赛简介

1.背景及意义

  • 中国几千年辉煌的华夏文明,留下了海量的古籍文献资料,这些文字记录承载着丰富的历史信息和文化传承。为响应古籍文化遗产保护、古籍数字化与推广应用的国家战略需求,传承中华优秀传统文化,挖掘利用古籍文献中蕴含的丰富知识,古籍透彻数字化工作势在必行。
  • 由于古籍文档图像的版式复杂、不同朝代的刻字书写风格差异大、古籍文字图像存在缺失、污渍、笔墨污染、模糊、印章噪声干扰、生僻字异体字繁多等技术挑战,古籍文档图像的识别及理解依然是一个极具挑战、远未解决的技术难题。
  • 为解决我国海量古籍数字化难题,本竞赛旨在征集先进的人工智能算法,解决高精度古籍文字检测、文本行识别、端到端古籍识别技术难题,推动古籍OCR技术进步,为古籍数字化保护、整理和利用提供人工智能支撑方法,特此举办本次比赛。

                                              图 1古籍文档示例

2.赛题描述

任务:古籍文档图像分析与识别

输入: 篇幅级别的古籍文档图片

输出: 利用文档图像物理及逻辑版面结构分析、文字检测、文字识别、文字阅读顺序理解等技术输出结构化的文本行坐标以及识别内容,其中各个文本的检测结果与识别内容按阅读顺序进行排列输出。模型仅输出正文的检测识别结果。忽略如版心、卷号等非结构化的内容。

码表说明

本次比赛提供码表(下载链接见链接: https://pan.baidu.com/s/16wUeSZ4JKD6f1Pj9ZhlKww 提取码: i53n ),其中包含了初赛训练集、验证集**(初赛A)及测试集(初赛B榜)中出现的字符类别。(注意:由于比赛设置了zero shot识别场景,训练集中出现的字符类别没有完全覆盖码表中的类别,目前公布的码表已完整覆盖初赛训练集及初赛A榜测试集的所有字符类别,初赛B榜码表可能会略有微调,后续将择机公布,请留意比赛官网通知。)**

初赛B榜码表公布:

下载链接见链接:https://pan.baidu.com/s/1gaNlKHk6lh5FxC2QP4UuDg
提取码:umzz
(公布日期:202298)

3.数据集说明

  • **初赛数据集:**训练集、验证集与测试集各包括1000幅古籍文档图像(共3000张图像),数据选自四库全书、历代古籍善本、乾隆大藏经等多种古籍数据。任务仅考虑古籍文档的正文内容,忽略如版心、卷号等边框外的内容。
  • **决赛数据集:**由于采取【**擂台赛】**的形式,除了主办方提供的原始初赛数据集以及决赛数据之外,决赛参赛队伍可申请成为擂主并提供各自的数据集供其他进入决赛的队伍进行训练和测试,提供的训练集不少于1000张,测试集不多于1000张,提供的数据集标注格式应与主办方提供的数据格式相同。

数据集标注格式:

每幅图像文本行文字及内容根据文本行阅读顺序进行标注,包含在一个单独的json文件。标注格式如下所示:

{
“image_name_1”, [{“points”:  x1, y1, x2, y2, …, xn, yn, “transcription”: text},
{“points”:  x1, y1, x2, y2, …, xn, yn, “transcription”: text},
					…],
“image_name_2”, [{“points”:  x1, y1, x2, y2, …, xn, yn, “transcription”: text},
{“points”:  x1, y1, x2, y2, …, xn, yn, “transcription”: text},
					…],
……
}
  • x1, y1, x2, y2, …, xn, yn代表文本框的各个点。
  • 对于四边形文本,n=4;数据集中存在少量不规则文本,对于这类标注,n=16(两条长边各8个点)。
  • Text代表每个文本行的内容,模糊无法识别的字均标注为#。
  • 其中文本行的检测与识别标签按照正确的阅读顺序给出。端到端识别内容按照阅读顺序进行标注,仅考虑文档的正文内容,忽略如版心、卷号等边框外的内容。
  • 阅读顺序的编排如图2所示。

                   图2 端到端古籍文档图像结构化识别理解中的阅读顺序标注可视化

4.提交结果

【初赛A榜】

  • **提交格式:**测试图片同名的CSV文件的压缩包
  • 提交内容:每张图片对应一个CSV文件,CSV文件中包含文本的检测框坐标以及对应的识别结果,并且这些文本都要按照预测得到的阅读顺序进行排列。

Csv文件内部格式如下:

x1, y1, x2, y2, x3, y3,…, xn, yn, transcription_1

x1, y1, x2, y2, x3, y3,…, xn, yn, transcription_2

x1, y1, x2, y2, x3, y3,…, xn, yn, transcription_n

(其中xn, yn代表坐标,这些坐标按顺时针进行排列,transcription_n代表文本的识别内容)

  • 提交样式示例:

    链接:https://pan.baidu.com/s/1h9smrGBwfJ78IP3WUlkEYQ
    提取码:suzi

  • 提交次数: 每天1次

  • 开始提交时间: 9月15日

二、数据集处理

1.解压数据集

!unzip -qoa data/data167941/dataset.zip

2.数据查看

!head -n30 dataset/train/label.json
{
  "image_0.jpg": [
    {
      "points": [
        1286,
        59,
        1326,
        59,
        1331,
        851,
        1290,
        851
      ],
      "transcription": "\u53ef\ud878\udcce\u4e45\u4e4e\u820e\u5229\u5f17\u563f\u7136\u4e0d\u8345\u25cf\u4e94\u8eab\u5b50\u81ea\u601d\u89e7\u8131\u7121\u4e45\u8fd1\u6545\u9ed9\u5929\u66f0\u5982\u4f55\ud859\udcbf\ud85b\udf94\u5927\u667a"
    },
    {
      "points": [
        1249,
        57,
        1286,
        59,
        1298,
        851,
        1251,
        851
      ],
      "transcription": "\u800c\u563f\u25cb\u516d\u5929\u554f\ud86e\udc26\u4ee5\u8087\u66f0\u4e94\u767e\u82d0\u5b50\u4ec1\u8005\u4f55\u667a\u6075\u82d0\u4e00\u563f\u7136\u4f55\u8036\u8345\u66f0\u89e7\u8131\u8005\u65e0\ud86e\udc26\u8a00\u8aaa"
    },
    {
      "points": [

在这里插入图片描述

3.数据格式转换

对PaddleOCR检测任务来说,数据集格式如下:

" 图像文件名                    json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]

故需要对数据格式进行转换。

import json

# 读取源label.json
f = open('dataset/train/label.json', 'r')
x = f.read()
y = json.loads(x)  
f.close()

# 查看长度1000
print(len(y))
# 查看数据格式
print(y["image_0.jpg"])
# 查看该文件下信息
print(len(y["image_0.jpg"]))

1000
[{'points': [1286, 59, 1326, 59, 1331, 851, 1290, 851], 'transcription': '可𮃎久乎舎利弗嘿然不荅●五身子自思觧脱無久近故黙天曰如何𦒿𦾔大智'}, {'points': [1249, 57, 1286, 59, 1298, 851, 1251, 851], 'transcription': '而嘿○六天問𫠦以肇曰五百苐子仁者何智恵苐一嘿然何耶荅曰觧脱者无𫠦言說'}, {'points': [1213, 60, 1252, 60, 1252, 784, 1213, 784], 'transcription': '故吾扵是不知𫠦云○七身子已離三𭻃惑得心觧脱永绝言𢿘故言不知𫠦云'}, {'points': [1173, 62, 1214, 62, 1224, 845, 1183, 845], 'transcription': '天曰言説文字皆觧脱𢪷●八天辨不思議觧脱即文字也文三𬼘𥘉摽文字即觧脱'}, {'points': [1135, 61, 1179, 61, 1184, 848, 1140, 848], 'transcription': '肇曰舎利弗以言文為失故黙然无言謂順真觧未䏻悟黙齊𦤺觸𢪱无礙故'}, {'points': [1099, 59, 1143, 59, 1149, 848, 1106, 848], 'transcription': '天說䒭觧以暁其意𫠦以者何觧脱者不内不外不在兩間文字亦不内不外不在兩'}, {'points': [1069, 61, 1111, 61, 1110, 854, 1065, 852], 'transcription': '間是故舎利弗无離文字説觧脱也𬼘二觧𥼶𫠦以肇曰法之𫠦在極扵三⺀𠁅⺀求文字'}, {'points': [1022, 61, 1066, 61, 1066, 851, 1022, 851], 'transcription': '觧脱俱不可淂如何欲離文字别説觧脱乎𫠦以者何一𭃄諸法是觧脱相○三明'}, {'points': [984, 60, 1025, 60, 1021, 850, 980, 850], 'transcription': '諸法䒭觧肇曰万法雖殊无非觧𢪷豈文字之獨異也舎利弗言不復以離媱怒'}, {'points': [946, 60, 985, 60, 978, 850, 938, 850], 'transcription': '𪪧為觧脱乎○𬼘下二明𣂾不𣂾别文二𥘉問也肇曰二乘结𥁞為觧脱聞上䒭觧乖'}, {'points': [905, 59, 951, 59, 942, 850, 895, 850], 'transcription': '其夲趣故𦤺斯問天日仏為増上𢢔人説離媱怒癡為觧脱耳𠰥无上𢢔者佛説'}, {'points': [860, 63, 909, 63, 902, 849, 852, 849], 'transcription': '媱怒癡性即是觧脱二荅也増上𢢔者未淂謂淂也身子𢴃小乘𫠦證非増上𢢔'}, {'points': [822, 62, 865, 62, 862, 850, 819, 850], 'transcription': '自謂共佛同㘴觧脱床名増上𢢔也既未悟缚解平䒭故為説離缚為觧𠰥大士'}, {'points': [779, 63, 822, 63, 822, 848, 779, 848], 'transcription': '非増上𢢔者為説即縛性脱性脱入不二門也舎利弗言善⺀㦲⺀天女汝何𫠦淂以何'}, {'points': [735, 62, 782, 62, 781, 846, 734, 846], 'transcription': '為證辨乃如是○三明證不證別文二𬼘𥘉也肇曰善其𫠦説非已𫠦及故問淂何道證'}, {'points': [693, 60, 736, 60, 745, 848, 703, 848], 'transcription': '阿果辨乃如是乎天曰我无淂无證故辨如是○荅文二𬼘𥘉正荅二乘捨缚求脱'}, {'points': [650, 62, 696, 62, 709, 852, 662, 852], 'transcription': '故有淂證大士悟縛脱平䒭非縛非脱故无淂无證既智窮不二之門故辨無礙'}, {'points': [619, 61, 658, 61, 664, 850, 626, 850], 'transcription': '也𫠦以者何𠰥有淂有證者則扵仏法為増上𢢔○二反厈肇曰𠰥見己有淂必見他'}, {'points': [576, 62, 617, 62, 631, 850, 591, 850], 'transcription': '不淂𬼘扵佛平䒭之法猶為増上𢢔人何䏻𦤺无礙之辨乎舎利弗問天汝扵三'}, {'points': [539, 63, 579, 63, 588, 845, 548, 845], 'transcription': '乘為何𢖽求𬼘下三约教明乘无乘別也小乘有法執故有差别乘大乘不二平'}, {'points': [497, 63, 539, 63, 550, 849, 508, 849], 'transcription': '䒭故无乘之乘文二𬼘𥘉問也肇曰上云无淂无證未知何乘故𣸪問也天曰以聲'}, {'points': [459, 63, 502, 63, 509, 853, 467, 853], 'transcription': '聞法化衆生故我為聲聞以因𦄘法化衆生故我為𮝻支仏以大悲法化衆生故我'}, {'points': [422, 65, 462, 65, 466, 851, 426, 851], 'transcription': '為大乘○二荅文二一惣约化𦄘荅二别约𫝆𦄘荅𬼘𥘉也肇曰大乘之道无乘之乘'}, {'points': [379, 65, 423, 65, 430, 827, 386, 827], 'transcription': '爲彼而乘吾何乘也生曰随彼為之我无㝎也又觧法花方便説三意同𬼘也'}, {'points': [342, 65, 382, 65, 396, 851, 356, 851], 'transcription': '舎利弗如人入瞻蔔林唯嗅瞻蔔不嗅餘香如是𠰥入𬼘室但聞仏㓛徳之香不樂'}, {'points': [300, 67, 343, 67, 360, 849, 318, 849], 'transcription': '聞聲聞𮝻支仏㓛徳香也○𬼘二约𫝆𦄘文四一明𫝆𦄘唯一二𫠦化樂大三室无小法四'}, {'points': [263, 64, 302, 64, 323, 849, 284, 849], 'transcription': '约室顕法𬼘𥘉也肇曰元乘不乘乃為大乘故以香林為喻明浄名之室不離二'}, {'points': [226, 64, 268, 64, 286, 849, 243, 849], 'transcription': '乘之香止乘止𬼘室者豈他嗅㢤舎利弗有其四𥼶梵四天王諸天龍神鬼√䒭入'}, {'points': [186, 63, 229, 63, 248, 855, 205, 855], 'transcription': '𬼘室者聞斯上人講说正法𣅜樂佛㓛徳之香𤼲心而出二明𫠦化皆樂大也舎利'}, {'points': [158, 65, 193, 63, 191, 204, 159, 207], 'transcription': '弗吾止𬼘室十'}, {'points': [183, 198, 200, 197, 200, 222, 183, 222], 'transcription': '有'}, {'points': [161, 207, 191, 205, 204, 856, 167, 859], 'transcription': '二年𥘉不聞説聲聞𮝻支仏法但聞菩薩大慈大悲不可思議諸'}, {'points': [121, 62, 169, 62, 172, 855, 125, 855], 'transcription': '佛之法三明深肇曰大乘之法𣅜不可思議上問止室久近欲生淪端故答'}, {'points': [80, 63, 122, 63, 131, 853, 90, 853], 'transcription': '以觧脱𫝆言實𭘾以明𫠦聞之不𮦀也生曰諸天鬼神蹔入室尚无不𤼲大意而出'}, {'points': [44, 62, 84, 62, 100, 849, 60, 849], 'transcription': '㦲况我久聞妙法乎然則不䏻不為大悲䏻為大矣舎利弗𬼘室常現八未曽有'}, {'points': [2, 60, 45, 60, 62, 848, 19, 848], 'transcription': '難淂之法𬼘四明未曽有室不说二乘之法也文三標𥼶结𬼘𥘉標也何謂為八'}]
36
# 格式转换
image_info_lists = {}
ff = open("dataset/train/label.txt", 'w')
for i in range(1000):
    # print(f"image_{i}.jpg")
    old_info = y[f"image_{i}.jpg"]

    new_info = []
    for item in old_info:
        image_info = {}
        image_info["transcription"] = item['transcription']
        points = item["points"]
        if len(points)==8:
            image_info["points"] = [[points[0], points[1]], [points[2], points[3]], [points[4], points[5]],
                                    [points[6], points[7]]]
        elif len(points)==32:
             image_info["points"] = [[points[0], points[1]], [points[2], points[3]], [points[4], points[5]],
                                    [points[6], points[7]], [points[8], points[9]],[points[10], points[11]],
                                    [points[12], points[13]], [points[14], points[15]],[points[16], points[17]],
                                    [points[18], points[19]], [points[20], points[21]],[points[22], points[23]],
                                    [points[24], points[25]], [points[26], points[27]],[points[28], points[29]],
                                    [points[30], points[31]]]
        elif len(points)==34:
             image_info["points"] = [[points[0], points[1]], [points[2], points[3]], [points[4], points[5]],
                                    [points[6], points[7]], [points[8], points[9]],[points[10], points[11]],
                                    [points[12], points[13]], [points[14], points[15]],[points[16], points[17]],
                                    [points[18], points[19]], [points[20], points[21]],[points[22], points[23]],
                                    [points[24], points[25]], [points[26], points[27]],[points[28], points[29]],
                                    [points[30], points[31]],[points[32], points[33]]]       
        else:
            continue                             
        new_info.append(image_info)
    image_info_lists[f"image_{i}.jpg"] = new_info
    ff.write(f"image_{i}.jpg" + "\t" + json.dumps(new_info) + "\n")
ff.close()

# 查看数据
print(image_info_lists["image_0.jpg"][0])
!head -n1 dataset/train/label.txt

4.分割数据集

前800为训练集
后200为测试集

%cd ~
!wc -l  dataset/train/label.txt
!head -800 dataset/train/label.txt >dataset/train/train.txt
!tail -200 dataset/train/label.txt >dataset/train/eval.txt

三、PaddleOCR环境准备

1.PaddleOCR下载

# !git clone https://gitee.com/paddlepaddle/PaddleOCR.git --depth=1

2.PaddleOCR安装

%cd ~/PaddleOCR/
!python -m pip install -q -U pip --user
!pip install -q -r requirements.txt
/home/aistudio/PaddleOCR
# !mkdir pretrain_models/
# %cd pretrain_models
# !wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
# !tar -xvf ch_PP-OCRv3_det_distill_train.tar

四、模型训练

!pip list|grep opencv
opencv-contrib-python          4.6.0.66
opencv-python                  4.2.0.32

1.opencv降级

opencv版本不对,需要降级,不然训练报以下错误。

Traceback (most recent call last):
  File "tools/train.py", line 30, in <module>
    from ppocr.data import build_dataloader
  File "/home/aistudio/PaddleOCR/ppocr/data/__init__.py", line 35, in <module>
    from ppocr.data.imaug import transform, create_operators
  File "/home/aistudio/PaddleOCR/ppocr/data/imaug/__init__.py", line 19, in <module>
    from .iaa_augment import IaaAugment
  File "/home/aistudio/PaddleOCR/ppocr/data/imaug/iaa_augment.py", line 24, in <module>
    import imgaug
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/imgaug/__init__.py", line 7, in <module>
    from imgaug.imgaug import *  # pylint: disable=redefined-builtin
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/imgaug/imgaug.py", line 18, in <module>
    import cv2
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/__init__.py", line 181, in <module>
    bootstrap()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/__init__.py", line 175, in bootstrap
    if __load_extra_py_code_for_module("cv2", submodule, DEBUG):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/__init__.py", line 28, in __load_extra_py_code_for_module
    py_module = importlib.import_module(module_name)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/cv2/mat_wrapper/__init__.py", line 33, in <module>
    cv._registerMatType(Mat)
AttributeError: module 'cv2' has no attribute '_registerMatType'
!pip uninstall  opencv-python -y
!pip uninstall opencv-contrib-python -y
!pip install opencv-python==4.2.0.32
Found existing installation: opencv-python 4.2.0.32
Uninstalling opencv-python-4.2.0.32:
  Successfully uninstalled opencv-python-4.2.0.32
Found existing installation: opencv-contrib-python 4.6.0.66
Uninstalling opencv-contrib-python-4.6.0.66:
  Successfully uninstalled opencv-contrib-python-4.6.0.66
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting opencv-python==4.2.0.32
  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/34/a3/403dbaef909fee9f9f6a8eaff51d44085a14e5bb1a1ff7257117d744986a/opencv_python-4.2.0.32-cp37-cp37m-manylinux1_x86_64.whl (28.2 MB)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from opencv-python==4.2.0.32) (1.19.5)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.2.0.32

2.训练配置

ch_PP-OCRv3_det_cml.yml

Global:
  character_dict_path: ../mb.txt #自定义字典
  debug: false
  use_gpu: true
  epoch_num: 500
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/ch_PP-OCR_v3_det/
  save_epoch_step: 100
  eval_batch_step:
  - 0
  - 400
  cal_metric_during_train: false
  pretrained_model: null
  checkpoints: null
  save_inference_dir: null
  use_visualdl: false
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./checkpoints/det_db/predicts_db.txt
  distributed: true

Architecture:
  name: DistillationModel
  algorithm: Distillation
  model_type: det
  Models:
    Student:
      pretrained:
      model_type: det
      algorithm: DB
      Transform: null
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: true
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    Student2:
      pretrained: 
      model_type: det
      algorithm: DB
      Transform: null
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: true
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    Teacher:
      pretrained: 
      freeze_params: true
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: ResNet_vd
        in_channels: 3
        layers: 50
      Neck:
        name: LKPAN
        out_channels: 256
      Head:
        name: DBHead
        kernel_list: [7,2,2]
        k: 50

Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: L2
    factor: 5.0e-05

PostProcess:
  name: DistillationDBPostProcess
  model_name: ["Student"]
  key: head_out
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"

# 数据集
Train:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/dataset/train/image
    label_file_list:
      - /home/aistudio/dataset/train/label.txt
    ratio_list: [1.0]
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - DetLabelEncode: null
    - CopyPaste:
    - IaaAugment:
        augmenter_args:
        - type: Fliplr
          args:
            p: 0.5
        - type: Affine
          args:
            rotate:
            - -10
            - 10
        - type: Resize
          args:
            size:
            - 0.5
            - 3
    - EastRandomCropData:
        size:
        - 960
        - 960
        max_tries: 50
        keep_ratio: true
    - MakeBorderMap:
        shrink_ratio: 0.4
        thresh_min: 0.3
        thresh_max: 0.7
    - MakeShrinkMap:
        shrink_ratio: 0.4
        min_text_size: 8
    - NormalizeImage:
        scale: 1./255.
        mean:
        - 0.485
        - 0.456
        - 0.406
        std:
        - 0.229
        - 0.224
        - 0.225
        order: hwc
    - ToCHWImage: null
    - KeepKeys:
        keep_keys:
        - image
        - threshold_map
        - threshold_mask
        - shrink_map
        - shrink_mask
  loader:
    shuffle: true
    drop_last: false
    batch_size_per_card: 12
    num_workers: 4

# 数据集
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/dataset/train/image
    label_file_list:
      - /home/aistudio/dataset/train/label.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 2

# 拷贝配置到对应目录
!cp ~/ch_PP-OCRv3_det_cml.yml ~/PaddleOCR/configs/det/ch_PP-OCRv3/
%export CUDA_VISIBLE_DEVICES='0,1,2,3'
# !python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Optimizer.base_lr=0.0001
%cd ~/PaddleOCR/
!python3 -m paddle.distributed.launch --ips="localhost" --gpus '0,1,2,3' tools/train.py -c  configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Optimizer.base_lr=0.0001

五、识别数据集准备

把det的数据集转换为rec数据集,进行模型训练

# ppocr/utils/gen_label.py
# convert the official gt to rec_gt_label.txt
%cd ~/PaddleOCR
!python ppocr/utils/gen_label.py --mode="rec" --input_path="../dataset/train/train.txt" --output_label="../dataset/train/train_rec_gt_label.txt"
!python ppocr/utils/gen_label.py --mode="rec" --input_path="../dataset/train/eval.txt" --output_label="../dataset/train/eval_rec_gt_label.txt"

六、识别模型训练

1.预训练模型下载

%cd ~/PaddleOCR/pretrain_models
!https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
!tar -xvf  ch_PP-OCRv3_rec_train.tar

2.配置训练参数

Global:
  debug: false
  use_gpu: true
  epoch_num: 800
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/rec_ppocr_v3_distillation
  save_epoch_step: 3
  eval_batch_step: [0, 2000]
  cal_metric_during_train: true
  # 预训练模型
  pretrained_model: pretrain_models/ch_PP-OCRv3_rec_train//best_accuracy.pdparams
  checkpoints:
  save_inference_dir:
  use_visualdl: false
  infer_img: doc/imgs_words/ch/word_1.jpg
  # 修改码表
  character_dict_path: ../mb.txt
  max_text_length: &max_text_length 25
  infer_mode: false
  use_space_char: true
  distributed: true
  save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Piecewise
    decay_epochs : [700, 800]
    values : [0.0005, 0.00005]
    warmup_epoch: 5
  regularizer:
    name: L2
    factor: 3.0e-05


Architecture:
  model_type: &model_type "rec"
  name: DistillationModel
  algorithm: Distillation
  Models:
    Teacher:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: SVTR
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last_conv_stride: [1, 2]
        last_pool_type: avg
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden_dims: 120
                use_guide: True
              Head:
                fc_decay: 0.00001
          - SARHead:
              enc_dim: 512
              max_text_length: *max_text_length
    Student:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: SVTR
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last_conv_stride: [1, 2]
        last_pool_type: avg
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden_dims: 120
                use_guide: True
              Head:
                fc_decay: 0.00001
          - SARHead:
              enc_dim: 512
              max_text_length: *max_text_length
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDMLLoss:
      weight: 1.0
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi_head: True
      dis_head: ctc
      name: dml_ctc
  - DistillationDMLLoss:
      weight: 0.5
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi_head: True
      dis_head: sar
      name: dml_sar
  - DistillationDistanceLoss:
      weight: 1.0
      mode: "l2"
      model_name_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationCTCLoss:
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: head_out
      multi_head: True
  - DistillationSARLoss:
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: head_out
      multi_head: True

PostProcess:
  name: DistillationCTCLabelDecode
  model_name: ["Student", "Teacher"]
  key: head_out
  multi_head: True

Metric:
  name: DistillationMetric
  base_metric_name: RecMetric
  main_indicator: acc
  key: "Student"
  ignore_space: False

# 修改数据及
Train:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/dataset/train/image
    ext_op_transform_idx: 1
    label_file_list:
    - /home/aistudio/dataset/train/train_rec_gt_label.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - RecConAug:
        prob: 0.5
        ext_data_num: 2
        image_shape: [48, 320, 3]
    - RecAug:
    - MultiLabelEncode:
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_sar
        - length
        - valid_ratio
  loader:
    shuffle: true
    batch_size_per_card: 128
    drop_last: true
    num_workers: 4

# 修改数据及
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/dataset/train/image
    ext_op_transform_idx: 1
    label_file_list:
    - /home/aistudio/dataset/train/eval_rec_gt_label.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - MultiLabelEncode:
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_sar
        - length
        - valid_ratio
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 128
    num_workers: 4

# 拷贝配置好的文件到指定位置
%cd ~
!cp  ~/ch_PP-OCRv3_rec_distillation.yml ~/PaddleOCR/configs/rec/PP-OCRv3/

3.模型训练

%cd ~/PaddleOCR/

#多卡训练,通过--gpus参数指定卡号
!python -m paddle.distributed.launch --gpus '0,1,2,3'  tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

七、联推理串

1.模型导出

# 导出检测模型
!python tools/export_model.py -c  configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml  -o Global.pretrained_model=./my_exps/det/best_accuracy Global.save_inference_dir=./inference/det
# 导出识别模型
!python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model=./my_exps/rec/best_accuracy Global.save_inference_dir=./inference/rec

2.联推理串

! python tools/infer/predict_system.py \
    --det_model_dir=inference/det \
    --rec_model_dir=inference/rec \
    --image_dir="/home/aistudio/dataset/train/image/image_0.jpg" \
    --rec_image_shape=3,48,320

# show img
plt.figure(figsize=(10, 8))
img = plt.imread("./inference_results/test.jpg")
c_image_shape=3,48,320

# show img
plt.figure(figsize=(10, 8))
img = plt.imread("./inference_results/test.jpg")
plt.imshow(img)

如上所述,进行预测,提交结果即可。

  • 建议: 用4卡GPU跑,会快一些,不然得好几天。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐