基于Roberta模型的中医文献阅读理解

机器阅读理解是自然语言处理中的一个重要的任务,最常见的有单篇章的抽取式阅读理解。机器阅读理解的应用范围很广,比如客服机器人,通过文字或者语音与用户进行沟通交流,然后获取相关的信息并提供准确可靠的回答。搜索引擎中精确返回用户所给定问题的答案。在医疗领域中自动阅读病人的资料来找到相应的病因。在教育领域中,利用阅读理解模型自动为学生的作文给出改进意见等等。中医是中华民族的瑰宝,此次任务将机器阅读理解和中医药领域文本结合起来,通过 PaddleNLP 开发库,让计算机帮助人类在大量中医文本中找到想要的答案,从而减轻人们对信息的获取的成本。

第一次写项目没啥经验,欢迎各位同学给我提点意见,点个喜欢,fork一下玩玩,参考文献放在末尾,希望大家多多支持~~~
也欢迎各位同学和我互关哇!
作者:jjyaoao

最终效果展示:

一、方案设计

阅读理解的方案如上图,首先是query表示的是问句,一般是用户的提问,passage表示的是文章,表示的是query的答案要从passage里面抽取出来,query和passage经过数据预处理,得到id形式的输入,然后把query,passage的id形式输入到Roberta模型里面,Roberta模型经过处理会输出答案的位置,输出位置以后就可以得到相应的answer了。

二、 数据处理

具体的任务定义为:对于一个给定的问题q和一个篇章p,根据篇章内容,给出该问题的答案a。数据集中的每个样本,是一个三元组<q, p, a>,例如:

问题 q: 草菇有什么功效?

篇章 p: 草菇荠菜汤鲜嫩清香、色味搭配,具有清热和脾、益气平肝、降糖降压等功效,是夏季解暑祛热的良食佳品…

参考答案 a: 草菇荠菜汤鲜嫩清香、色味搭配,具有清热和脾、益气平肝、降糖降压等功效,是夏季解暑祛热的良食佳品

我们本次的数据集是以Json格式提供,包括:

  • id: 段落id
  • text: 段落文本
  • annotations: 包含(问题、答案)对,共有
  • Q:问题
  • A:答案

将上述数据进行简单地数据清洗以及格式(sqaud格式)转换操作,为了方便读取,具体格式如下:

{
    'id': 'xx', 'title': 'xxx', 
    'context': 'xxxx', 
    'question': 'xxxxx', 
    'answers': ['xxxx'], 
    'answer_starts': [xxx]
}

关于该数据集的详细内容,可参考数据集https://aistudio.baidu.com/aistudio/datasetdetail/181560 或天池比赛链接.

# 首先导入实验所需要用到的库包。
# base
import paddlenlp as ppnlp
from utils import prepare_train_features, prepare_validation_features
from functools import partial
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction

import collections
import time
import json

# data preprocess: 
from paddlenlp.datasets import load_dataset, MapDataset
from sklearn.model_selection import train_test_split
from paddle.io import Dataset

# Build the dataloader
import paddle
from paddlenlp.data import Stack, Dict, Pad
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):

2.1 数据集加载与处理

PaddleNLP已经内置SQuAD,CMRC等中英文阅读理解数据集,使用paddlenlp.datasets.load_dataset()API即可一键加载。本实例加载的是自行装配的中医阅读理解数据集,采用SQuAD数据格式读入,InputFeature使用滑动窗口的方法生成,即一个example可能对应多个InputFeature。

答案抽取任务即根据输入的问题和文章,预测答案在文章中的起始位置和结束位置。

由于文章加问题的文本长度可能大于max_seq_length,答案出现的位置有可能出现在文章最后,所以不能简单的对文章进行截断。

那么对于过长的文章,则采用滑动窗口将文章分成多段,分别与问题组合。再用对应的tokenizer转化为模型可接受的feature。doc_stride参数就是每次滑动的距离。滑动窗口生成InputFeature的过程如下图:


