基于BERT预训练模型的阅读理解任务

机器阅读理解是自然语言处理中的一个重要的任务,最常见的有单篇章的抽取式阅读理解。机器阅读理解的应用范围很广,比如客服机器人,通过文字或者语音与用户进行沟通交流,然后获取相关的信息并提供准确可靠的回答。搜索引擎中精确返回用户所给定问题的答案。在医疗领域中自动阅读病人的资料来找到相应的病因。在教育领域中,利用阅读理解模型自动为学生的作文给出改进意见等等。


学习资源


⭐ ⭐ ⭐ 欢迎点个小小的Star,开源不易,希望大家多多支持~⭐ ⭐ ⭐

一、方案设计

阅读理解的方案如上图,首先是query表示的是问句,一般是用户的提问,passage表示的是文章,表示的是query的答案要从passage里面抽取出来,query和passage经过数据预处理,得到id形式的输入,然后把query,passage的id形式输入到BERT模型里面,BERT模型经过处理会输出答案的位置,输出位置以后就可以得到相应的答案了。

二、 数据处理

具体的任务定义为:对于一个给定的问题q和一个篇章p,根据篇章内容,给出该问题的答案a。数据集中的每个样本,是一个三元组<q, p, a>,例如:

问题 q: 乔丹打了多少个赛季

篇章 p: 迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役…

参考答案 a: [‘15个’,‘15个赛季’]

阅读理解模型的鲁棒性是衡量该技术能否在实际应用中大规模落地的重要指标之一。随着当前技术的进步,模型虽然能够在一些阅读理解测试集上取得较好的性能,但在实际应用中,这些模型所表现出的鲁棒性仍然难以令人满意。本示例使用的DuReader-robust数据集作为首个关注阅读理解模型鲁棒性的中文数据集,旨在考察模型在真实应用场景中的过敏感性、过稳定性以及泛化能力等问题。

关于该数据集的详细内容,可参考数据集论文,或官方比赛链接

首先导入实验所需要用到的库包。

from paddlenlp.datasets import load_dataset
import paddlenlp as ppnlp
from utils import prepare_train_features, prepare_validation_features
from functools import partial
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction

import collections
import time
import json

2.1 数据集加载

PaddleNLP已经内置SQuAD,CMRC等中英文阅读理解数据集,使用paddlenlp.datasets.load_dataset()API即可一键加载。本实例加载的是DuReaderRobust中文阅读理解数据集。由于DuReaderRobust数据集采用SQuAD数据格式,InputFeature使用滑动窗口的方法生成,即一个example可能对应多个InputFeature。

答案抽取任务即根据输入的问题和文章,预测答案在文章中的起始位置和结束位置。

由于文章加问题的文本长度可能大于max_seq_length,答案出现的位置有可能出现在文章最后,所以不能简单的对文章进行截断。

那么对于过长的文章,则采用滑动窗口将文章分成多段,分别与问题组合。再用对应的tokenizer转化为模型可接受的feature。doc_stride参数就是每次滑动的距离。滑动窗口生成InputFeature的过程如下图:


图2:滑动窗口生成InputFeature示意图
train_ds, dev_ds = ppnlp.datasets.load_dataset('dureader_robust', splits=('train', 'dev'))

for idx in range(2):
    print(train_ds[idx]['question'])
    print(train_ds[idx]['context'])
    print(train_ds[idx]['answers'])
    print(train_ds[idx]['answer_starts'])
    print()
2021-10-29 15:21:06,944 - INFO - unique_endpoints {''}


仙剑奇侠传3第几集上天界
第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。
['第35集']
[0]

燃气热水器哪个牌子好
选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。
['方太']
[110]

ppnlp.transformers.BertTokenizer

调用BertTokenizer进行数据处理。

预训练模型Bert对中文数据的处理是以字为单位。PaddleNLP对于各种预训练模型已经内置了相应的tokenizer,指定想要使用的模型名字即可加载对应的tokenizer。

