保险文本视觉认知问答
保险文本视觉认知问答
1.项目介绍
1.1背景
随着人工智能技术的逐渐成熟,计算机视觉、语音、自然语言处理等技术在金融行业的应用从广度和深度上都在加速,这不仅降低了金融机构的运营和风险成本,而且有助于提升客户的满意度,比如利用NLP 技术实现智能问答解决方案,帮助用户即使没有复杂的金融背景知识也能快速找到自己需要的信息,而在寿险、产险、健康险等保险的理赔流程和客户服务环节中,存在大量扫描文档,例如医疗票据、费用清单、病例等。对这些扫描文档进行文字检测与识别,并且提取出结构化信息,可以用于极速理赔、个人健康管理等业务场景。
在保险领域,用户常见的问题占了60%~70%,这部分重复性工作费时费力,需要更有效率的处理方式。智能问答能够准确理解用户的意图,并直接给出精确的答案,极大节省了用户及工作人员的时间。
1.2.项目任务分析
本次任务需要将提供面向保险场景的扫描图片数据集,利用OCR技术自动识别影像资料后,再通过AI智能判断所识别文字的内在逻辑,回答关于图片的自然语言问题。问题的答案是可以从图片中提取的任何文本/标记。
输入: 保险场景的扫描文档(例如:医疗票据)+ 自然语言提问(例如:病人服用的药品清单有什么?)
输出: 对应自然语言提问的事实性答案.
-
分析: 根据不同文档图片数据集进行OCR识别,对问题和回答进行建模,保证一定准确率,技术涉及到OCR+NLP。
-
难点: 照片拍摄角度不同,字体混合手写,一张图片可能由多张票据混合,背景噪声影响识别效果,考虑使用多种模型对比。
PaddleOCR流程
1.3 参考资料:
基于Paddle实现baseline项目参考:AIWN保险文本视觉认知问答
优秀方案参考:DocVQA冠军方案分享
PaddleOCR官方资料库
https://github.com/PaddlePaddle/PaddleOCR
PaddleNLP官方资料库
https://github.com/PaddlePaddle/PaddleNLP
2.数据集
2.1数据简介
- 本次大赛提供的数据集使用的扫描文件类型包括票据、说明、报告等20 多种。混合了印刷、打字和手写的内容。
- 训练集有5000余张左右原始扫描文件及对应的 4万余个自然语言问答标注。提供的数据均已做了标注及脱敏。
2.1.1 训练集
-
训练集数据包括:
- image:包含所有原始扫描文件图像
- train.csv:问答训练库,包含序号(index)、问题 ID(quesiton_id)、图片名称(filename)、问题(question_text)、答案(answer_text),共 5 列
- readme:数据说明文档
字段说明:
-
训练集用于模型训练,数据字段包括以下内容:
1、index:序号
2、question_id:问题的唯一id标识
3、filename:问题对应的唯一图片名称
4、question_text:问题描述
5、answer_text:问题对应的唯一答案
2.1.2 测试集
-
测试集数据规模为1000张左右原始扫描文件及对应的7000个自然语言问题,数据内容样例同训练集。
-
测试集包含以下3个文件:
- image:包含所有原始扫描文件图像
- test1.csv:问答测试库,包含序号(index)、问题 ID(quesiton_id)、图片路径(filename)、问题(question_text),共 4 列
- readme:数据说明文档
-
测试集用于模型验证,需提交问题对应答案结果,数据字段包括以下内容:
1、index:序号
2、question_id:问题的唯一id标识
3、filename:问题对应的唯一图片名称
4、question_text:问题描述
2.2数据展示
-
样例一:
- 提问: 西药费的金额是多少?
回答: 140.16
提问: 140.16元购买了什么药品?
回答: {甲}缘沙坦胶囊{基}
- 提问: 西药费的金额是多少?
-
样例二:
- 提问: 这是一份关于什么药品的说明?
回答: 十三味疏肝胶囊
提问: 药品的有效期是多久?
回答: 1.5年
- 提问: 这是一份关于什么药品的说明?
3.项目代码
使用PaddleOCR+PaddleNLP实现代码
参考项目原地址: https://github.com/datawhalechina/competition-baseline/tree/master/competition/AIWIN2021
3.1安装环境依赖包
# 安装paddleocr和paddlenlp
!pip install --user paddleocr==2.0.4 paddlenlp==2.0.0rc18
!pip list
!pip install pandas pillow matplotlib Ipython
#解压数据集
!tar -xf data/data83016/dataset.tar -C data
import pandas as pd
from PIL import Image
import codecs
import os
import matplotlib.pyplot as plt
# from IPython.display import set_matplotlib_formats
# %matplotlib inline
# set_matplotlib_formats('svg') # 输出为svg
df = pd.read_csv('data/train-utf8.csv')
df['filename'] = 'data/image/' + df['filename'] # 改为本地路径
3.2 OCR阶段
ocr阶段生成位置及内容:
注:Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换参数依次为ch
, en
, french
, german
, korean
, japan
。
from paddleocr import PaddleOCR
import pandas as pd
from PIL import Image
import codecs
import os
import matplotlib.pyplot as plt
ocr = PaddleOCR(use_angle_cls=True, lang="ch", enable_mkldnn=True) # need to run only once to download and load model into memory
df = pd.read_csv('data/train-utf8.csv')
df['filename'] = 'data/image/' + df['filename'] # 改为本地路径
for path in df['filename'].unique():
print(path)
if os.path.exists('result/' + os.path.basename(path)[:-4] + '.txt'):
continue
result = ocr.ocr(path, cls=True)
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
for box, txt in zip(boxes, txts):
with codecs.open('result/' + os.path.basename(path)[:-4] + '.txt', 'a') as up:
up.write('{0}\t{1}\n'.format(box, txt))
#创建结果路径
!mkdir result
#运行上述脚本,时间比较长,耐心等待即可
!python ocr.py
输出结果展示:
[[394.0, 51.0], [459.0, 51.0], [459.0, 75.0], [394.0, 75.0]] 橘红丸
[[34.0, 125.0], [343.0, 125.0], [343.0, 137.0], [34.0, 137.0]] 说明书来源:广东宏兴集团股份有限公司宏兴制药厂
[[77.0, 169.0], [150.0, 169.0], [150.0, 186.0], [77.0, 186.0]] 【药品名称】
[[178.0, 169.0], [249.0, 169.0], [249.0, 186.0], [178.0, 186.0]] 【通用名称】
[[300.0, 169.0], [343.0, 169.0], [343.0, 186.0], [300.0, 186.0]] 橘红丸
[[178.0, 210.0], [248.0, 210.0], [248.0, 223.0], [178.0, 223.0]] 【汉语拼音】
[[300.0, 210.0], [379.0, 210.0], [379.0, 223.0], [300.0, 223.0]] JuhongWan
结果分析示例:
df.head(10)
index | question_id | filename | question_text | answer_text | |
---|---|---|---|---|---|
0 | 1 | Q00001 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 这是什么药品? | 茶碱缓释片 |
1 | 2 | Q00002 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 本说明书来源于哪里? | 黑龙江鼎恒升药业有限公司 |
2 | 3 | Q00003 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 本品可通过什么屏障? | 胎盘 |
3 | 4 | Q00004 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 说明书上方正中是什么字? | 茶碱缓释片 |
4 | 5 | Q00005 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 左上角是什么字? | 说明书来源:黑龙江鼎恒升药业有限公司 |
5 | 6 | Q00006 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 老年用药是下一项是什么? | 药物相互作用 |
6 | 7 | Q00007 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | Theophylline Sustainde-release Tablets是药品的什么? | 英文名 |
7 | 8 | Q00008 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 茶碱是指什么? | 主要成份 |
8 | 9 | Q00009 | data/image/c850b0d7018d127989d1b20d0f7118d66f5... | 198.18是指什么数? | 分子量 |
9 | 10 | Q00010 | data/image/AHEFGLB18921EAAA75R7_20210301111254... | 太平洋产险全国统一保险消费投诉电话是哪个号码? | 95500-3-4 |
Image.open(df['filename'].iloc[0])
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZUY4abis-1672287898385)(output_20_0.png)]
codecs.open('result/' + os.path.basename(df['filename'].iloc[0])[:-4] + '.txt').readlines()[:10]
['[[374.0, 51.0], [476.0, 51.0], [476.0, 74.0], [374.0, 74.0]]\t茶碱缓释片\n',
'[[33.0, 124.0], [274.0, 124.0], [274.0, 137.0], [33.0, 137.0]]\t说明书来源:黑龙江鼎恒升药业有限公司\n',
'[[231.0, 167.0], [304.0, 170.0], [303.0, 187.0], [231.0, 184.0]]\t【通用名称】\n',
'[[133.0, 170.0], [202.0, 170.0], [202.0, 185.0], [133.0, 185.0]]\t【药品名称】\n',
'[[354.0, 169.0], [424.0, 169.0], [424.0, 187.0], [354.0, 187.0]]\t茶碱缓释片\n',
'[[232.0, 209.0], [292.0, 209.0], [292.0, 226.0], [232.0, 226.0]]\t【商品名】\n',
'[[231.0, 245.0], [291.0, 248.0], [290.0, 266.0], [231.0, 264.0]]\t【英文名】\n',
'[[355.0, 248.0], [600.0, 248.0], [600.0, 264.0], [355.0, 264.0]]\tTheophylline Sustained-release Tablets\n',
'[[233.0, 286.0], [304.0, 286.0], [304.0, 303.0], [233.0, 303.0]]\t【汉语拼音】\n',
'[[356.0, 288.0], [488.0, 288.0], [488.0, 301.0], [356.0, 301.0]]\tChajian HuanshiPian\n']
3.3 自然语言处理
3.3.1 模型1:规则匹配
codecs.open('result/' + os.path.basename(df['filename'].iloc[80])[:-4] + '.txt').readlines()[:15]
['[[150.0, 57.0], [470.0, 57.0], [470.0, 83.0], [150.0, 83.0]]\tPERSDNAL\n',
'[[456.0, 57.0], [694.0, 57.0], [694.0, 83.0], [456.0, 83.0]]\tRESUME\n',
'[[205.0, 276.0], [293.0, 276.0], [293.0, 304.0], [205.0, 304.0]]\t五百丁\n',
'[[576.0, 276.0], [696.0, 276.0], [696.0, 307.0], [576.0, 307.0]]\t基本信息\n',
'[[212.0, 328.0], [519.0, 328.0], [519.0, 341.0], [212.0, 341.0]]\t一句话介绍自己,告诉HR为什么选择你而不是别人\n',
'[[575.0, 332.0], [633.0, 332.0], [633.0, 355.0], [575.0, 355.0]]\t画24岁\n',
'[[576.0, 367.0], [629.0, 367.0], [629.0, 389.0], [576.0, 389.0]]\t国汉族\n',
'[[577.0, 402.0], [657.0, 402.0], [657.0, 422.0], [577.0, 422.0]]\t广东广州\n',
'[[35.0, 416.0], [154.0, 416.0], [154.0, 443.0], [35.0, 443.0]]\t(国)求职意向\n',
'[[574.0, 436.0], [699.0, 430.0], [700.0, 453.0], [575.0, 459.0]]\tC13800138000\n',
'[[36.0, 474.0], [132.0, 474.0], [132.0, 494.0], [36.0, 494.0]]\t幼儿园老师\n',
'[[574.0, 471.0], [694.0, 466.0], [695.0, 486.0], [575.0, 490.0]]\t区bd@500d.me\n',
'[[575.0, 533.0], [695.0, 533.0], [695.0, 562.0], [575.0, 562.0]]\t()技能特长\n',
'[[34.0, 561.0], [155.0, 561.0], [155.0, 591.0], [34.0, 591.0]]\t()教育背景\n',
'[[35.0, 619.0], [122.0, 619.0], [122.0, 633.0], [35.0, 633.0]]\t2013.9-至今\n']
df[df['question_text'].apply(lambda x: '邮箱是多少' in x)]
# 1、ocr识别图片
# 2、问题的意图
# 3、问题和ocr的结果进行匹配
index | question_id | filename | question_text | answer_text | |
---|---|---|---|---|---|
96 | 97 | Q00097 | data/image/e506b03f95cfc0b0649e4edcdb2076300a9... | 邮箱是多少? | bd@500d.me |
376 | 377 | Q00377 | data/image/a11b4e53ee1b706c0a01c626d4b53ee7712... | 五百丁邮箱是多少? | bd@500d.me |
1954 | 1955 | Q01955 | data/image/e09b52455b9e351cf62b8537f4d06208a9c... | 五百丁的邮箱是多少? | bd@500d.me |
2088 | 2089 | Q02089 | data/image/word_1145.png | 这张简历的邮箱是多少? | 666666@qq.com |
2812 | 2813 | Q02813 | data/image/03d0ce91ee87f4939e64470c700d69a9058... | 该证券中吴立的邮箱是多少? | wuli1@tfzq.com |
2850 | 2851 | Q02851 | data/image/d326457bd0d87670c10e232ceef5a0ffecc... | 五百丁的邮箱是多少? | bd@500d.me |
2919 | 2920 | Q02920 | data/image/03d0ce91ee87f4939e64470c700d69a9058... | 图中杨烨辉的邮箱是多少? | yangyehui@tfzq.com |
3153 | 3154 | Q03154 | data/image/fa964e762d3d2ab7595931c1d9bdd628475... | 五百丁的邮箱是多少? | bd@500d.me |
3169 | 3170 | Q03170 | data/image/e8c8044dd0ba4c1b7665be4005b6835f314... | 五百丁的邮箱是多少? | bd@500d.me |
3283 | 3284 | Q03284 | data/image/ecd226c3b1db5dec169dad321465287ffad... | 五百丁的邮箱是多少? | bd@500d.me |
5002 | 5003 | Q05003 | data/image/054260010acde733be26cd74ad7fff4b77b... | 研究助理:薛绍阳的邮箱是多少? | sueshaoyang@mszq.com |
7032 | 7033 | Q07033 | data/image/c4b40cc2dc55ad0ce1909db20ab1d29fa36... | 五百丁在简历上留的邮箱是多少? | bd@500d.me |
7471 | 7472 | Q07472 | data/image/ba12ed9453422ec07f2866a3e69d7701af5... | 五百丁的邮箱是多少? | bd@500d.me |
7538 | 7539 | Q07539 | data/image/cb3eb2eb4f42bdac18dd9634c08687422b8... | 五百丁邮箱是多少? | bd@500.me |
8862 | 8863 | Q08863 | data/image/word_1117.png | 此人的邮箱是多少? | qmjianli@163.com |
9725 | 9726 | Q09726 | data/image/d3dfd339afdfd79102cd5dc3508ef106dfc... | 五百丁邮箱是多少? | bd@500d.me |
12592 | 12593 | Q12593 | data/image/d792e6f57fc699e729122b938777eee60b2... | 五百丁的邮箱是多少? | bd@500d.me |
13625 | 13626 | Q13626 | data/image/3266144112911b1370cbe9b0ebb78bce86c... | 五百丁的邮箱是多少? | bd@500d.me |
13767 | 13768 | Q13768 | data/image/f25e0d56fbe2f2cd7227bcadb30f3c5baca... | 五百丁的电子邮箱是多少? | bd@500d.me |
30950 | 30951 | Q30951 | data/image/AHEFBZ1Y2021M010251ATEMP_5249414_1.jpg | 投保人的电子邮箱是多少? | 54564560134.com |
31092 | 31093 | Q31093 | data/image/11525741dee150c477b4cc598d934afa5d4... | 分析师金敏的邮箱是多少? | jinm@ctsec.com |
31379 | 31380 | Q31380 | data/image/201907_71922e65-a8da-4b09-844d-8e55... | 锤子的电子邮箱是多少? | 6464646@qq.com |
34193 | 34194 | Q34194 | data/image/4c42b2a394fea24bb8e41a2e6a0f35376bf... | 五百丁的邮箱是多少? | bd@500d.me |
34369 | 34370 | Q34370 | data/image/79dbb85c53a538fc9fe1d38dcc22d949920... | 五百丁的邮箱是多少? | bd@500d.me |
35070 | 35071 | Q35071 | data/image/201907_f0116bfd-948c-4b3b-8bb9-6f06... | 这张简历的邮箱是多少? | 1234@qq.com |
35241 | 35242 | Q35242 | data/image/57cedc752dfc0f0037e2fc3771e51abe687... | 五百丁的邮箱是多少? | bd@500d.me |
35249 | 35250 | Q35250 | data/image/3dce2f8ef58180c531cbc8c9a271be4d30e... | 五百丁的邮箱是多少? | bd@500d,me |
35771 | 35772 | Q35772 | data/image/201907_f1d0da38-c6dc-426a-b4ee-a712... | 邮箱是多少? | 13800642@qq.com |
40771 | 40772 | Q40772 | data/image/201907_f4b4c7c1-0c3e-41ed-afa3-1ddf... | 求职者锤子的邮箱是多少? | docer @qq.com |
import re
# 对于所有的数据集,迭代每一行
# 步骤1:判断OCR是否识别
for row in df.iloc[:].iterrows():
qs = row[1].question_text
# ocr是否识别成功
if not os.path.exists('ocr_result/'+os.path.basename(row[1]['filename'])[:-4] + '.txt'):
continue
# 读取ocr识别结果
ocrs = codecs.open('ocr_result/'+os.path.basename(row[1]['filename'])[:-4] + '.txt').readlines()[:]
# 文字
ocr_text = [x.split('\t')[1].strip() for x in ocrs]
# 文本框
ocr_box = [x.split('\t')[0].strip() for x in ocrs]
if re.findall('什么药品', qs):
# pass
print(row[1].answer_text, '\t', ocr_text[0])
elif re.findall('说明书来源于哪里', qs):
candicate_text = list(set([x for x in ocr_text if '说明书' in x]))
candicate_text = [x for x in candicate_text if '说明书' in x][0]
candicate_text = candicate_text.replace('说明书', '').replace('来源', '').replace(':', '')
print(row[1].answer_text, candicate_text)
pass
elif re.findall('什么大学什么专业', qs):
candicate_text = list(set([x for x in ocr_text if '大学' in x and '专业' in x]))
print(row[1].answer_text, candicate_text[0])
elif re.findall('什么大学', qs):
candicate_text = list(set([x for x in ocr_text if re.findall('大学', x)]))
if len(candicate_text) == 0:
continue
# print(row[1].answer_text, candicate_text[0])
elif re.findall('什么专业', qs):
candicate_text = list(set([x for x in ocr_text if re.findall('本科', x)]))
if len(candicate_text) == 0:
continue
# print(row[1].answer_text, candicate_text[0])
# elif re.findall('电话是多少', qs):
# continue
# # break
elif re.findall('邮箱', qs):
candicate_text = list(set([x for x in ocr_text if re.findall('@', x)]))
if len(candicate_text) == 0:
continue
print(row[1].answer_text, candicate_text[0])
# 没有匹配成功怎么办
# XX之后是什么?,最近的框里面的文本进行回答
# box信息,字的大小信息,字号
# ocr结果
# XX
# YY
# break
# break
模型3.3.2 Bert
LayoutLM: Pre-training of Text and Layout for Document Image Understanding
介绍:LayoutLM利用文本分布的板式信息和识别到的文字信息,基于bert进行大规模预训练,然后在SER和RE任务进行微调;LayoutLMv2在LayoutLM的基础上,将图像视觉信息引入预训练阶段,对多模态信息进行更好的融合;LayoutXLM将LayoutLMv2扩展到多语言。
适用场景:针对卡证、票据等场景的信息提取、关系抽取、文档视觉问答任务。
参考资料:
论文:https://arxiv.org/pdf/1912.13318.pdf
https://huggingface.co/transformers/model_doc/layoutlm.html
更多模型选择请参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/transformers.md
Bert模型训练
可以直接使用本项目中转换后的文本数据训练,也可先执行下面命令生成自己的数据,训练( 注意:生成数据时间较长,请耐心等待 )
# !python gen_dataset.py
import paddle
import paddlenlp as ppnlp
from functools import partial
from paddlenlp.data import Stack, Dict, Pad
from utils import prepare_train_features, prepare_validation_features, evaluate
############参数配置###############
# 模型名称
MODEL_NAME = "bert-wwm-chinese"
# 根据官方文档可使用更多的模型,例如:BERT,ERNIE, RoBERTa等,之后可考虑集成的方法
# MODEL_NAME = "roberta-wwm-ext"
# 最大文本长度
max_seq_length = 512
# 文本滑动窗口步幅
doc_stride = 128
# 训练过程中的最大学习率
learning_rate = 2e-5
# 训练轮次
epochs = 12
# 数据批次大小
batch_size = 32 # 根据显存大小更改
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01
#############模型################
# 加载模型
# 请根据模型名称查看官方文档文档更换接口
model = ppnlp.transformers.BertForQuestionAnswering.from_pretrained(MODEL_NAME)
# model = ppnlp.transformers.RobertaForQuestionAnswering.from_pretrained(MODEL_NAME)
# 加载 tokenizer
# 请根据文档更换接口
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained(MODEL_NAME)
# tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained(MODEL_NAME)
#############数据###############
# 加载数据集
# 如果是自己生成的数据请更换为自己数据的路径
train_ds = ppnlp.datasets.load_dataset('dureader_robust', data_files='data/data83268/train.json')
dev_ds = ppnlp.datasets.load_dataset('dureader_robust', data_files='data/data83268/dev.json')
# 数据滑窗处理
train_trans_func = partial(prepare_train_features,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
tokenizer=tokenizer)
train_ds.map(train_trans_func, batched=True)
dev_trans_func = partial(prepare_validation_features,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
tokenizer=tokenizer)
dev_ds.map(dev_trans_func, batched=True)
# 数据读取器配置
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_ds, batch_size=batch_size, shuffle=True)
train_batchify_fn = lambda samples, fn=Dict({
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
"start_positions": Stack(dtype="int64"),
"end_positions": Stack(dtype="int64")
}): fn(samples)
train_data_loader = paddle.io.DataLoader(
dataset=train_ds,
batch_sampler=train_batch_sampler,
collate_fn=train_batchify_fn,
return_list=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=batch_size, shuffle=False)
dev_batchify_fn = lambda samples, fn=Dict({
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)
dev_data_loader = paddle.io.DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=dev_batchify_fn,
return_list=True)
#############优化器配置#############
# 学习率策略
num_training_steps = len(train_data_loader) * epochs
lr_scheduler = ppnlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
# 设置优化器
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
#############损失函数################
class CrossEntropyLossForSQuAD(paddle.nn.Layer):
def __init__(self):
super(CrossEntropyLossForSQuAD, self).__init__()
def forward(self, y, label):
start_logits, end_logits = y # both shape are [batch_size, seq_len]
start_position, end_position = label
start_position = paddle.unsqueeze(start_position, axis=-1)
end_position = paddle.unsqueeze(end_position, axis=-1)
start_loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=start_logits, label=start_position, soft_label=False)
start_loss = paddle.mean(start_loss)
end_loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=end_logits, label=end_position, soft_label=False)
end_loss = paddle.mean(end_loss)
loss = (start_loss + end_loss) / 2
return loss
#############模型训练################
# 实例化 loss
criterion = CrossEntropyLossForSQuAD()
global_step = 0
# 训练
for epoch in range(1, epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
global_step += 1
input_ids, segment_ids, start_positions, end_positions = batch
logits = model(input_ids=input_ids, token_type_ids=segment_ids)
loss = criterion(logits, (start_positions, end_positions))
if global_step % 100 == 0 :
print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss))
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
evaluate(model=model, data_loader=dev_data_loader)
# 保存
model.save_pretrained('/home/aistudio/checkpoint')
tokenizer.save_pretrained('/home/aistudio/checkpoint')
#运行上训练述代码,可以用四卡跑,若内存溢出可以调低batchsize,训练时间较长可以提前终止。
!python train.py
OCR获得测试集文本数据
#注意:如果使用aistudio平台这里需要自行去官网下载数据并上传
#!python gen_test.py
Bert模型预测结果
import paddle
import paddlenlp as ppnlp
from functools import partial
from paddlenlp.data import Dict, Pad
from utils import prepare_validation_features, predict
############参数配置###############
# 模型名称
MODEL_NAME = "bert-wwm-chinese"
# 根据官方文档可使用更多的模型,例如:BERT,ERNIE, RoBERTa等,之后可考虑集成的方法
# MODEL_NAME = "roberta-wwm-ext"
# 最大文本长度
max_seq_length = 512
# 文本滑动窗口步幅
doc_stride = 128
# 训练过程中的最大学习率
learning_rate = 3e-5
# 训练轮次
epochs = 4
# 数据批次大小
batch_size = 32
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01
#############模型################
# 加载模型
model = ppnlp.transformers.BertForQuestionAnswering.from_pretrained("训练得到的checkpoint文件夹")
# model = ppnlp.transformers.RobertaForQuestionAnswering.from_pretrained(MODEL_NAME)
# 更新参数
# state_dict = paddle.load('checkpoints/model_state.pdparams')
# model.state_dict(state_dict)
# 加载 tokenizer
# 请根据文档更换接口
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained("训练得到的checkpoint文件夹")
# tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained(MODEL_NAME)
#############数据###############
# 加载数据集
dev_ds = ppnlp.datasets.load_dataset('dureader_robust', data_files='ocr_result/test.json')
dev_trans_func = partial(prepare_validation_features,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
tokenizer=tokenizer)
dev_ds.map(dev_trans_func, batched=True)
# 数据读取器配置
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=batch_size, shuffle=False)
dev_batchify_fn = lambda samples, fn=Dict({
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)
dev_data_loader = paddle.io.DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=dev_batchify_fn,
return_list=True)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
# 预测
predict(model=model, data_loader=dev_data_loader)
#运行上述推理预测代码
预测
predict(model=model, data_loader=dev_data_loader)
#运行上述推理预测代码
!python infer.py
4.项目总结
本文的保险文本视觉认知问答项目,使用paddle框架将OCR与NLP技术相结合,根据不同文档图片数据集进行OCR识别,对问题和回答进行建模 。针对本次学习收获主要是对语言处理有进一步了解,熟悉规则匹配和nlp的Bert语言模型,后面会在ENRIE、GPT-2等模型做对比实验,结合模型参数和训练策略,选择准确率最高的模型作为最终项目应用。
参考项目:https://aistudio.baidu.com/aistudio/projectdetail/1842470
此文章为搬运
原项目链接
更多推荐
所有评论(0)