滑动窗口生成InputFeature示意图
# data preprocess:
def read(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            raw = line.strip('\n')
            id_Index = raw.find("id")
            if id_Index == 5 and raw[5] == 'i':
                id_raw = raw[id_Index+5:]
                id = id_raw[:len(id_raw)-1].replace('\"', "")            


            text_Index = raw.find("text")

            if text_Index == 5 and raw[5] == 't':
                text_raw = raw[text_Index+8:]
                text = text_raw[:len(text_raw)-2]


            ques_Index = raw.find("Q")

            if ques_Index == 9 and raw[9] == 'Q':
                ques_raw = raw[ques_Index+5:]
                ques = ques_raw[:len(ques_raw) - 2]


            ans_Index = raw.find("A")
            if ans_Index == 9 and raw[9] == 'A':
                ans_raw = raw[ans_Index+5:]
                ans = ans_raw[:len(ans_raw) - 1]

                ans_start = [text.find(ans)]
                ans = [ans]

            
                yield {'id': id, 'title': '', 'context': text, 'question': ques, 'answers': ans, 'answer_starts': ans_start}
map_ds= load_dataset(read, data_path='/home/aistudio/data/data181560/中医数据集.json',lazy=False)

ds1, ds2 = train_test_split(map_ds, test_size=0.08, random_state=42, shuffle=True)
train_ds = MapDataset(ds1)
dev_ds = MapDataset(ds2)

# 数据集展示
for i in range(1):
    print(map_ds[0])
    print()
    print(map_ds[18000])
{'id': '1240', 'title': '', 'context': '\\"胆石症的治疗应区别不同情况分别处理,无症状胆囊结石可不作治疗,但应定期观察并注意良好的饮食习惯。有症状的胆囊结石仍以胆囊切除术为较安全有效的疗法,此外,尚可采用体外震波碎石。胆管结石宜采用以手术为主的综合治疗。胆石症的家庭治疗可采用以下方法:\\\\n(1)一般治疗    预防和治疗肠道寄生虫病和肠道感染,以降低胆石症的发病率。胆绞痛发作期应禁食脂肪等食物,采用高碳水化合物流质饮食;缓解期应忌食富含胆固醇的食物如脑、肝、肾、蛋黄等。\\\\n(2)增进胆汁排泄    可选用50%硫酸镁10~15毫升,餐后口服,每日3次;胆盐每次口服0.5~1克,每日3次;去氢胆酸0.25克,每日3次,餐后服用。\\\\n(3)消除胆绞痛    轻者可卧床休息,右上腹热敷,用硝酸甘油酯0.6毫克,每3~4小时一次,含于舌下;或阿托品0.5毫克,每3~4小时肌肉注射一次。重者应住院治疗。\\\\n(4)排石疗法以中药治疗为主,若右上腹疼痛有间歇期,无明显发热及黄疸,苔薄白,脉弦,属气滞者,用生大黄6克、木香9克、枳壳9克、金钱草30克、川楝子9克、黄苓9克,水煎服。右上腹痛为持续性,且阵发性加剧,有明显发热及黄疸,舌红苔黄,', 'question': '什么类型的胆囊结石可不作治疗?', 'answers': ['无症状胆囊结'], 'answer_starts': [20]}

{'id': '13462', 'title': '', 'context': '古人言“养生贵在养神”,经常排除杂念、静养心神、闭目休息,是一种调养精神的简便方法。闭目养神能养阴去燥,特别是饭后闭目20分钟,可护肝消食。饭后,身体内的血液都集中到消化道内,参与食物的消化吸收,如果此时行走、运动,就会有一部分血液流向手足,肝脏则会出现供血量不足的情况,影响其正常的新陈代谢。饭后闭目静坐20分钟,能使血液更多地流向肝脏,供给肝细胞氧和营养成分。患有肝病的人,饭后更应该闭目养神。', 'question': '调养精神的简便方法?', 'answers': ['经常排除杂念、静养心神、闭目休息'], 'answer_starts': [12]}

ppnlp.transformers.RobertaTokenizer

调用RobertaTokenizer进行数据处理。

预训练模型Roberta对中文数据的处理是以byte为单位的BPE编码。官方词表包含5w多的byte级别的token。merges.txt中存储了所有的token,而vocab.json则是一个byte到索引的映射,通常频率越高的byte索引越小。所以转换的过程是,先将输入的所有tokens转化为merges.txt中对应的byte,再通过vocab.json中的字典进行byte到索引的转化。

tokenizer的作用是将原始输入文本转化成模型可以接受的输入数据形式。对于Roberta,比如输入的文本是

What's up with the tokenizer?

首先使用merges.txt转化为对应的Byte(类似于标准化的过程)

['What', "'s", 'Ġup', 'Ġwith', 'Ġthe', 'Ġtoken', 'izer', '?']

再通过vocab.json文件存储的映射转化为对应的索引

[   'What',     "'s",    'Ġup',  'Ġwith',   'Ġthe', 'Ġtoken',   'izer',      '?']
---- becomes ----
[     2061,      338,      510,      351,      262,    11241,     7509,       30]
# 更多可选择模型:
# ['bert-base-uncased', 'bert-large-uncased', 'bert-base-multilingual-uncased', 'bert-base-cased', 'bert-base-chinese', 'bert-base-multilingual-cased'
# , 'bert-large-cased', 'bert-wwm-chinese', 'bert-wwm-ext-chinese', 'macbert-base-chinese', 'macbert-large-chinese', 'simbert-base-chinese']

# 定义使用paddleNLP内置的roberta中文预训练模型
MODEL_NAME = 'roberta-wwm-ext-large'
tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained(MODEL_NAME)
[2022-12-07 00:36:50,934] [    INFO] - Found /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large/vocab.txt

2.2 数据转化

使用load_dataset()API默认读取到的数据集是MapDataset对象,MapDatasetpaddle.io.Dataset的功能增强版本。其内置的map()方法适合用来进行批量数据集处理。map()方法传入的是一个用于数据处理的function。
以下是采取的数据转化的用法:

max_seq_length = 512
doc_stride = 128

train_trans_func = partial(prepare_train_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)

train_ds.map(train_trans_func, batched=True)

dev_trans_func = partial(prepare_validation_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)
                           
dev_ds.map(dev_trans_func, batched=True)

<paddlenlp.datasets.dataset.MapDataset at 0x7fc25a6cd410>
# 展示数据处理效果
for idx in range(2):
    print(train_ds[idx]['input_ids'])
    print(train_ds[idx]['token_type_ids'])
    print(train_ds[idx]['overflow_to_sample'])
    print(train_ds[idx]['offset_mapping'])
    print(train_ds[idx]['start_positions'])
    print(train_ds[idx]['end_positions'])
    print()
[101, 7946, 5698, 7937, 4638, 5852, 1075, 817, 966, 680, 4635, 5698, 7937, 1525, 702, 3291, 7770, 8043, 102, 5698, 7937, 3300, 7946, 4635, 697, 4905, 8024, 794, 5852, 1075, 4906, 2110, 4692, 8024, 3187, 6389, 7946, 5698, 7937, 510, 4635, 5698, 7937, 6963, 3221, 5852, 1075, 705, 2168, 4638, 7608, 4289, 8024, 5735, 671, 2137, 6206, 1146, 1139, 7770, 678, 8024, 6929, 720, 7946, 5698, 7937, 4638, 5852, 1075, 817, 966, 4526, 7770, 754, 4635, 5698, 7937, 511, 2792, 809, 8024, 4500, 754, 6133, 4660, 1075, 4495, 4638, 1914, 711, 7946, 5698, 7937, 8024, 5445, 4635, 5698, 7937, 1728, 711, 5682, 3813, 4023, 778, 8024, 1156, 3291, 1914, 1765, 4500, 1762, 3189, 2382, 7650, 7608, 704, 976, 4157, 5345, 722, 4500, 511, 7608, 4545, 868, 4500, 7946, 5698, 7937, 2595, 1456, 4491, 510, 2398, 8024, 1072, 3300, 3996, 1075, 5498, 5513, 510, 1075, 6117, 3883, 4246, 4638, 868, 4500, 511, 7370, 749, 1920, 2157, 4225, 4761, 4638, 723, 1355, 1216, 3126, 722, 1912, 8024, 7946, 5698, 7937, 6820, 3300, 2523, 1962, 4638, 2834, 3709, 1265, 1216, 5543, 8024, 1728, 711, 2124, 705, 2168, 4638, 3779, 5544, 1469, 5335, 4495, 5162, 147, 8024, 5543, 3996, 3883, 4649, 5502, 510, 6133, 6117, 6858, 912, 8024, 3221, 1075, 7582, 7728, 7582, 4638, 924, 1075, 881, 1501, 511, 2124, 6820, 1419, 3300, 7344, 3632, 782, 860, 1355, 5523, 4638, 4289, 6574, 6028, 7942, 5162, 510, 5519, 4822, 510, 5491, 5131, 8024, 1728, 3634, 5698, 7937, 1391, 1914, 749, 738, 679, 833, 1355, 5523, 511, 1762, 5688, 7608, 1121, 5503, 4638, 1398, 3198, 8024, 5735, 6981, 1394, 5698, 7937, 4638, 7608, 4500, 8024, 5110, 5133, 4638, 4649, 5502, 1377, 5815, 2533, 3121, 1587, 511, 3634, 1912, 8024, 7946, 5698, 7937, 6820, 1072, 3300, 4660, 5554, 1856, 7767, 4638, 1216, 3126, 8024, 2190, 754, 671, 763, 5498, 5513, 679, 6639, 2792, 5636, 1928, 3238, 4680, 4700, 510, 6381, 2554, 1213, 1121, 6842, 4638, 782, 3341, 6432, 8024, 1914, 1391, 671, 763, 7946, 5698, 7937, 833, 3353, 3300, 4660, 1905, 511, 7608, 4500, 722, 3791, 2769, 812, 2398, 3189, 7027, 1391, 1168, 4638, 5698, 7937, 1169, 1501, 1914, 711, 5698, 7937, 6996, 1469, 5698, 7937, 7676, 3779, 511, 1391, 3146, 5108, 4638, 5698, 7937, 2190, 754, 5852, 1075, 4638, 1429, 3119, 3341, 6432, 2400, 679, 3221, 3297, 1962, 4638, 8024, 1728, 711, 5698, 7937, 4638, 1912, 7481, 3300, 671, 2231, 4924, 4801, 4638, 5606, 8024, 1372, 3300, 2828, 2124, 4827, 4810, 8024, 1071, 704, 4638, 5852, 1075, 5162, 2798, 5543, 6158, 1429, 3119, 511, 2792, 809, 8024, 5698, 7937, 3297, 1962, 4827, 4810, 749, 1086, 1391, 511, 996, 2100, 3300, 6887, 743, 3341, 4638, 5698, 7937, 6206, 4500, 2166, 2196, 2595, 1962, 4638, 2159, 1690, 3341, 996, 2100, 8024, 2400, 2100, 3123, 1762, 7346, 1117, 2397, 4246, 4638, 1765, 3175, 8024, 6912, 1048, 7345, 1045, 4684, 2198, 511, 1963, 3362, 2199, 5698, 7937, 4143, 4225, 3256, 2397, 1156, 3291, 1217, 1164, 754, 2100, 3123, 511, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
0
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 261), (261, 262), (262, 263), (263, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 277), (277, 278), (278, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (325, 326), (326, 327), (327, 328), (328, 329), (329, 330), (330, 331), (331, 332), (332, 333), (333, 334), (334, 335), (335, 336), (336, 337), (337, 338), (338, 339), (339, 340), (340, 341), (341, 342), (342, 343), (343, 344), (344, 345), (345, 346), (346, 347), (347, 348), (348, 349), (349, 350), (350, 351), (351, 352), (352, 353), (353, 354), (354, 355), (355, 356), (356, 357), (357, 358), (358, 359), (359, 360), (360, 361), (361, 362), (362, 363), (363, 364), (364, 365), (365, 366), (366, 367), (367, 368), (368, 369), (369, 370), (370, 371), (371, 372), (372, 373), (373, 374), (374, 375), (375, 376), (376, 377), (377, 378), (378, 379), (379, 380), (380, 381), (381, 382), (382, 383), (383, 384), (384, 385), (385, 386), (386, 387), (387, 388), (388, 389), (389, 390), (390, 391), (391, 392), (392, 393), (393, 394), (394, 395), (395, 396), (396, 397), (397, 398), (398, 399), (399, 400), (400, 401), (401, 402), (402, 403), (403, 404), (404, 405), (405, 406), (406, 407), (407, 408), (408, 409), (409, 410), (410, 411), (411, 412), (412, 413), (413, 414), (414, 415), (415, 416), (416, 417), (417, 418), (418, 419), (419, 420), (420, 421), (421, 422), (422, 423), (423, 424), (424, 425), (425, 426), (426, 427), (427, 428), (428, 429), (429, 430), (430, 431), (431, 432), (432, 433), (433, 434), (434, 435), (435, 436), (436, 437), (437, 438), (438, 439), (439, 440), (440, 441), (441, 442), (442, 443), (443, 444), (444, 445), (445, 446), (446, 447), (447, 448), (448, 449), (449, 450), (450, 451), (451, 452), (452, 453), (453, 454), (454, 455), (455, 456), (456, 457), (457, 458), (458, 459), (459, 460), (460, 461), (461, 462), (462, 463), (463, 464), (464, 465), (465, 466), (466, 467), (467, 468), (468, 469), (469, 470), (0, 0)]
64
78

[101, 7370, 749, 3580, 2094, 1912, 8024, 6820, 3300, 1525, 763, 1072, 3300, 5710, 7676, 1456, 4638, 3717, 3362, 8043, 102, 100, 3379, 3580, 1072, 3300, 5710, 7676, 3698, 1456, 8024, 1762, 671, 2137, 4923, 2428, 677, 1377, 809, 2458, 4964, 6237, 4920, 511, 100, 1266, 776, 704, 1278, 5790, 1920, 2110, 2382, 4995, 2168, 3136, 2956, 2456, 6379, 8024, 1963, 3362, 1762, 1309, 2147, 7027, 3123, 671, 702, 3580, 2094, 8024, 3926, 3173, 4638, 3698, 1456, 8024, 5543, 1916, 1173, 4080, 4868, 5307, 5143, 5320, 4638, 1069, 1939, 8024, 6375, 782, 4868, 3926, 3698, 4272, 8024, 3926, 7370, 3738, 3843, 4638, 4958, 3698, 8024, 4685, 2190, 5401, 1265, 2147, 1079, 4638, 4384, 1862, 511, 2382, 4995, 2168, 3136, 2956, 2900, 1139, 8024, 794, 704, 1278, 6235, 2428, 3341, 6432, 8024, 3580, 2094, 1072, 3300, 4638, 5710, 7676, 1456, 8024, 1377, 809, 1265, 3969, 510, 7008, 5569, 510, 6912, 4920, 510, 2458, 4964, 8024, 7370, 749, 7008, 5554, 2458, 4964, 1912, 8024, 2496, 2697, 6230, 726, 1213, 510, 5517, 5499, 7653, 5515, 8024, 679, 2682, 1391, 691, 6205, 3198, 8024, 6844, 2496, 7319, 7319, 3580, 2094, 4638, 3926, 7676, 8024, 1377, 809, 5353, 6237, 679, 6844, 4638, 4568, 4307, 511, 5710, 7676, 4638, 3698, 1456, 6820, 5543, 1916, 886, 782, 7252, 7474, 2128, 4868, 8024, 2828, 3580, 2094, 3123, 1762, 2414, 1928, 8024, 738, 3300, 1164, 754, 4717, 4697, 511, 5445, 3580, 2094, 3382, 1469, 4638, 5682, 2506, 8024, 833, 5314, 782, 3946, 3265, 4638, 2697, 6230, 8024, 1728, 3634, 8024, 1377, 809, 6432, 3580, 2094, 3221, 2147, 1079, 4638, 100, 2341, 3030, 6392, 100, 511, 7370, 749, 3580, 2094, 1912, 8024, 6820, 3300, 2523, 1914, 1072, 3300, 5710, 7676, 1456, 4638, 3717, 3362, 8024, 1963, 5741, 3362, 510, 7676, 5933, 8024, 2190, 2147, 1079, 4384, 1862, 4638, 1112, 1265, 738, 833, 3300, 671, 2137, 4638, 2512, 1510, 511, 852, 6821, 763, 3698, 1456, 2190, 924, 2898, 978, 2434, 3341, 6432, 8024, 1372, 5543, 6629, 1168, 6774, 1221, 3780, 4545, 4638, 2797, 3667, 8024, 2400, 684, 7444, 6206, 6809, 1168, 671, 2137, 4638, 3849, 2428, 511, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
1
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 261), (261, 262), (262, 263), (263, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 277), (277, 278), (278, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (0, 0)]
286
290

从以上结果可以看出,数据集中的example已经被转换成了模型可以接收的feature,包括input_ids、token_type_ids、答案的起始位置等信息。
其中:

  • input_ids: 表示输入文本的token ID。
  • token_type_ids: 表示对应的token属于输入的问题还是答案。(Transformer类预训练模型支持单句以及句对输入)。
  • overflow_to_sample: feature对应的example的编号。
  • offset_mapping: 每个token的起始字符和结束字符在原文中对应的index(用于生成答案文本)。
  • start_positions: 答案在这个feature中的开始位置。
  • end_positions: 答案在这个feature中的结束位置。

2.3 构造Dataloader

使用paddle.io.DataLoader接口多线程异步加载数据。同时使用paddlenlp.data中提供的方法把feature组成batch

# Build the dataloader
batch_size = 8

train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=batch_size, shuffle=True)