tokenizer的作用是将原始输入文本转化成模型可以接受的输入数据形式。

MODEL_NAME = "bert-base-chinese"
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained(MODEL_NAME)
[2021-10-29 15:21:07,358] [    INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt

2.2 数据处理

使用load_dataset()API默认读取到的数据集是MapDataset对象,MapDatasetpaddle.io.Dataset的功能增强版本。其内置的map()方法适合用来进行批量数据集处理。map()方法传入的是一个用于数据处理的function。
以下是Dureader-Robust中数据转化的用法:

max_seq_length = 512
doc_stride = 128

train_trans_func = partial(prepare_train_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)

train_ds.map(train_trans_func, batched=True)

dev_trans_func = partial(prepare_validation_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)
                           
dev_ds.map(dev_trans_func, batched=True)
<paddlenlp.datasets.dataset.MapDataset at 0x7fe07c6d7f10>
for idx in range(2):
    print(train_ds[idx]['input_ids'])
    print(train_ds[idx]['token_type_ids'])
    print(train_ds[idx]['overflow_to_sample'])
    print(train_ds[idx]['offset_mapping'])
    print(train_ds[idx]['start_positions'])
    print(train_ds[idx]['end_positions'])
    print()
[101, 803, 1187, 1936, 899, 837, 124, 5018, 1126, 7415, 677, 1921, 4518, 102, 5018, 8198, 7415, 7434, 6224, 5353, 5353, 2476, 2458, 4706, 4714, 8024, 3250, 1921, 1348, 2661, 1348, 1599, 722, 7354, 8024, 7270, 1321, 1469, 5166, 5858, 4638, 803, 5670, 7724, 5635, 8024, 6224, 830, 782, 3187, 2610, 8024, 738, 1282, 1146, 7770, 1069, 511, 830, 782, 4633, 5670, 8024, 4500, 2226, 1394, 1213, 2828, 5632, 6716, 4638, 4696, 3698, 1469, 3717, 1146, 6783, 5314, 1961, 511, 7434, 6224, 5303, 754, 7008, 6814, 3341, 749, 8024, 852, 1316, 671, 5567, 3312, 4197, 8024, 1059, 3187, 1353, 2418, 511, 830, 782, 1403, 2382, 5530, 3724, 1221, 8024, 1316, 1355, 4385, 782, 686, 4518, 4994, 3766, 3300, 7434, 6224, 4638, 6716, 686, 5279, 2497, 511, 7270, 1321, 6418, 7309, 3926, 2544, 4638, 6716, 686, 8024, 3926, 2544, 6427, 2372, 1352, 1068, 6432, 671, 1147, 677, 749, 1921, 4518, 912, 3300, 5031, 3428, 511, 7270, 1321, 7730, 7724, 803, 5670, 8024, 830, 782, 1104, 2137, 4989, 7716, 1220, 6716, 8024, 2518, 1921, 4518, 5445, 1343, 511, 830, 782, 3341, 1168, 671, 5774, 2255, 8024, 7270, 1321, 2900, 1139, 8024, 7795, 4518, 1469, 1921, 4518, 4685, 6825, 511, 4507, 7795, 4518, 6822, 1057, 6858, 6814, 4868, 7795, 722, 759, 8024, 912, 1377, 4633, 1921, 511, 830, 782, 5635, 7795, 4518, 1057, 1366, 8024, 820, 5735, 671, 7946, 5682, 4638, 6073, 6075, 3822, 8024, 852, 1993, 5303, 3187, 3791, 6822, 1057, 511, 1400, 3341, 5709, 3516, 1355, 4385, 1372, 6206, 3300, 5420, 5598, 912, 5543, 7607, 1057, 511, 754, 3221, 3250, 1921, 5023, 782, 2802, 678, 6387, 1914, 723, 7887, 8024, 3563, 820, 7028, 3517, 4638, 5420, 5598, 8024, 1169, 868, 3144, 2190, 5420, 5598, 4307, 2342, 4289, 511, 1157, 877, 2785, 1762, 6716, 8024, 912, 6158, 1429, 1057, 3822, 1366, 511, 830, 782, 3035, 5862, 1762, 1765, 8024, 2848, 1928, 1355, 4385, 7795, 4518, 2127, 1310, 511, 3250, 1921, 1469, 830, 7795, 1947, 769, 2658, 8024, 5632, 4917, 1469, 7795, 2203, 7028, 3517, 4685, 4225, 8024, 830, 7795, 679, 4415, 8024, 2802, 749, 6629, 3341, 511, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
0
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 261), (261, 262), (262, 263), (263, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 277), (277, 278), (278, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (325, 326), (326, 327), (327, 328), (328, 329), (329, 330), (330, 331), (331, 332), (0, 0)]
14
16

[101, 4234, 3698, 4178, 3717, 1690, 1525, 702, 4277, 2094, 1962, 102, 6848, 2885, 4234, 3698, 4178, 3717, 1690, 3198, 8024, 671, 2137, 6206, 1068, 3800, 6821, 1126, 702, 7309, 7579, 8038, 122, 510, 1139, 3717, 4937, 2137, 2595, 6206, 1962, 8024, 679, 5543, 1139, 4385, 2575, 4178, 2575, 1107, 4638, 4385, 6496, 123, 510, 2571, 6862, 1168, 6809, 6392, 2137, 4638, 7444, 3724, 3717, 3946, 124, 510, 3082, 868, 6206, 3255, 5543, 510, 3175, 912, 125, 510, 2128, 1059, 2595, 6206, 1962, 8024, 6206, 6163, 3300, 2128, 1059, 2845, 6356, 6163, 5390, 2356, 1767, 677, 4234, 3698, 4178, 3717, 1690, 1501, 4277, 830, 1914, 8024, 6579, 743, 3198, 6820, 7444, 1914, 1217, 2190, 3683, 1469, 798, 5301, 7063, 1166, 511, 3175, 1922, 791, 2399, 712, 2802, 4638, 4828, 1265, 2608, 3946, 4178, 3717, 1690, 1762, 886, 4500, 860, 7741, 3175, 7481, 976, 749, 1059, 7481, 1285, 5277, 8038, 130, 4907, 6862, 4178, 8024, 1377, 2571, 6862, 6822, 1057, 3819, 3861, 3563, 2466, 8039, 3717, 3946, 2898, 719, 4937, 2137, 8024, 679, 833, 1139, 4385, 2575, 4178, 2575, 1107, 4638, 4385, 6496, 8024, 2400, 6858, 6814, 3717, 7030, 848, 3302, 2825, 3318, 2199, 1139, 3717, 3946, 2428, 5125, 4802, 2971, 1169, 1762, 11349, 119, 9687, 8024, 1377, 4007, 6639, 2157, 7027, 2140, 6564, 3130, 2697, 5491, 5502, 3819, 2844, 7444, 3724, 8039, 6981, 1906, 100, 1469, 100, 1352, 3698, 860, 2845, 6356, 6163, 5390, 3291, 2128, 1059, 8020, 2356, 1767, 677, 671, 5663, 1914, 711, 100, 1296, 3698, 860, 2845, 6356, 8021, 511, 1369, 1912, 8024, 6821, 3621, 4178, 3717, 1690, 6820, 3300, 3255, 5543, 100, 757, 5468, 1216, 5543, 8024, 1372, 7444, 678, 6770, 702, 2797, 3322, 100, 1315, 1377, 4500, 2797, 3322, 6823, 4923, 3082, 868, 4178, 3717, 1690, 8024, 2141, 4385, 5125, 1114, 6444, 5688, 3717, 3946, 8024, 4007, 6639, 2157, 782, 1914, 3416, 1265, 4638, 3819, 3861, 7444, 3724, 511, 2496, 4197, 3175, 1922, 4638, 4828, 1265, 2608, 3946, 5143, 1154, 712, 6206, 4638, 3221, 1872, 1217, 4828, 1265, 1216, 5543, 8024, 1377, 809, 3300, 3126, 1429, 7353, 3717, 704, 4638, 7188, 7224, 510, 7188, 2244, 5023, 2544, 2207, 3325, 6574, 8024, 7344, 3632, 5301, 5826, 3996, 4495, 8024, 886, 3759, 3861, 3717, 6574, 3291, 3815, 1112, 8024, 7270, 3309, 886, 4500, 4828, 1265, 3717, 3759, 3861, 3291, 1164, 754, 6716, 860, 978, 2434, 511, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
1
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 193), (193, 194), (194, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 217), (217, 218), (218, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (325, 326), (326, 327), (327, 328), (328, 329), (329, 330), (330, 331), (331, 332), (332, 333), (333, 334), (334, 335), (335, 336), (336, 337), (337, 338), (338, 339), (339, 340), (340, 341), (341, 342), (342, 343), (343, 344), (344, 345), (345, 346), (346, 347), (347, 348), (348, 349), (349, 350), (350, 351), (351, 352), (352, 353), (353, 354), (354, 355), (355, 356), (356, 357), (357, 358), (358, 359), (359, 360), (360, 361), (361, 362), (362, 363), (363, 364), (364, 365), (365, 366), (366, 367), (367, 368), (368, 369), (369, 370), (370, 371), (371, 372), (372, 373), (373, 374), (374, 375), (375, 376), (376, 377), (377, 378), (378, 379), (379, 380), (380, 381), (381, 382), (382, 383), (383, 384), (384, 385), (385, 386), (386, 387), (387, 388), (388, 389), (0, 0)]
121
122

从以上结果可以看出,数据集中的example已经被转换成了模型可以接收的feature,包括input_ids、token_type_ids、答案的起始位置等信息。
其中:

  • input_ids: 表示输入文本的token ID。
  • token_type_ids: 表示对应的token属于输入的问题还是答案。(Transformer类预训练模型支持单句以及句对输入)。
  • overflow_to_sample: feature对应的example的编号。
  • offset_mapping: 每个token的起始字符和结束字符在原文中对应的index(用于生成答案文本)。
  • start_positions: 答案在这个feature中的开始位置。
  • end_positions: 答案在这个feature中的结束位置。

2.3 构造Dataloader

使用paddle.io.DataLoader接口多线程异步加载数据。同时使用paddlenlp.data中提供的方法把feature组成batch

import paddle
from paddlenlp.data import Stack, Dict, Pad

batch_size = 8

train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=batch_size, shuffle=True)

