基于预训练模型的机器阅读理解

阅读理解是检索问答系统中的重要组成部分,最常见的数据集是单篇章、抽取式阅读理解数据集。

该示例展示了如何使用PaddleNLP快速实现基于预训练模型的机器阅读理解任务。

本示例使用的数据集是Dureaderrobust数据集。对于一个给定的问题q和一个篇章p,根据篇章内容,给出该问题的答案a。数据集中的每个样本,是一个三元组<q, p, a>,例如:

问题 q: 乔丹打了多少个赛季

篇章 p: 迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役…

参考答案 a: [‘15个’,‘15个赛季’]

阅读理解模型的鲁棒性是衡量该技术能否在实际应用中大规模落地的重要指标之一。随着当前技术的进步,模型虽然能够在一些阅读理解测试集上取得较好的性能,但在实际应用中,这些模型所表现出的鲁棒性仍然难以令人满意。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zmhBShdL-1656682591537)(https://ai-studio-static-online.cdn.bcebos.com/67bb14fa9db64908ba03288ea1bcd49fb523f77f8adf43129293411b89a2f2a4)]

本示例使用的Dureaderrobust数据集作为首个关注阅读理解模型鲁棒性的中文数据集,旨在考察模型在真实应用场景中的过敏感性、过稳定性以及泛化能力等问题。

关于该数据集的详细内容,可参考数据集论文,或官方比赛链接

安装说明

  • PaddlePaddle 安装

    本项目依赖于 PaddlePaddle 2.3 及以上版本,请参考 安装指南 进行安装

  • PaddleNLP 安装

    pip install --upgrade paddlenlp -i https://pypi.org/simple
    
  • 环境依赖

    Python的版本要求 3.7+

AI Studio平台默认安装了Paddle和PaddleNLP,并定期更新版本。
如需手动更新Paddle,可参考飞桨安装说明,安装相应环境下最新版飞桨框架。

使用如下命令确保安装最新版PaddleNLP:

!python -m pip install paddlepaddle-gpu==2.3.0.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
!pip install --upgrade paddlenlp

示例流程

与大多数NLP任务相同,本次机器阅读理解任务的示例展示分为以下四步:

首先我们从数据准备开始。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BQFnT3dJ-1656682591538)(https://ai-studio-static-online.cdn.bcebos.com/dd30e17318fb48fabb5701fd8a97be8176a1e372dd134cc0826e58cb5401933d)]

数据准备

数据准备流程如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VWFvy2lU-1656682591539)(https://ai-studio-static-online.cdn.bcebos.com/127ff46bf5c342889991438a95aa7158ec06c8852e58410ba7a516db330175f2)]

1. 加载PaddleNLP内置数据集

使用PaddleNLP提供的load_datasetAPI,即可一键完成数据集加载。

from datasets import load_dataset

train_examples = load_dataset('PaddlePaddle/dureader_robust', split="train")
dev_examples = load_dataset('PaddlePaddle/dureader_robust', split="validation")
test_examples = load_dataset('PaddlePaddle/dureader_robust', split="test")


for idx in range(2):
    print(train_examples[idx]['question'])
    print(train_examples[idx]['context'])
    print(train_examples[idx]['answers'])
    print()
Reusing dataset dureader_robust (/home/aistudio/.cache/huggingface/datasets/PaddlePaddle___dureader_robust/plain_text/1.0.0/1cd8a5be26918caf884ea444d2d909d813355d7972530ea00f82fad0962f8e95)
Reusing dataset dureader_robust (/home/aistudio/.cache/huggingface/datasets/PaddlePaddle___dureader_robust/plain_text/1.0.0/1cd8a5be26918caf884ea444d2d909d813355d7972530ea00f82fad0962f8e95)
Reusing dataset dureader_robust (/home/aistudio/.cache/huggingface/datasets/PaddlePaddle___dureader_robust/plain_text/1.0.0/1cd8a5be26918caf884ea444d2d909d813355d7972530ea00f82fad0962f8e95)


仙剑奇侠传3第几集上天界
第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。
{'text': ['第35集'], 'answer_start': [0]}

燃气热水器哪个牌子好
选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。
{'text': ['方太'], 'answer_start': [110]}

关于更多PaddleNLP数据集,请参考数据集列表

如果你想使用自己的数据集文件构建数据集,请参考以内置数据集格式读取本地数据集自定义数据集

2. 加载 paddlenlp.transformers.AutoTokenizer用于数据处理

DuReaderrubust数据集采用SQuAD数据格式,InputFeature使用滑动窗口的方法生成,即一个example可能对应多个InputFeature。

由于文章加问题的文本长度可能大于max_seq_length,答案出现的位置有可能出现在文章最后,所以不能简单的对文章进行截断。

那么对于过长的文章,则采用滑动窗口将文章分成多段,分别与问题组合。再用对应的tokenizer转化为模型可接受的feature。doc_stride参数就是每次滑动的距离。滑动窗口生成InputFeature的过程如下图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZrT21KTY-1656682591541)(https://ai-studio-static-online.cdn.bcebos.com/5776cf9ec00546bca047a0930c8f56a8b64d723e0ff04f269334522954bb7d90)]

本基线中,我们使用的预训练模型是ERNIE,ERNIE对中文数据的处理是以字为单位。PaddleNLP对于各种预训练模型已经内置了相应的tokenizer,指定想要使用的模型名字即可加载对应的tokenizer。

tokenizer的作用是将原始输入文本转化成模型可以接受的输入数据形式。

import paddlenlp
from paddlenlp.transformers import AutoTokenizer

# 设置模型名称
MODEL_NAME = 'ernie-3.0-medium-zh'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

3. 调用map()方法批量处理数据

由于我们传入了lazy=False,所以我们使用load_dataset()自定义的数据集是MapDataset对象。MapDatasetpaddle.io.Dataset的功能增强版本。其内置的map()方法适合用来进行批量数据集处理。

map()方法接受的主要参数是一个用于数据处理的function。正好可以与tokenizer相配合。

以下是本示例中的用法:

from utils import prepare_train_features, prepare_validation_features
from functools import partial

max_seq_length = 512
doc_stride = 128

train_trans_func = partial(prepare_train_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)


dev_trans_func = partial(prepare_validation_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)
    
column_names = train_examples.column_names
train_ds = train_examples.map(train_trans_func, batched=True, num_proc=4, remove_columns=column_names)
dev_ds = dev_examples.map(dev_trans_func, batched=True, num_proc=4, remove_columns=column_names)
test_ds = test_examples.map(dev_trans_func, batched=True, num_proc=4, remove_columns=column_names)

dev_ds_for_model = dev_ds.remove_columns(["example_id", "offset_mapping"])
test_ds_for_model = test_ds.remove_columns(["example_id", "offset_mapping"])
for idx in range(2):
    print(train_ds['input_ids'][idx])
    print(train_ds['token_type_ids'][idx])
    print(train_ds['start_positions'][idx])
    print(train_ds['end_positions'][idx])
[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 12043, 37, 10, 561, 125, 43, 8, 445, 86, 576, 65, 1448, 2969, 4, 469, 1586, 118, 776, 5, 1993, 2602, 4, 108, 25, 179, 51, 1993, 2602, 498, 1052, 122, 12043, 1082, 1994, 1616, 11, 262, 4, 518, 171, 813, 109, 1084, 270, 12043, 539, 8, 3006, 580, 11, 31, 4, 2473, 306, 34, 87, 889, 280, 846, 573, 12043, 561, 125, 14, 539, 889, 810, 276, 182, 4, 67, 351, 14, 889, 1182, 118, 776, 156, 952, 4, 539, 889, 16, 38, 4, 445, 15, 200, 61, 12043, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
14
16
[1, 1404, 266, 506, 101, 361, 1256, 27, 664, 85, 170, 2, 352, 790, 1404, 266, 506, 101, 361, 36, 4, 7, 91, 41, 129, 490, 47, 553, 27, 358, 281, 74, 208, 6, 39, 101, 862, 91, 92, 41, 170, 4, 16, 52, 39, 87, 1745, 506, 1745, 888, 5, 87, 528, 249, 6, 532, 537, 45, 302, 94, 91, 5, 413, 323, 101, 565, 284, 6, 868, 25, 41, 826, 52, 6, 58, 518, 397, 6, 204, 62, 92, 41, 170, 4, 41, 371, 9, 204, 62, 337, 1023, 371, 521, 99, 191, 28, 1404, 266, 506, 101, 361, 100, 664, 539, 65, 4, 817, 1042, 36, 201, 413, 65, 120, 51, 277, 14, 2081, 541, 1190, 348, 12043, 58, 512, 508, 17, 57, 445, 5, 1512, 73, 1664, 565, 506, 101, 361, 11, 175, 29, 82, 412, 58, 76, 388, 15, 62, 76, 658, 222, 74, 701, 1866, 537, 506, 4, 48, 532, 537, 71, 109, 1123, 1600, 469, 220, 12048, 101, 565, 303, 876, 862, 91, 4, 16, 32, 39, 87, 1745, 506, 1745, 888, 5, 87, 528, 4, 145, 124, 93, 101, 150, 3466, 231, 164, 133, 174, 39, 101, 565, 130, 326, 524, 586, 108, 11, 18010, 9479, 42, 39979, 4, 48, 596, 581, 50, 155, 707, 1358, 1443, 345, 1455, 1411, 1123, 455, 413, 323, 12048, 483, 366, 4850, 14, 6215, 9488, 653, 266, 82, 337, 1023, 371, 521, 263, 204, 62, 78, 99, 191, 28, 7, 689, 65, 13, 4850, 269, 266, 82, 337, 1023, 77, 12043, 770, 137, 4, 47, 699, 506, 101, 361, 201, 9, 826, 52, 4177, 756, 387, 369, 52, 4, 297, 413, 86, 763, 27, 247, 98, 3887, 444, 48, 29, 247, 98, 629, 163, 868, 25, 506, 101, 361, 4, 79, 87, 326, 378, 290, 377, 101, 565, 4, 596, 581, 50, 8, 65, 314, 73, 5, 1123, 1600, 413, 323, 12043, 153, 187, 58, 512, 5, 1512, 73, 1664, 565, 135, 517, 57, 41, 5, 10, 385, 120, 1512, 73, 369, 52, 4, 48, 22, 9, 344, 813, 912, 101, 12, 5, 754, 2337, 6, 754, 2880, 43, 702, 96, 792, 207, 4, 510, 735, 541, 1101, 1989, 21, 4, 175, 2873, 1600, 101, 207, 263, 1308, 1158, 4, 84, 195, 175, 29, 1512, 73, 101, 2873, 1600, 263, 217, 37, 262, 82, 691, 736, 12043, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
121
122

从以上结果可以看出,数据集中的example已经被转换成了模型可以接收的feature,包括input_ids、token_type_ids、答案的起始位置等信息。
其中:

  • input_ids: 表示输入文本的token ID。
  • token_type_ids: 表示对应的token属于输入的问题还是答案。(Transformer类预训练模型支持单句以及句对输入)。
  • overflow_to_sample: feature对应的example的编号。
  • offset_mapping: 每个token的起始字符和结束字符在原文中对应的index(用于生成答案文本)。
  • start_positions: 答案在这个feature中的开始位置。
  • end_positions: 答案在这个feature中的结束位置。

数据处理的详细过程请参见utils.py

更多有关数据处理的内容,请参考数据处理

4. Batchify和数据读入

使用paddle.io.BatchSamplerpaddlenlp.data中提供的方法把数据组成batch。

然后使用paddle.io.DataLoader接口多线程异步加载数据。

batchify_fn详解:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XU2g6HWb-1656682591542)(https://ai-studio-static-online.cdn.bcebos.com/30e43d4659384375a2a2c1b890ca5a995c4324d7168e49cebf1d2a1e99161f7d)]

import paddle
from paddlenlp.data import DataCollatorWithPadding

batch_size = 12

# 定义BatchSampler
train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=batch_size, shuffle=True)

dev_batch_sampler = paddle.io.BatchSampler(
    dev_ds, batch_size=batch_size, shuffle=False)

test_batch_sampler = paddle.io.BatchSampler(
    test_ds, batch_size=batch_size, shuffle=False)

# 定义batchify_fn

train_batchify_fn = DataCollatorWithPadding(tokenizer)

dev_batchify_fn = DataCollatorWithPadding(tokenizer)

# 构造DataLoader
train_data_loader = paddle.io.DataLoader(
    dataset=train_ds,
    batch_sampler=train_batch_sampler,
    collate_fn=train_batchify_fn,
    return_list=True)

dev_data_loader = paddle.io.DataLoader(
    dataset=dev_ds_for_model,
    batch_sampler=dev_batch_sampler,
    collate_fn=dev_batchify_fn,
    return_list=True)

test_data_loader = paddle.io.DataLoader(
    dataset=test_ds_for_model,
    batch_sampler=test_batch_sampler,
    collate_fn=dev_batchify_fn,
    return_list=True)

更多PaddleNLP内置的batchify相关API,请参考collate

到这里数据集准备就全部完成了,下一步我们需要组网并设计loss function。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QKjSccis-1656682591543)(https://ai-studio-static-online.cdn.bcebos.com/fdcb44a00ede4ce08ae2652931556fb58cc903f686bf491792489353d2800e7d)]

模型结构

使用PaddleNLP一键加载预训练模型

以下项目以ERNIE为例,介绍如何将预训练模型Fine-tune完成DuReaderrobust阅读理解任务。

DuReaderrobust阅读理解任务的本质是答案抽取任务。根据输入的问题和文章,从预训练模型的sequence_output中预测答案在文章中的起始位置和结束位置。原理如下图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jhM1dZ2Y-1656682591545)(https://ai-studio-static-online.cdn.bcebos.com/bb1396fc12614dbabcfb4fcfafe9346507d4d65d0a194d75aba04b9d31bace6b)]

目前PaddleNLP已经内置了包括ERNIE在内的多种基于预训练模型的常用任务的下游网络,包括机器阅读理解。

这些网络在paddlenlp.transformers下,均可实现一键调用。

from paddlenlp.transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
[2022-06-28 16:37:22,959] [    INFO] - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForQuestionAnswering'> to load 'ernie-3.0-base-zh'.
[2022-06-28 16:37:22,960] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/ernie_3.0_base_zh.pdparams
W0628 16:37:22.963842 20601 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0628 16:37:22.967849 20601 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.

设计loss function

模型的网络结构确定后我们就可以设计loss function了。

AutoModelForQuestionAnswering模型对将ErnieModel的sequence_output拆开成start_logits和end_logits输出,所以DuReaderrobust的loss由start_loss和end_loss两部分组成,我们需要自己定义loss function。

对于答案起始位置和结束位置的预测可以分别看成两个分类任务。所以设计的loss function如下:

class CrossEntropyLossForRobust(paddle.nn.Layer):
    def __init__(self):
        super(CrossEntropyLossForRobust, self).__init__()

    def forward(self, y, label):
        start_logits, end_logits = y
        start_position, end_position = label
        start_position = paddle.unsqueeze(start_position, axis=-1)
        end_position = paddle.unsqueeze(end_position, axis=-1)
        start_loss = paddle.nn.functional.cross_entropy(
            input=start_logits, label=start_position)
        end_loss = paddle.nn.functional.cross_entropy(
            input=end_logits, label=end_position)
        loss = (start_loss + end_loss) / 2
        return loss

选择网络结构后,我们需要设置Fine-Tune优化策略。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j3o1h1n7-1656682591546)(https://ai-studio-static-online.cdn.bcebos.com/7eca6595f338409498149cb586c077ba4933739810cf436080a2292be7e0a92d)]

设置Fine-Tune优化策略

适用于ERNIE/BERT这类Transformer模型的学习率为warmup的动态学习率。



图3:动态学习率示意图
# 训练过程中的最大学习率
learning_rate = 3e-5 

# 训练轮次
epochs = 2

# 学习率预热比例
warmup_proportion = 0.1

# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01

num_training_steps = len(train_data_loader) * epochs

# 学习率衰减策略
lr_scheduler = paddlenlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)

decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

现在万事俱备,我们可以开始训练阅读理解模型啦。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Hi54nGKW-1656682591547)(https://ai-studio-static-online.cdn.bcebos.com/6975542d488f4f75b385fe75d574a3aaa8e208f5e99f4acd8a8e8aea3b85c058)]

模型训练与评估

模型训练的过程通常有以下步骤:

  1. 从dataloader中取出一个batch data。
  2. 将batch data喂给model,做前向计算。
  3. 将前向计算结果传给损失函数,计算loss。
  4. loss反向回传,更新梯度。重复以上步骤。

每训练一个epoch时,程序通过evaluate()调用paddlenlp.metric.squad中的squad_evaluate(), compute_predictions()评估当前模型训练的效果,其中:

  • compute_predictions()用于生成可提交的答案;

  • squad_evaluate()用于返回评价指标。

二者适用于所有符合squad数据格式的答案抽取任务。这类任务使用F1和exact来评估预测的答案和真实答案的相似程度。

from utils import evaluate

criterion = CrossEntropyLossForRobust()
global_step = 0
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        global_step += 1
        input_ids, segment_ids, start_positions, end_positions = batch
        logits = model(input_ids=batch["input_ids"], token_type_ids=batch["token_type_ids"])
        loss = criterion(logits, (batch["start_positions"], batch["end_positions"]))

        if global_step % 100 == 0 :
            print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss))

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()
global step 100, epoch: 1, batch: 100, loss: 4.31947
global step 200, epoch: 1, batch: 200, loss: 1.50273
global step 300, epoch: 1, batch: 300, loss: 1.38390
global step 400, epoch: 1, batch: 400, loss: 1.27645
global step 500, epoch: 1, batch: 500, loss: 1.43703
global step 600, epoch: 1, batch: 600, loss: 1.30268
global step 700, epoch: 1, batch: 700, loss: 1.06415
global step 800, epoch: 1, batch: 800, loss: 2.09669
global step 900, epoch: 1, batch: 900, loss: 0.87356
global step 1000, epoch: 1, batch: 1000, loss: 1.44776
global step 1100, epoch: 1, batch: 1100, loss: 0.77803
global step 1200, epoch: 1, batch: 1200, loss: 1.50059
global step 1300, epoch: 1, batch: 1300, loss: 1.02779
global step 1400, epoch: 1, batch: 1400, loss: 1.14965
global step 1500, epoch: 2, batch: 29, loss: 1.23338
global step 1600, epoch: 2, batch: 129, loss: 1.45627
global step 1700, epoch: 2, batch: 229, loss: 1.28679
global step 1800, epoch: 2, batch: 329, loss: 0.83966
global step 1900, epoch: 2, batch: 429, loss: 1.11859
global step 2000, epoch: 2, batch: 529, loss: 0.92468
global step 2100, epoch: 2, batch: 629, loss: 0.58391
global step 2200, epoch: 2, batch: 729, loss: 0.90198
global step 2300, epoch: 2, batch: 829, loss: 1.04647
global step 2400, epoch: 2, batch: 929, loss: 1.01556
global step 2500, epoch: 2, batch: 1029, loss: 0.93802
global step 2600, epoch: 2, batch: 1129, loss: 0.69602
global step 2700, epoch: 2, batch: 1229, loss: 1.43215
global step 2800, epoch: 2, batch: 1329, loss: 0.87333
global step 2900, epoch: 2, batch: 1429, loss: 0.64515
# 传入test_data_loader,并将is_test参数设为True,即可生成千言比赛可提交的结果。
evaluate(model=model, raw_dataset=test_examples, dataset=test_ds, data_loader=test_data_loader, is_test=True) 
Processing example: 1000
time per 1000: 10.584674596786499
Processing example: 2000
time per 1000: 10.188016891479492
Processing example: 3000
time per 1000: 9.92471432685852
Processing example: 4000
time per 1000: 10.358511924743652
Processing example: 5000
time per 1000: 10.093605041503906
Processing example: 6000
time per 1000: 10.080277681350708
Processing example: 7000
time per 1000: 9.987485647201538
Processing example: 8000
time per 1000: 10.222150325775146
Processing example: 9000
time per 1000: 10.706721782684326
Processing example: 10000
time per 1000: 10.512943506240845
Processing example: 11000
time per 1000: 10.303147554397583
Processing example: 12000
time per 1000: 10.55568528175354
Processing example: 13000
time per 1000: 10.26531982421875
Processing example: 14000
time per 1000: 10.266725778579712
Processing example: 15000
time per 1000: 10.64447546005249
Processing example: 16000
time per 1000: 9.920329093933105
Processing example: 17000
time per 1000: 10.208895206451416
Processing example: 18000
time per 1000: 10.26574993133545
Processing example: 19000
time per 1000: 9.914969682693481
Processing example: 20000
time per 1000: 10.20920467376709
Processing example: 21000
time per 1000: 10.57719898223877
Processing example: 22000
time per 1000: 10.71744441986084
Processing example: 23000
time per 1000: 10.45024824142456
Processing example: 24000
time per 1000: 10.221151113510132
Processing example: 25000
time per 1000: 10.319378137588501
Processing example: 26000
time per 1000: 10.167399168014526
Processing example: 27000
time per 1000: 10.86082124710083
Processing example: 28000
time per 1000: 10.394505500793457
Processing example: 29000
time per 1000: 10.35581088066101
Processing example: 30000
time per 1000: 10.169897556304932
Processing example: 31000
time per 1000: 9.831788301467896
Processing example: 32000
time per 1000: 10.149449348449707
Processing example: 33000
time per 1000: 10.106878757476807
Processing example: 34000
time per 1000: 10.475741624832153
Processing example: 35000
time per 1000: 9.970959901809692
Processing example: 36000
time per 1000: 10.39727783203125
Processing example: 37000
time per 1000: 9.979295015335083
Processing example: 38000
time per 1000: 10.217779397964478
Processing example: 39000
time per 1000: 10.340791702270508
Processing example: 40000
time per 1000: 10.340667247772217
Processing example: 41000
time per 1000: 10.436401605606079
Processing example: 42000
time per 1000: 10.371085405349731
Processing example: 43000
time per 1000: 10.644593238830566
Processing example: 44000
time per 1000: 10.25355052947998
Processing example: 45000
time per 1000: 10.114908933639526
Processing example: 46000
time per 1000: 10.181347131729126
Processing example: 47000
time per 1000: 10.038010835647583
Processing example: 48000
time per 1000: 10.707454204559326
Processing example: 49000
time per 1000: 10.093716859817505
Processing example: 50000
time per 1000: 9.750082731246948
Processing example: 51000
time per 1000: 10.104126214981079
Processing example: 52000
time per 1000: 10.051070213317871
Processing example: 53000
time per 1000: 9.974929094314575
Processing example: 54000
time per 1000: 10.26790714263916
Processing example: 55000
time per 1000: 10.284163236618042
Processing example: 56000
time per 1000: 10.124592542648315
Processing example: 57000
time per 1000: 10.628040075302124
Processing example: 58000
time per 1000: 10.414206981658936
Processing example: 59000
time per 1000: 10.146691083908081
Processing example: 60000
time per 1000: 10.072739124298096
Processing example: 61000
time per 1000: 9.900139808654785
Processing example: 62000
time per 1000: 10.332890510559082
Processing example: 63000
time per 1000: 10.174994707107544
Processing example: 64000
time per 1000: 10.424129486083984
Processing example: 65000
time per 1000: 10.393558740615845
Processing example: 66000
time per 1000: 10.283146142959595
Processing example: 67000
time per 1000: 9.896865606307983
Processing example: 68000
time per 1000: 9.99543833732605
Processing example: 69000
time per 1000: 10.295023441314697