train_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    "start_positions": Stack(dtype="int64"),
    "end_positions": Stack(dtype="int64")
}): fn(samples)

train_data_loader = paddle.io.DataLoader(
    dataset=train_ds,
    batch_sampler=train_batch_sampler,
    collate_fn=train_batchify_fn,
    return_list=True)

dev_batch_sampler = paddle.io.BatchSampler(
    dev_ds, batch_size=batch_size, shuffle=False)

dev_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)

dev_data_loader = paddle.io.DataLoader(
    dataset=dev_ds,
    batch_sampler=dev_batch_sampler,
    collate_fn=dev_batchify_fn,
    return_list=True)
for idx in range(2):
    print(train_ds[idx]['input_ids'])
    print(train_ds[idx]['token_type_ids'])
    print(train_ds[idx]['overflow_to_sample'])
    print(train_ds[idx]['offset_mapping'])
    print(train_ds[idx]['start_positions'])
    print(train_ds[idx]['end_positions'])
    print()

[101, 7946, 5698, 7937, 4638, 5852, 1075, 817, 966, 680, 4635, 5698, 7937, 1525, 702, 3291, 7770, 8043, 102, 5698, 7937, 3300, 7946, 4635, 697, 4905, 8024, 794, 5852, 1075, 4906, 2110, 4692, 8024, 3187, 6389, 7946, 5698, 7937, 510, 4635, 5698, 7937, 6963, 3221, 5852, 1075, 705, 2168, 4638, 7608, 4289, 8024, 5735, 671, 2137, 6206, 1146, 1139, 7770, 678, 8024, 6929, 720, 7946, 5698, 7937, 4638, 5852, 1075, 817, 966, 4526, 7770, 754, 4635, 5698, 7937, 511, 2792, 809, 8024, 4500, 754, 6133, 4660, 1075, 4495, 4638, 1914, 711, 7946, 5698, 7937, 8024, 5445, 4635, 5698, 7937, 1728, 711, 5682, 3813, 4023, 778, 8024, 1156, 3291, 1914, 1765, 4500, 1762, 3189, 2382, 7650, 7608, 704, 976, 4157, 5345, 722, 4500, 511, 7608, 4545, 868, 4500, 7946, 5698, 7937, 2595, 1456, 4491, 510, 2398, 8024, 1072, 3300, 3996, 1075, 5498, 5513, 510, 1075, 6117, 3883, 4246, 4638, 868, 4500, 511, 7370, 749, 1920, 2157, 4225, 4761, 4638, 723, 1355, 1216, 3126, 722, 1912, 8024, 7946, 5698, 7937, 6820, 3300, 2523, 1962, 4638, 2834, 3709, 1265, 1216, 5543, 8024, 1728, 711, 2124, 705, 2168, 4638, 3779, 5544, 1469, 5335, 4495, 5162, 147, 8024, 5543, 3996, 3883, 4649, 5502, 510, 6133, 6117, 6858, 912, 8024, 3221, 1075, 7582, 7728, 7582, 4638, 924, 1075, 881, 1501, 511, 2124, 6820, 1419, 3300, 7344, 3632, 782, 860, 1355, 5523, 4638, 4289, 6574, 6028, 7942, 5162, 510, 5519, 4822, 510, 5491, 5131, 8024, 1728, 3634, 5698, 7937, 1391, 1914, 749, 738, 679, 833, 1355, 5523, 511, 1762, 5688, 7608, 1121, 5503, 4638, 1398, 3198, 8024, 5735, 6981, 1394, 5698, 7937, 4638, 7608, 4500, 8024, 5110, 5133, 4638, 4649, 5502, 1377, 5815, 2533, 3121, 1587, 511, 3634, 1912, 8024, 7946, 5698, 7937, 6820, 1072, 3300, 4660, 5554, 1856, 7767, 4638, 1216, 3126, 8024, 2190, 754, 671, 763, 5498, 5513, 679, 6639, 2792, 5636, 1928, 3238, 4680, 4700, 510, 6381, 2554, 1213, 1121, 6842, 4638, 782, 3341, 6432, 8024, 1914, 1391, 671, 763, 7946, 5698, 7937, 833, 3353, 3300, 4660, 1905, 511, 7608, 4500, 722, 3791, 2769, 812, 2398, 3189, 7027, 1391, 1168, 4638, 5698, 7937, 1169, 1501, 1914, 711, 5698, 7937, 6996, 1469, 5698, 7937, 7676, 3779, 511, 1391, 3146, 5108, 4638, 5698, 7937, 2190, 754, 5852, 1075, 4638, 1429, 3119, 3341, 6432, 2400, 679, 3221, 3297, 1962, 4638, 8024, 1728, 711, 5698, 7937, 4638, 1912, 7481, 3300, 671, 2231, 4924, 4801, 4638, 5606, 8024, 1372, 3300, 2828, 2124, 4827, 4810, 8024, 1071, 704, 4638, 5852, 1075, 5162, 2798, 5543, 6158, 1429, 3119, 511, 2792, 809, 8024, 5698, 7937, 3297, 1962, 4827, 4810, 749, 1086, 1391, 511, 996, 2100, 3300, 6887, 743, 3341, 4638, 5698, 7937, 6206, 4500, 2166, 2196, 2595, 1962, 4638, 2159, 1690, 3341, 996, 2100, 8024, 2400, 2100, 3123, 1762, 7346, 1117, 2397, 4246, 4638, 1765, 3175, 8024, 6912, 1048, 7345, 1045, 4684, 2198, 511, 1963, 3362, 2199, 5698, 7937, 4143, 4225, 3256, 2397, 1156, 3291, 1217, 1164, 754, 2100, 3123, 511, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
0
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 261), (261, 262), (262, 263), (263, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 277), (277, 278), (278, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (325, 326), (326, 327), (327, 328), (328, 329), (329, 330), (330, 331), (331, 332), (332, 333), (333, 334), (334, 335), (335, 336), (336, 337), (337, 338), (338, 339), (339, 340), (340, 341), (341, 342), (342, 343), (343, 344), (344, 345), (345, 346), (346, 347), (347, 348), (348, 349), (349, 350), (350, 351), (351, 352), (352, 353), (353, 354), (354, 355), (355, 356), (356, 357), (357, 358), (358, 359), (359, 360), (360, 361), (361, 362), (362, 363), (363, 364), (364, 365), (365, 366), (366, 367), (367, 368), (368, 369), (369, 370), (370, 371), (371, 372), (372, 373), (373, 374), (374, 375), (375, 376), (376, 377), (377, 378), (378, 379), (379, 380), (380, 381), (381, 382), (382, 383), (383, 384), (384, 385), (385, 386), (386, 387), (387, 388), (388, 389), (389, 390), (390, 391), (391, 392), (392, 393), (393, 394), (394, 395), (395, 396), (396, 397), (397, 398), (398, 399), (399, 400), (400, 401), (401, 402), (402, 403), (403, 404), (404, 405), (405, 406), (406, 407), (407, 408), (408, 409), (409, 410), (410, 411), (411, 412), (412, 413), (413, 414), (414, 415), (415, 416), (416, 417), (417, 418), (418, 419), (419, 420), (420, 421), (421, 422), (422, 423), (423, 424), (424, 425), (425, 426), (426, 427), (427, 428), (428, 429), (429, 430), (430, 431), (431, 432), (432, 433), (433, 434), (434, 435), (435, 436), (436, 437), (437, 438), (438, 439), (439, 440), (440, 441), (441, 442), (442, 443), (443, 444), (444, 445), (445, 446), (446, 447), (447, 448), (448, 449), (449, 450), (450, 451), (451, 452), (452, 453), (453, 454), (454, 455), (455, 456), (456, 457), (457, 458), (458, 459), (459, 460), (460, 461), (461, 462), (462, 463), (463, 464), (464, 465), (465, 466), (466, 467), (467, 468), (468, 469), (469, 470), (0, 0)]
64
78