train_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    "start_positions": Stack(dtype="int64"),
    "end_positions": Stack(dtype="int64")
}): fn(samples)

train_data_loader = paddle.io.DataLoader(
    dataset=train_ds,
    batch_sampler=train_batch_sampler,
    collate_fn=train_batchify_fn,
    return_list=True)

dev_batch_sampler = paddle.io.BatchSampler(
    dev_ds, batch_size=batch_size, shuffle=False)

dev_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)

dev_data_loader = paddle.io.DataLoader(
    dataset=dev_ds,
    batch_sampler=dev_batch_sampler,
    collate_fn=dev_batchify_fn,
    return_list=True)

三、模型构建

阅读理解本质是一个答案抽取任务,PaddleNLP对于各种预训练模型已经内置了对于下游任务-答案抽取的Fine-tune网络

以下项目以BERT为例,介绍如何将预训练模型Fine-tune完成答案抽取任务。

答案抽取任务的本质就是根据输入的问题和文章,预测答案在文章中的起始位置和结束位置。基于BERT的答案抽取原理如下图所示:



图1:基于BERT的答案抽取原理示意图

paddlenlp.transformers.BertForQuestionAnswering()

一行代码即可加载预训练模型BERT用于答案抽取任务的Fine-tune网络。

paddlenlp.transformers.BertForQuestionAnswering.from_pretrained()