问题: 220v一安等于多少瓦
原文: 在220交流电的状态下一安等于220瓦.基于32太1.5匹用多大的开关计算方法是:1匹=0.735瓦. 0.735*1.5*32=35.28千瓦  1千瓦=4.5安  35.28*4.5=158.26安  这儿是实际的电流,在现实应用过程中不能用160安的  开关,单个1.5匹启动时有一个较大的启动电流,在实际用用是乘以105倍:158.26*1.5=237.39安. 开关的电流应该是250A的空气开关.
答案: 220瓦

问题: 氧化铜和稀盐酸的离子方程式
原文: 化学方程式:CuO+2HCl=CuCl2+H2O 书写离子方程式时,只有强电解质(强酸、强碱、盐)拆开写成离子形式. 离子方程式: CuO+2H+=Cu^2+ +H2O
答案: CuO+2H+=Cu^2+ +H2O

问题: 刀塔传奇98元英雄排行
原文: 让我们把目光放到刀塔传奇中去,看看那些在竞技场的刀山火海中异军突起的英雄,很多只是得益于一个小小的改动,却改变了竞技场的整个格局。http://www.18183.com/dtcq/syzs/wanjiayc/164981.html|1艾吉奥梦境打架都可以2科学怪人梦境团本高分必备3魔像可以打吸血鬼一个梦境4白银很看好她期待她的觉醒5凹凸曼打猴子,其他可有可无
答案: 艾吉奥