[101, 7370, 749, 3580, 2094, 1912, 8024, 6820, 3300, 1525, 763, 1072, 3300, 5710, 7676, 1456, 4638, 3717, 3362, 8043, 102, 100, 3379, 3580, 1072, 3300, 5710, 7676, 3698, 1456, 8024, 1762, 671, 2137, 4923, 2428, 677, 1377, 809, 2458, 4964, 6237, 4920, 511, 100, 1266, 776, 704, 1278, 5790, 1920, 2110, 2382, 4995, 2168, 3136, 2956, 2456, 6379, 8024, 1963, 3362, 1762, 1309, 2147, 7027, 3123, 671, 702, 3580, 2094, 8024, 3926, 3173, 4638, 3698, 1456, 8024, 5543, 1916, 1173, 4080, 4868, 5307, 5143, 5320, 4638, 1069, 1939, 8024, 6375, 782, 4868, 3926, 3698, 4272, 8024, 3926, 7370, 3738, 3843, 4638, 4958, 3698, 8024, 4685, 2190, 5401, 1265, 2147, 1079, 4638, 4384, 1862, 511, 2382, 4995, 2168, 3136, 2956, 2900, 1139, 8024, 794, 704, 1278, 6235, 2428, 3341, 6432, 8024, 3580, 2094, 1072, 3300, 4638, 5710, 7676, 1456, 8024, 1377, 809, 1265, 3969, 510, 7008, 5569, 510, 6912, 4920, 510, 2458, 4964, 8024, 7370, 749, 7008, 5554, 2458, 4964, 1912, 8024, 2496, 2697, 6230, 726, 1213, 510, 5517, 5499, 7653, 5515, 8024, 679, 2682, 1391, 691, 6205, 3198, 8024, 6844, 2496, 7319, 7319, 3580, 2094, 4638, 3926, 7676, 8024, 1377, 809, 5353, 6237, 679, 6844, 4638, 4568, 4307, 511, 5710, 7676, 4638, 3698, 1456, 6820, 5543, 1916, 886, 782, 7252, 7474, 2128, 4868, 8024, 2828, 3580, 2094, 3123, 1762, 2414, 1928, 8024, 738, 3300, 1164, 754, 4717, 4697, 511, 5445, 3580, 2094, 3382, 1469, 4638, 5682, 2506, 8024, 833, 5314, 782, 3946, 3265, 4638, 2697, 6230, 8024, 1728, 3634, 8024, 1377, 809, 6432, 3580, 2094, 3221, 2147, 1079, 4638, 100, 2341, 3030, 6392, 100, 511, 7370, 749, 3580, 2094, 1912, 8024, 6820, 3300, 2523, 1914, 1072, 3300, 5710, 7676, 1456, 4638, 3717, 3362, 8024, 1963, 5741, 3362, 510, 7676, 5933, 8024, 2190, 2147, 1079, 4384, 1862, 4638, 1112, 1265, 738, 833, 3300, 671, 2137, 4638, 2512, 1510, 511, 852, 6821, 763, 3698, 1456, 2190, 924, 2898, 978, 2434, 3341, 6432, 8024, 1372, 5543, 6629, 1168, 6774, 1221, 3780, 4545, 4638, 2797, 3667, 8024, 2400, 684, 7444, 6206, 6809, 1168, 671, 2137, 4638, 3849, 2428, 511, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
1
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 261), (261, 262), (262, 263), (263, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 277), (277, 278), (278, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (0, 0)]
286
290

三、模型构建

阅读理解本质是一个答案抽取任务,PaddleNLP对于各种预训练模型已经内置了对于下游任务-答案抽取的Fine-tune网络

以下项目以Roberta为例,介绍如何将预训练模型Fine-tune完成答案抽取任务。

答案抽取任务的本质就是根据输入的问题和文章,预测答案在文章中的起始位置和结束位置。基于BERT的答案抽取原理如下图所示:



图1:基于BERT的答案抽取原理示意图

Roberta模型主要是在BERT基础上做了几点调整: 1)训练时间更长,batch size更大,训练数据更多; 2)移除了next predict loss; 3)训练序列更长; 4)动态调整Masking机制。 5) Byte level BPE RoBERTa is trained with dynamic masking

paddlenlp.transformers.RobertaForQuestionAnswering()

一行代码即可加载预训练模型BERT用于答案抽取任务的Fine-tune网络。

paddlenlp.transformers.RobertaForQuestionAnswering.from_pretrained()

指定想要使用的模型名称和文本分类的类别数,一行代码完成网络构建。

# 设置想要使用模型的名称
model = ppnlp.transformers.RobertaForQuestionAnswering.from_pretrained(MODEL_NAME)
[2022-12-07 00:38:46,166] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/roberta_large/roberta_chn_large.pdparams and saved to /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large
[2022-12-07 00:38:46,169] [    INFO] - Downloading roberta_chn_large.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/roberta_large/roberta_chn_large.pdparams
100%|██████████| 1271615/1271615 [00:17<00:00, 71811.19it/s]
W1207 00:39:05.382441    98 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W1207 00:39:05.390488    98 device_context.cc:372] device: 0, cuDNN Version: 7.6.

四、模型配置

4.1 设置Fine-Tune优化策略

适用于ERNIE/BERT这类Transformer模型的学习率为warmup的动态学习率。



图3:动态学习率示意图
# 训练过程中的最大学习率
learning_rate = 3e-5 
# 训练轮次
epochs = 2
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01

num_training_steps = len(train_data_loader) * epochs
lr_scheduler = ppnlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

4.2 设计loss function

由于BertForQuestionAnswering模型对将BertModel的sequence_output拆开成start_logits和end_logits进行输出,所以阅读理解任务的loss也由start_loss和end_loss组成,我们需要自己定义loss function。对于答案其实位置和结束位置的预测可以分别成两个分类任务。所以设计的loss function如下:


# 尝试加入rdrop,尚未实现
# from paddlenlp.losses import RDropLoss
# class CrossEntropyWithRdrop(paddle.nn.Layer):
#     """
#     """
#     def __init__(self,
#                  label_smoothing=None,
#                  pad_idx=1,
#                  alpha=5):
#         super(CrossEntropyWithRdrop,self).__init__(label_smoothing,pad_idx)
#         self.pad_idx =pad_idx
#         self.alpha= alpha
#         self.rdrop_loss = RDropLoss()