指定想要使用的模型名称和文本分类的类别数,一行代码完成网络构建。

# 设置想要使用模型的名称
model = ppnlp.transformers.BertForQuestionAnswering.from_pretrained(MODEL_NAME)
[2021-10-29 15:22:02,890] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams

四、模型配置

4.1 设置Fine-Tune优化策略

适用于ERNIE/BERT这类Transformer模型的学习率为warmup的动态学习率。



图3:动态学习率示意图
# 训练过程中的最大学习率
learning_rate = 3e-5 
# 训练轮次
epochs = 1
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01

num_training_steps = len(train_data_loader) * epochs
lr_scheduler = ppnlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

4.2 设计loss function

由于BertForQuestionAnswering模型对将BertModel的sequence_output拆开成start_logits和end_logits进行输出,所以阅读理解任务的loss也由start_loss和end_loss组成,我们需要自己定义loss function。对于答案其实位置和结束位置的预测可以分别成两个分类任务。所以设计的loss function如下:

class CrossEntropyLossForSQuAD(paddle.nn.Layer):
    def __init__(self):
        super(CrossEntropyLossForSQuAD, self).__init__()

    def forward(self, y, label):
        start_logits, end_logits = y   # both shape are [batch_size, seq_len]
        start_position, end_position = label
        start_position = paddle.unsqueeze(start_position, axis=-1)
        end_position = paddle.unsqueeze(end_position, axis=-1)
        start_loss = paddle.nn.functional.softmax_with_cross_entropy(
            logits=start_logits, label=start_position, soft_label=False)
        start_loss = paddle.mean(start_loss)
        end_loss = paddle.nn.functional.softmax_with_cross_entropy(
            logits=end_logits, label=end_position, soft_label=False)
        end_loss = paddle.mean(end_loss)

        loss = (start_loss + end_loss) / 2
        return loss