问题: 新款汉兰达什么时候上市
原文: 新一代丰田汉兰达虽然早在2013年的纽约车展上已经正式亮相,但依然迟迟未见在国内出现。据目前获得的最新消息称,现款汉兰达将于2015年1月正式停产,新一代车型预计将于2015年一季度末投产,并于2015年上半年正式亮相。目前广汽丰田新一代汉兰达申报图已经在网上曝光。
答案: 2015年上半年

问题: 一立方分米等于多少立方米
原文: 正方体用米作单位计算:  1*1*1=1立方米 把米化成分米做单位就是10分米,用分米作单位计算:  10*10*10=1000立方分米 所以,1立方米=1000立方分米
答案: 1000立方分米

问题: 橙子和什么一起榨汁好喝
原文: 醇厚:香蕉,木瓜,火龙果,草莓,芒果,杏子。 清新:柠檬,黄瓜,雪梨,西瓜,柚子,芹菜,哈密瓜难以区别:西红柿,苹果,橙子,桃子,猕猴桃,菠萝,柚子,樱桃 很多人都试过:苹果+橙子+雪梨;雪梨+黄瓜;柚子+苹果;西红柿+柠檬+柚子
答案: 雪梨

以上基线实现基于PaddleNLP,开源不易,希望大家多多支持~
记得给PaddleNLP点个小小的Star⭐

GitHub地址:https://github.com/PaddlePaddle/PaddleNLP
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YialCNdt-1656682591549)(https://ai-studio-static-online.cdn.bcebos.com/5970708b1c584e73bff9f8937685f5477f7fed902a164cd184bffb073307ce61)]

更多使用方法可参考PaddleNLP教程

加入交流群,一起学习吧

现在就加入PaddleNLP的QQ技术交流群,一起交流NLP技术吧!

image

作者仅为AiStudio搬运,原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/2017189
/1329361)

加入交流群,一起学习吧

现在就加入PaddleNLP的QQ技术交流群,一起交流NLP技术吧!

image

作者仅为AiStudio搬运,原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/2017189

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