#     def forward(self, model, sample, need_attn=False):
#         '''
#         return : loss,sample_size,log
#         '''
#         # 1.loss ce
#         logits, sum_cost, avg_cost, token_num = super().forward(model, sample, need_attn=False)

#         # 2.rdrop loss
#         if model.training:
#             avg_cost = self.get_rdrop_loss(model, sample, logits, avg_cost)
#             print(avg_cost)

#         return logits, sum_cost, avg_cost, token_num

#     def get_rdrop_loss(self, model, sample, logits1 ,ce_loss1):
#         if self.alpha > 0:
#             logits2, sum_cost2, ce_loss2, token_num2 = super().forward(model, sample)
#             pad_mask = (sample["prev_tokens"] != self.pad_idx).unsqueeze(-1).tile([1,1,logits1.shape[-1]])
#             kl_loss = self.rdrop_loss(logits1,logits2,pad_mask)
#             kl_loss = kl_loss/token_num2
#             loss = 0.5 * (ce_loss1 + ce_loss2) + self.alpha * kl_loss
#         else:
#             loss = ce_loss1

#         return loss
class CrossEntropyLossForSQuAD(paddle.nn.Layer):
    def __init__(self):
        super(CrossEntropyLossForSQuAD, self).__init__()

    def forward(self, y, label):
        start_logits, end_logits = y   # both shape are [batch_size, seq_len]
        start_position, end_position = label
        start_position = paddle.unsqueeze(start_position, axis=-1)
        end_position = paddle.unsqueeze(end_position, axis=-1)
        start_loss = paddle.nn.functional.softmax_with_cross_entropy(
            logits=start_logits, label=start_position, soft_label=False)
        start_loss = paddle.mean(start_loss)
        end_loss = paddle.nn.functional.softmax_with_cross_entropy(
            logits=end_logits, label=end_position, soft_label=False)
        end_loss = paddle.mean(end_loss)

        loss = (start_loss + end_loss) / 2
        return loss

五、模型训练

模型训练的过程通常有以下步骤:

  1. 从dataloader中取出一个batch data
  2. 将batch data喂给model,做前向计算
  3. 将前向计算结果传给损失函数,计算loss。
  4. loss反向回传,更新梯度。重复以上步骤。

每训练一个epoch时,程序通过evaluate()调用paddlenlp.metric.squad中的squad_evaluate(), compute_predictions()评估当前模型训练的效果,其中:

  • compute_predictions()用于生成可提交的答案;

  • squad_evaluate()用于返回评价指标。

二者适用于所有符合squad数据格式的答案抽取任务。这类任务使用Rouge-L和exact来评估预测的答案和真实答案的相似程度。

@paddle.no_grad()
def evaluate(model, data_loader):
    model.eval()

    all_start_logits = []
    all_end_logits = []
    tic_eval = time.time()

    for batch in data_loader:
        input_ids, token_type_ids = batch
        start_logits_tensor, end_logits_tensor = model(input_ids,
                                                       token_type_ids)

        for idx in range(start_logits_tensor.shape[0]):
            if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
                print("Processing example: %d" % len(all_start_logits))
                print('time per 1000:', time.time() - tic_eval)
                tic_eval = time.time()

            all_start_logits.append(start_logits_tensor.numpy()[idx])
            all_end_logits.append(end_logits_tensor.numpy()[idx])

    all_predictions, _, _ = compute_prediction(
        data_loader.dataset.data, data_loader.dataset.new_data,
        (all_start_logits, all_end_logits), False, 20, 30)
    squad_evaluate(
        examples=data_loader.dataset.data,
        preds=all_predictions,
        is_whitespace_splited=False)
    
    model.train()
# train

criterion = CrossEntropyLossForSQuAD()
global_step = 0
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        global_step += 1
        input_ids, segment_ids, start_positions, end_positions = batch
        logits = model(input_ids=input_ids, token_type_ids=segment_ids)
        loss = criterion(logits, (start_positions, end_positions))

        if global_step % 100 == 0 :
            print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss))
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()

    evaluate(model=model, data_loader=dev_data_loader) 

# save model
model.save_pretrained('./checkpoint')
tokenizer.save_pretrained('./checkpoint')

# load model
# tokenizer = BertTokenizer.from_pretrained("./checkpoint")
# model = BertModel.from_pretrained("./checkpoint")
global step 100, epoch: 1, batch: 100, loss: 4.68508
global step 200, epoch: 1, batch: 200, loss: 1.93787
global step 300, epoch: 1, batch: 300, loss: 2.50312
global step 400, epoch: 1, batch: 400, loss: 2.06680
.......................................................

global step 2000, epoch: 1, batch: 2000, loss: 1.67050
global step 2100, epoch: 1, batch: 2100, loss: 1.03209
global step 2200, epoch: 1, batch: 2200, loss: 0.87983
Processing example: 1000
time per 1000: 33.68374156951904
{
  "exact": 21.974758723088346,
  "f1": 62.55876388815933,
  "total": 1347,
  "HasAns_exact": 21.974758723088346,
  "HasAns_f1": 62.55876388815933,
  "HasAns_total": 1347
}
global step 2300, epoch: 2, batch: 59, loss: 1.05687
global step 2400, epoch: 2, batch: 159, loss: 0.93989
....................................................

global step 4200, epoch: 2, batch: 1959, loss: 0.75131
global step 4300, epoch: 2, batch: 2059, loss: 1.06529
global step 4400, epoch: 2, batch: 2159, loss: 0.97709
Processing example: 1000
time per 1000: 33.705798625946045
{
  "exact": 24.35040831477357,
  "f1": 63.29065134552712,
  "total": 1347,
  "HasAns_exact": 24.35040831477357,
  "HasAns_f1": 63.29065134552712,
  "HasAns_total": 1347
}

六、模型预测

对验证集数据进行测试,打印出前两条数据

# model predict
import paddle



@paddle.no_grad()
def do_predict(model, data_loader):
    model.eval()

    all_start_logits = []
    all_end_logits = []
    tic_eval = time.time()

    for batch in data_loader:
        input_ids, token_type_ids = batch
        start_logits_tensor, end_logits_tensor = model(input_ids,
                                                       token_type_ids)

        for idx in range(start_logits_tensor.shape[0]):
            if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
                print("Processing example: %d" % len(all_start_logits))
                print('time per 1000:', time.time() - tic_eval)
                tic_eval = time.time()

            all_start_logits.append(start_logits_tensor.numpy()[idx])
            all_end_logits.append(end_logits_tensor.numpy()[idx])

    all_predictions, _, _ = compute_prediction(
        data_loader.dataset.data, data_loader.dataset.new_data,
        (all_start_logits, all_end_logits), False, 20, 30)


    count = 0
    for example in data_loader.dataset.data:
        count += 1
        print()
        print('问题:',example['question'])
        print('原文:',''.join(example['context']))
        print('答案:',all_predictions[example['id']])
        if count >= 2:
            break
    
    model.train()
do_predict(model, dev_data_loader)
Processing example: 1000
time per 1000: 37.292165994644165

问题: 无名指与食指长度差可能辅助用于预测什么疾病?
原文: 大屁股记忆差根据美国西北大学医学院的研究,看起来身形笨拙的“苹果形”身材的女性,其实记忆力远比“梨形”身材的女性要好。研究人员发现,腰腹部稍稍有一点儿脂肪堆积,有助于女性分泌身体必需的雌激素,尤其是对于绝经期妇女而言。而适量的脂肪有助于保护大脑各项功能,特别是记忆力。食指长更抑郁韩国一项最新研究发现,无名指长的男人患前列腺癌的几率会增加3倍。研究负责人表示,在不久的将来,无名指与食指长度差可能辅助用于预测前列腺癌。研究者指出,较高水平的雄性激素会导致右手无名指较长,而雄性激素正是刺激前列腺肿瘤生长的激素。无名指比食指长的人易患孤独症,食指较长的人易患抑郁症。此外,利物浦大学的研究人员发现,无名指较短的男性年轻时更易发生心梗,因为他们体内可以预防心梗的睾丸激素水平较低。大腿细心脏差来自哥本哈根大学医院的贝丽特·海特曼和佩德·弗雷泽里克森梳理了医院在上世纪80年代后期获取的数千个体测数据后发现,大腿围不足60厘米者比大腿围达到60厘米的人更容易患心脏病,患病比例是后者的两倍左右。
答案: 前列腺癌。