五、模型训练

模型训练的过程通常有以下步骤:

  1. 从dataloader中取出一个batch data
  2. 将batch data喂给model,做前向计算
  3. 将前向计算结果传给损失函数,计算loss。
  4. loss反向回传,更新梯度。重复以上步骤。

每训练一个epoch时,程序通过evaluate()调用paddlenlp.metric.squad中的squad_evaluate(), compute_predictions()评估当前模型训练的效果,其中:

  • compute_predictions()用于生成可提交的答案;

  • squad_evaluate()用于返回评价指标。

二者适用于所有符合squad数据格式的答案抽取任务。这类任务使用Rouge-L和exact来评估预测的答案和真实答案的相似程度。

@paddle.no_grad()
def evaluate(model, data_loader):
    model.eval()

    all_start_logits = []
    all_end_logits = []
    tic_eval = time.time()

    for batch in data_loader:
        input_ids, token_type_ids = batch
        start_logits_tensor, end_logits_tensor = model(input_ids,
                                                       token_type_ids)

        for idx in range(start_logits_tensor.shape[0]):
            if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
                print("Processing example: %d" % len(all_start_logits))
                print('time per 1000:', time.time() - tic_eval)
                tic_eval = time.time()

            all_start_logits.append(start_logits_tensor.numpy()[idx])
            all_end_logits.append(end_logits_tensor.numpy()[idx])

    all_predictions, _, _ = compute_prediction(
        data_loader.dataset.data, data_loader.dataset.new_data,
        (all_start_logits, all_end_logits), False, 20, 30)
    squad_evaluate(
        examples=data_loader.dataset.data,
        preds=all_predictions,
        is_whitespace_splited=False)
    
    model.train()
# from utils import evaluate

criterion = CrossEntropyLossForSQuAD()
global_step = 0
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        global_step += 1
        input_ids, segment_ids, start_positions, end_positions = batch
        logits = model(input_ids=input_ids, token_type_ids=segment_ids)
        loss = criterion(logits, (start_positions, end_positions))

        if global_step % 100 == 0 :
            print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss))
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()

    evaluate(model=model, data_loader=dev_data_loader) 