问题: 终日无所事事时我们可以做些什么?
原文: 凡事都有度,养生这件事也不例外。“不过度”就意味着节制,国家级名老中医、湖北省中医院涂晋文教授告诉《生命时报》记者,日常生活中的一些细节,只要长期坚持并且把握好度,就能强身健体、延年益寿。大家不妨对照试试↓生命时报特约记者喻朝晖衣不过暖:穿衣戴帽不要过于暖和,也不可过于单薄,过暖容易感冒,过冷容易受寒。食不过饱:吃饭不要过饱,粗细都吃,荤素相兼。住不过奢:要随遇而安,居室富丽堂皇易夺心志而变质。行不过富:身体健康允许,尽量以步代车。如出门必乘车,日久腿脚就要失去灵便。劳不过累:劳动的强度是有限的,超过负荷量容易造成身体伤害。每日工作8小时,8小时外适当地休闲,劳逸结合很必要。逸不过安:终日无所事事,会丧失对生活的情趣而心灰意懒,所以即使退休在家,也应勤于动脑,散步聊天、写字作画、下棋看戏等,心情由此舒畅,益于延年增寿。喜不过欢:人逢喜事精神爽。但是喜不能过头,“过喜则伤心”,古人范进中举后变疯,即为过喜所致。怒不可暴:有不顺心和烦恼的事,不要生气恼怒。怒则伤肝,伤肝就要发病。要有涵养,乐观处世。名不过求:名不过求、利不过贪,能给人带来平淡的幸福感。
答案: 散步聊天、写字作画、下棋看戏等,心情由此舒畅,益于延年增寿。

七、预测部署

模型训练完成之后接下来我们实现模型的预测部署。虽然训练阶段使用的动态图模式有诸多优点,包括Python风格的编程体验(使用RNN等包含控制流的网络时尤为明显)、友好的debug交互机制等。但Python动态图模式无法更好的满足预测部署阶段的性能要求,同时也限制了部署环境。

静态图是预测部署通常采用的方式。通过静态图中预先定义的网络结构,一方面无需像动态图那样执行开销较大的Python代码;另一方面,预先固定的图结构也为基于图的优化提供了可能,这些能够有效提升预测部署的性能。常用的基于图的优化策略有内存复用和算子融合,这需要预测引擎的支持。下面是算子融合的一个示例(将Transformer Block的FFN中的矩阵乘->加bias->relu激活替换为单个算子):

高性能预测部署需要静态图模型导出和预测引擎两方面的支持,这里分别介绍。

7.1 动转静导出模型

基于静态图的预测部署要求将动态图的模型转换为静态图形式的模型(网络结构和参数权重)。

Paddle静态图形式的模型(由变量和算子构成的网络结构)使用Program来存放,Program的构造可以通过Paddle的静态图模式说明,静态图模式下网络构建执行的各API会将输入输出变量和使用的算子添加到Program中。

import paddle
# 默认为动态图模式,这里开启静态图模式
paddle.enable_static()
# 定义输入变量,静态图下变量只是一个符号化表示,并不像动态图 Tensor 那样持有实际数据
x = paddle.static.data(shape=[None, 128], dtype='float32', name='x')
linear = paddle.nn.Linear(128, 256, bias_attr=False)
# 定义计算网络,输入和输出也都是符号化表示
y = linear(x)
# 打印 program
print(paddle.static.default_main_program())
# 关闭静态图模式
paddle.disable_static()
{ // block 0
    var x : paddle.VarType.LOD_TENSOR.shape(-1, 128).astype(VarType.FP32)
    persist trainable param linear_0.w_0 : paddle.VarType.LOD_TENSOR.shape(128, 256).astype(VarType.FP32)
    var linear_1.tmp_0 : paddle.VarType.LOD_TENSOR.shape(-1, 256).astype(VarType.FP32)

    {Out=['linear_1.tmp_0']} = matmul(inputs={X=['x'], Y=['linear_0.w_0']}, Scale_out = 1.0, Scale_x = 1.0, Scale_y = 1.0, alpha = 1.0, force_fp32_output = False, fused_reshape_Out = [], fused_reshape_X = [], fused_reshape_Y = [], fused_transpose_Out = [], fused_transpose_X = [], fused_transpose_Y = [], mkldnn_data_type = float32, op_device = , op_namescope = /, op_role = 0, op_role_var = [], transpose_X = False, transpose_Y = False, use_mkldnn = False, use_quantizer = False)
}

结合Paddle的静态图机制,Paddle提供了从动态图模型转换并导出静态图模型(包括网络结构和参数权重)的功能,通过jit.to_staticjit.save完成。

  1. paddle.jit.to_static 完成动态图模型到静态图模型的转换。
  • 网络结构:将动态图模型的forward函数转写(重点将Python控制流转换为Paddle对应API的调用),然后以静态图模式执行,生成Program。
  • 参数权重:将动态图模型的参数在生成Program时对应到其中的变量上。

动转静时还需要使用InputSpec提供模型输入的描述信息(shape、dtype和name)保证Program构建过程中形状和数据类型的正确性。

# 设置log输出转写的代码内容
# paddle.jit.set_code_level(100)

# 加载动态图模型
param_state_dict = paddle.load("checkpoint/model_state.pdparams")
model.set_state_dict(param_state_dict)

# 动转静,通过`input_spec`给出模型所需输入数据的描述,shape中的None代表可变的大小,类似上面静态图模式中的`paddle.static.data`
model_static = paddle.jit.to_static(
    model,
    input_spec=[
        paddle.static.InputSpec(
            shape=[None, None], dtype="int64"),  # input_ids: [batch_size, max_seq_len]
        paddle.static.InputSpec(
            shape=[None], dtype="int64")  # length: [batch_size]
    ])
# 打印动转静产生的Program以及输入变量
print(model_static.forward.concrete_program.main_program)
print(model_static.forward.inputs)
# # 打印模型参数权重内容
print(model_static.forward.concrete_program.parameters[0].name)
print(model_static.forward.concrete_program.parameters[0].value)
persist trainable param linear_127.b_0 : paddle.VarType.LOD_TENSOR.shape(1024,).astype(VarType.FP32)
    var linear_127.tmp_0 : paddle.VarType.LOD_TENSOR.shape(-1, -1, 1024).astype(VarType.FP32)
    var linear_127.tmp_1 : paddle.VarType.LOD_TENSOR.shape(-1, -1, 1024).astype(VarType.FP32)
    persist trainable param linear_128.w_0 : paddle.VarType.LOD_TENSOR.shape(1024, 1024).astype(VarType.FP32)
    persist trainable param linear_128.b_0 : paddle.VarType.LOD_TENSOR.shape(1024,).astype(VarType.FP32)
    var linear_128.tmp_0 : paddle.VarType.LOD_TENSOR.shape(-1, -1, 1024).astype(VarType.FP32)
    ...........................................
    
  [var input_ids : paddle.VarType.LOD_TENSOR.shape(-1, -1).astype(VarType.INT64), var token_type_ids : paddle.VarType.LOD_TENSOR.shape(-1,).astype(VarType.INT64)]
embedding_0.w_0
<bound method PyCapsule.value of Parameter containing:
Tensor(shape=[21128, 1024], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[ 0.02093766,  0.04043975, -0.02231590, ..., -0.00104699,  0.01650023, -0.00473674],
        [-0.03206097,  0.06535473, -0.05432026, ...,  0.00993754,  0.02322754, -0.02759956],
        [ 0.04299722,  0.06158005, -0.00912949, ..., -0.01705220, -0.02060436, -0.04537297],
        ...,
        [-0.00458803,  0.05845100, -0.03022355, ...,  0.02722553,  0.05777356,  0.01021087],
        [ 0.05257823,  0.03553688, -0.00079630, ..., -0.00281990,  0.01666584, -0.08521493],
        [-0.02762268,  0.08472827,  0.00651097, ...,  0.00811364, -0.00589278, -0.00263454]])>
  1. paddle.jit.save 完成静态图模型(网络结构和参数权重)的序列化保存。
  • 网络结构:以.pdmodel为扩展名的文件,可以使用visualdl来可视化。
  • 参数权重:以.pdiparams为扩展名的文件。
import os

# 保存动转静后的模型,得到 infer_model/model.pdmodel 和 infer_model/model.pdiparams 文件
paddle.jit.save(model_static, "infer_model/model")
os.listdir("infer_model/")
['model.pdiparams', 'model.pdmodel', 'model.pdiparams.info']

7.2 使用推理库预测

获得静态图模型之后,我们使用Paddle Inference进行预测部署。Paddle Inference是飞桨的原生推理库,作用于服务器端和云端,提供高性能的推理能力。

Paddle Inference采用 Predictor 进行预测。Predictor 是一个高性能预测引擎,该引擎通过对计算图的分析,完成对计算图的一系列的优化(如OP的融合、内存/显存的优化、 MKLDNN,TensorRT 等底层加速库的支持等),能够大大提升预测性能。另外Paddle Inference提供了Python、C++、GO等多语言的API,可以根据实际环境需要进行选择,为了便于演示这里使用Python API来完成,其已在安装的Paddle包中集成,直接使用即可。使用 Paddle Inference 开发 Python 预测程序仅需以下步骤:


import paddle.inference as paddle_infer

# 1. 创建配置对象,设置预测模型路径 
config = paddle_infer.Config("infer_model/model.pdmodel", "infer_model/model.pdiparams")
# 启用 GPU 进行预测 - 初始化 GPU 显存 100M, Deivce_ID 为 0
# config.enable_use_gpu(100, 0)
config.disable_gpu()
# 2. 根据配置内容创建推理引擎
predictor = paddle_infer.create_predictor(config)
# 3. 设置输入数据
# 获取输入句柄
input_handles = [
            predictor.get_input_handle(name)
            for name in predictor.get_input_names()
        ]
# 获取输入数据
data = dev_batchify_fn([dev_ds[0]])
# 设置输入数据
for input_field, input_handle in zip(data, input_handles):
    input_handle.copy_from_cpu(input_field)

# 4. 执行预测
predictor.run()

# 5. 获取预测结果
# 获取输出句柄
output_handles = [
            predictor.get_output_handle(name)
            for name in predictor.get_output_names()
        ]
# 从输出句柄获取预测结果
output = [output_handle.copy_to_cpu() for output_handle in output_handles]
# 打印预测结果
# print(output[:1000])

# 打印直接使用动态图模型预测的结果
# print(model(*data[:1000]))

# # Predictor和动态图模型预测速度对照
import time
start_time = time.time()
for i in range(1):
    for input_field, input_handle in zip(data, input_handles):
        input_handle.copy_from_cpu(input_field)
    predictor.run()
    output = [output_handle.copy_to_cpu() for output_handle in output_handles]
print("Predictor inference time: ", time.time() - start_time)

start_time = time.time()
for i in range(1):
    output = model(*data[:1000])
print("Dygraph model inference time: ", time.time() - start_time)


start_time = time.time()
for i in range(1):
    output = model_static(*data[:1000])
print("Static model inference time: ", time.time() - start_time)
print('动转静验证完成')
print()
I1207 01:41:12.262806    98 analysis_predictor.cc:155] Profiler is deactivated, and no profiling report will be generated.


Predictor inference time:  4.781304359436035