model.save_pretrained('./checkpoint')
tokenizer.save_pretrained('./checkpoint')

更多预训练模型

PaddleNLP不仅支持BERT预训练模型,还支持ERNIE、RoBERTa、Electra等预训练模型。
下表汇总了目前PaddleNLP支持的各类预训练模型。用户可以使用PaddleNLP提供的模型,完成问答、序列分类、token分类等任务。同时我们提供了22种预训练的参数权重供用户使用,其中包含了11种中文语言模型的预训练权重。

ModelTokenizerSupported TaskModel Name
BERTBertTokenizerBertModel
BertForQuestionAnswering
BertForSequenceClassification
BertForTokenClassification
bert-base-uncased
bert-large-uncased
bert-base-multilingual-uncased
bert-base-cased
bert-base-chinese
bert-base-multilingual-cased
bert-large-cased
bert-wwm-chinese
bert-wwm-ext-chinese
ERNIEErnieTokenizer
ErnieTinyTokenizer
ErnieModel
ErnieForQuestionAnswering
ErnieForSequenceClassification
ErnieForTokenClassification
ernie-1.0
ernie-tiny
ernie-2.0-en
ernie-2.0-large-en
RoBERTaRobertaTokenizerRobertaModel
RobertaForQuestionAnswering
RobertaForSequenceClassification
RobertaForTokenClassification
roberta-wwm-ext
roberta-wwm-ext-large
rbt3
rbtl3
ELECTRAElectraTokenizerElectraModel
ElectraForSequenceClassification
ElectraForTokenClassification
electra-small
electra-base
electra-large
chinese-electra-small
chinese-electra-base

注:其中中文的预训练模型有 bert-base-chinese, bert-wwm-chinese, bert-wwm-ext-chinese, ernie-1.0, ernie-tiny, roberta-wwm-ext, roberta-wwm-ext-large, rbt3, rbtl3, chinese-electra-base, chinese-electra-small 等。

更多预训练模型参考:https://github.com/PaddlePaddle/models/blob/develop/PaddleNLP/docs/transformers.md
更多预训练模型fine-tune下游任务使用方法,请参考examples

六、模型预测

@paddle.no_grad()
def do_predict(model, data_loader):
    model.eval()

    all_start_logits = []
    all_end_logits = []
    tic_eval = time.time()

    for batch in data_loader:
        input_ids, token_type_ids = batch
        start_logits_tensor, end_logits_tensor = model(input_ids,
                                                       token_type_ids)

        for idx in range(start_logits_tensor.shape[0]):
            if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
                print("Processing example: %d" % len(all_start_logits))
                print('time per 1000:', time.time() - tic_eval)
                tic_eval = time.time()

            all_start_logits.append(start_logits_tensor.numpy()[idx])
            all_end_logits.append(end_logits_tensor.numpy()[idx])

    all_predictions, _, _ = compute_prediction(
        data_loader.dataset.data, data_loader.dataset.new_data,
        (all_start_logits, all_end_logits), False, 20, 30)


    count = 0
    for example in data_loader.dataset.data:
        count += 1
        print()
        print('问题:',example['question'])
        print('原文:',''.join(example['context']))
        print('答案:',all_predictions[example['id']])
        if count >= 2:
            break
    
    model.train()
do_predict(model, dev_data_loader)

七. 更多深度学习资源

7.1 一站式深度学习平台awesome-DeepLearning

  • 深度学习入门课
  • 深度学习百问
  • 特色课
  • 产业实践

PaddleEdu使用过程中有任何问题欢迎在awesome-DeepLearning提issue,同时更多深度学习资料请参阅飞桨深度学习平台

记得点个Star⭐收藏噢~~

7.2 飞桨技术交流群(QQ)

目前QQ群已有2000+同学一起学习,欢迎扫码加入

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