[1m[35m--- Running analysis [ir_graph_build_pass][0m
[1m[35m--- Running analysis [ir_graph_clean_pass][0m
[1m[35m--- Running analysis [ir_analysis_pass][0m
[32m--- Running IR pass [simplify_with_basic_ops_pass][0m
[32m--- Running IR pass [attention_lstm_fuse_pass][0m
[32m--- Running IR pass [seqconv_eltadd_relu_fuse_pass][0m
[32m--- Running IR pass [seqpool_cvm_concat_fuse_pass][0m
[32m--- Running IR pass [mul_lstm_fuse_pass][0m
[32m--- Running IR pass [fc_gru_fuse_pass][0m
[32m--- Running IR pass [mul_gru_fuse_pass][0m
[32m--- Running IR pass [seq_concat_fc_fuse_pass][0m
[32m--- Running IR pass [squeeze2_matmul_fuse_pass][0m
[32m--- Running IR pass [reshape2_matmul_fuse_pass][0m
I1207 01:41:13.372135    98 graph_pattern_detector.cc:101] ---  detected 24 subgraphs
[32m--- Running IR pass [flatten2_matmul_fuse_pass][0m
[32m--- Running IR pass [map_matmul_to_mul_pass][0m
I1207 01:41:13.428436    98 graph_pattern_detector.cc:101] ---  detected 169 subgraphs
[32m--- Running IR pass [fc_fuse_pass][0m
I1207 01:41:13.643700    98 graph_pattern_detector.cc:101] ---  detected 145 subgraphs
[32m--- Running IR pass [repeated_fc_relu_fuse_pass][0m
[32m--- Running IR pass [squared_mat_sub_fuse_pass][0m
[32m--- Running IR pass [conv_bn_fuse_pass][0m
[32m--- Running IR pass [conv_eltwiseadd_bn_fuse_pass][0m
[32m--- Running IR pass [conv_transpose_bn_fuse_pass][0m
[32m--- Running IR pass [conv_transpose_eltwiseadd_bn_fuse_pass][0m
[32m--- Running IR pass [is_test_pass][0m
[32m--- Running IR pass [runtime_context_cache_pass][0m
[1m[35m--- Running analysis [ir_params_sync_among_devices_pass][0m
[1m[35m--- Running analysis [adjust_cudnn_workspace_size_pass][0m
[1m[35m--- Running analysis [inference_op_replace_pass][0m
[1m[35m--- Running analysis [ir_graph_to_program_pass][0m
I1207 01:41:15.062129    98 analysis_predictor.cc:598] ======= optimize end =======
I1207 01:41:15.063323    98 naive_executor.cc:107] ---  skip [feed], feed -> token_type_ids
I1207 01:41:15.063338    98 naive_executor.cc:107] ---  skip [feed], feed -> input_ids
I1207 01:41:15.072453    98 naive_executor.cc:107] ---  skip [save_infer_model/scale_0.tmp_1], fetch -> fetch
I1207 01:41:15.072474    98 naive_executor.cc:107] ---  skip [save_infer_model/scale_1.tmp_1], fetch -> fetch


Dygraph model inference time:  1.7076599597930908
Static model inference time:  0.09106826782226562
动转静验证完成
Predictor inference time:  4.806589365005493
I1206 21:39:09.569715    98 analysis_predictor.cc:155] Profiler is deactivated, and no profiling report will be generated.
--- Running analysis [ir_graph_build_pass]
--- Running analysis [ir_graph_clean_pass]
--- Running analysis [ir_analysis_pass]
--- Running IR pass [simplify_with_basic_ops_pass]
--- Running IR pass [attention_lstm_fuse_pass]
--- Running IR pass [seqconv_eltadd_relu_fuse_pass]
--- Running IR pass [seqpool_cvm_concat_fuse_pass]
--- Running IR pass [mul_lstm_fuse_pass]
--- Running IR pass [fc_gru_fuse_pass]
--- Running IR pass [mul_gru_fuse_pass]
--- Running IR pass [seq_concat_fc_fuse_pass]
--- Running IR pass [squeeze2_matmul_fuse_pass]
--- Running IR pass [reshape2_matmul_fuse_pass]
I1206 21:39:10.645851    98 graph_pattern_detector.cc:101] ---  detected 24 subgraphs
--- Running IR pass [flatten2_matmul_fuse_pass]
--- Running IR pass [map_matmul_to_mul_pass]
I1206 21:39:10.704700    98 graph_pattern_detector.cc:101] ---  detected 169 subgraphs
--- Running IR pass [fc_fuse_pass]
I1206 21:39:10.933856    98 graph_pattern_detector.cc:101] ---  detected 145 subgraphs
--- Running IR pass [repeated_fc_relu_fuse_pass]
--- Running IR pass [squared_mat_sub_fuse_pass]
--- Running IR pass [conv_bn_fuse_pass]
--- Running IR pass [conv_eltwiseadd_bn_fuse_pass]
--- Running IR pass [conv_transpose_bn_fuse_pass]
--- Running IR pass [conv_transpose_eltwiseadd_bn_fuse_pass]
--- Running IR pass [is_test_pass]
--- Running IR pass [runtime_context_cache_pass]
--- Running analysis [ir_params_sync_among_devices_pass]
--- Running analysis [adjust_cudnn_workspace_size_pass]
--- Running analysis [inference_op_replace_pass]
--- Running analysis [ir_graph_to_program_pass]
I1206 21:39:12.281976    98 analysis_predictor.cc:598] ======= optimize end =======
I1206 21:39:12.282996    98 naive_executor.cc:107] ---  skip [feed], feed -> token_type_ids
I1206 21:39:12.283007    98 naive_executor.cc:107] ---  skip [feed], feed -> input_ids
I1206 21:39:12.291481    98 naive_executor.cc:107] ---  skip [save_infer_model/scale_0.tmp_1], fetch -> fetch
I1206 21:39:12.291496    98 naive_executor.cc:107] ---  skip [save_infer_model/scale_1.tmp_1], fetch 

八、gradio进行交互项部署

gradio部署有以下步骤

  1. 将模型拷贝到本地,并按照接口要求封装好方法
import gradio as gr
def question_answer(context, question):
    pass  # Implement your question-answering model here...

gr.Interface(fn=question_answer, inputs=["text", "text"], outputs=["textbox", "text"]).launch(share=True)
  1. 将用户输入的context, question加载,并利用模型返回answer

  2. 返回到gradio部署的框内, 进行页面展示

8.1 获取单条信息的命令行答案演示

from paddlenlp.datasets import load_dataset, MapDataset
def test_preprocess(context : str, question : str):
    test_raw = [{'id':'', 'title':'', 'context':context, 'question':question, 'answers':[''], 'answer_starts':['']}]
    test_example = MapDataset(test_raw)
    dev_trans_func = partial(prepare_validation_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)
    test_ds = test_example.map(dev_trans_func, batched=True)

    test_batch_sampler = paddle.io.BatchSampler(
        test_ds, batch_size=batch_size, shuffle=False)
        
    # test_ds_for_model = test_ds.remove_columns(["example_id", "offset_mapping"])
    test_data_loader = paddle.io.DataLoader(
        dataset=test_ds,
        batch_sampler=test_batch_sampler,
        collate_fn=dev_batchify_fn,
        return_list=True)    
    return test_example, test_ds, test_data_loader
import paddle
test_example, test_ds, test_data_loader = test_preprocess('草菇荠菜汤鲜嫩清香、色味搭配,具有清热和脾、益气平肝、降糖降压等功效,是夏季解暑祛热的良食佳品。具体做法如下,取新鲜草菇200克,荠菜50克,瘦猪肉50克,生姜、食盐、橄榄油适量。草菇洗净切片,芥菜洗净切段,生姜切丝,瘦肉剁成末。坐砂锅放入适量清水,倒入姜丝,大火烧开,滴加少量橄榄油,倒入草菇片,加盖焖煮2分钟;开盖加入肉末和荠菜,煮熟即可,加少量食盐调味。中医养生认为,草菇具有消食祛热,补脾益气,清暑热,降血压等功效。草菇含有丰富的维生素C,可预防牙龈出血,促进创伤愈合,而且促进铁的吸收,养血益气;草菇中钾、镁、硒等矿物元素较高,能促进人体新陈代谢,提高机体免疫力;另外草菇中的膳食纤维可以减慢人体对糖类的吸收,抑制餐后血糖的快速升高,是糖尿病患者的良好食品。荠菜有清热和脾、利水消肿等功效。荠菜中维C的含量比柑橘还要高,钙、钾、铁等含量均是大白菜的5倍以上,远高于普通蔬菜,丰富的矿物质能够壮骨养血,增强身体机能,保持年轻活力。另外,荠菜还含有大量的粗纤维和黄酮类物质,能够降血脂、保护心血管、抑癌抗癌,而且抑制毒物的吸收,促进毒物在肝脏的代谢排出,保肝护肝。',
 '草菇有什么功效?')
from utils import evaluate
# 命令行运行效果演示,并将输出保存在本地prediction.json文件内
素C,可预防牙龈出血,促进创伤愈合,而且促进铁的吸收,养血益气;草菇中钾、镁、硒等矿物元素较高,能促进人体新陈代谢,提高机体免疫力;另外草菇中的膳食纤维可以减慢人体对糖类的吸收,抑制餐后血糖的快速升高,是糖尿病患者的良好食品。荠菜有清热和脾、利水消肿等功效。荠菜中维C的含量比柑橘还要高,钙、钾、铁等含量均是大白菜的5倍以上,远高于普通蔬菜,丰富的矿物质能够壮骨养血,增强身体机能,保持年轻活力。另外,荠菜还含有大量的粗纤维和黄酮类物质,能够降血脂、保护心血管、抑癌抗癌,而且抑制毒物的吸收,促进毒物在肝脏的代谢排出,保肝护肝。',
 '草菇有什么功效?')
from utils import evaluate
# 命令行运行效果演示,并将输出保存在本地prediction.json文件内
evaluate(model=model, data_loader=test_data_loader)
{
  "exact": 0.0,
  "f1": 0.0,
  "total": 1,
  "HasAns_exact": 0.0,
  "HasAns_f1": 0.0,
  "HasAns_total": 1
}

问题: 草菇有什么功效?
原文: 草菇荠菜汤鲜嫩清香、色味搭配,具有清热和脾、益气平肝、降糖降压等功效,是夏季解暑祛热的良食佳品。具体做法如下,取新鲜草菇200克,荠菜50克,瘦猪肉50克,生姜、食盐、橄榄油适量。草菇洗净切片,芥菜洗净切段,生姜切丝,瘦肉剁成末。坐砂锅放入适量清水,倒入姜丝,大火烧开,滴加少量橄榄油,倒入草菇片,加盖焖煮2分钟;开盖加入肉末和荠菜,煮熟即可,加少量食盐调味。中医养生认为,草菇具有消食祛热,补脾益气,清暑热,降血压等功效。草菇含有丰富的维生素C,可预防牙龈出血,促进创伤愈合,而且促进铁的吸收,养血益气;草菇中钾、镁、硒等矿物元素较高,能促进人体新陈代谢,提高机体免疫力;另外草菇中的膳食纤维可以减慢人体对糖类的吸收,抑制餐后血糖的快速升高,是糖尿病患者的良好食品。荠菜有清热和脾、利水消肿等功效。荠菜中维C的含量比柑橘还要高,钙、钾、铁等含量均是大白菜的5倍以上,远高于普通蔬菜,丰富的矿物质能够壮骨养血,增强身体机能,保持年轻活力。另外,荠菜还含有大量的粗纤维和黄酮类物质,能够降血脂、保护心血管、抑癌抗癌,而且抑制毒物的吸收,促进毒物在肝脏的代谢排出,保肝护肝。
答案: 草菇具有消食祛热,补脾益气,清暑热,降血压等功效。

8.2 部署于gradio公共平台

Gradio:轻松实现AI算法可视化部署

如何将你的AI算法迅速分享给别人,让对方体验,一直是一件麻烦事儿。

首先大部分人都是在本地跑代码,让别人使用你的模型,以往有这三种方案:

  • 上github
  • 将代码打包或者封装成docker后,用QQ/百度云/U盘传输
  • 学习前后端知识,写个前端界面,买个域名,用flask这样微服务框架快速部署,看情况结合一下内网穿透。
    这些方案的问题在于——前两者需要对方会编程会配置环境(还得愿意),我们的分享对象满足这个条件的寥寥无几;后者则需要你这个算法工程师升级成全栈,学习前后端开发,学习成本太高

总结起来:场景不匹配,需求不契合,费时又费力!

那么有没有更好的解决方案呢?有!它就是我今天要给大家安利的一个python开源库:Gradio。

Gradio是MIT的开源项目,GitHub 2k+ star。

使用gradio,只需在原有的代码中增加几行,就能自动化生成交互式web页面,并支持多种输入输出格式,比如图像分类中的图>>标签,超分辨率中的图>>图等。

同时还支持生成能外部网络访问的链接,能够迅速让你的朋友,同事体验你的算法。

总结起来,它的优势有:

自动生成页面且可交互
改动几行代码就能完成
支持自定义多种输入输出
支持生成可外部访问的链接进行分享

想要了解更多,请见官网gradio

将其模型在本地环境运行,并进入链接

进入public URL验证

9、总结与展望

9.1 模型效果对比

ernie-2.0-large-enbert-base-chinesebert-wwm-ext-chineseroberta-wwm-ext-large
Epoch
134.6456.6259.0462.56
235.6857.8459.9463.29
334.9857.1259.6762.97
  • 上述表格得分为 F1 值
  • 最开始使用ernie2.0-large-en模型,结果发现好像该模型并不是针对于中文的预训练模型,因此导致效果比较差。
  • bert基本模型较为良好,能够基本完成任务。
  • 可以发现随着bert基线拓展而成的bert-wwm-ext-chinese模型与改良bert后的roberta-wwm-ext-large模型,能够取得较好的效果,随之付出的代价就是模型的体积变大,并且训练速度变迟缓。
  • 由于采用预训练 + 微调的方法,处理数据,因此Epoch过大极易发生过拟合,这也是我们以后应该避免的,尽量做到手工微调。

9.2 项目展望

  1. 数据:

    • 寻找更多优质中医语料数据集,进行简单增强

    • 采用回译等数据增强方法,从无到有的构建文本相似数据集

  2. 项目部署:

    • gradio再优化

    • aistudio沙盒打包

    • flask生成api,使用docker打包flask制作小应用

  3. 模型

    • 将模型微调的几步Epoch结果保存,进行模型平均操作
    • 采用模型Bagging策略,将训练质量好的模型进行融合处理
    • 探寻更多优质模型,针对任务,寻找使用更多医疗数据集训练好的与训练模型
  4. 更多:

    • 学习RocketQA等端到端问答模型,加上检索条件,在机器阅读理解基础上,制作完整的基于检索的问答系统,并可为后续学习基于生成的问答模型打下基础(PS:励志做个类似chatgpt的通用问答模型(bushi))

9.3 项目总结

​ 通过本次项目的实践,让我体验了一次机器阅读理解任务的全流程,从构建自己的数据集到搭建、训练并调优机器阅读理解模型,到实现动转静,完成静态图的推理,并用gradio快速实现可交互的部署,是一次难得的完成项目体验,感谢飞桨,让我有充足的算力,能不断试错,顺利完成此次任务,感谢角灰大帝大佬对此次任务的指导,让我有信心不断去探索、找到更好的实现方式。

参考文献:

  • https://zhuanlan.zhihu.com/p/121787628
  • https://link.zhihu.com/?target=https%3A//tianchi.aliyun.com/competition/entrance/531826/introduction
  • https://aistudio.baidu.com/aistudio/projectdetail/4113678?channelType=0&channel=0
  • https://aistudio.baidu.com/aistudio/projectdetail/5204743?forkThirdPart=1
  • https://zhuanlan.zhihu.com/p/374238080
    此文章为转载
    原项目链接
Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