【CLUE benchmark】ChID 成语阅读理解填空

资源

⭐ ⭐ ⭐ 欢迎点个小小的Star支持!⭐ ⭐ ⭐

开源不易,希望大家多多支持~

一、背景介绍

这是Paddle版本的CLUE benchmark,旨在为用户提供Paddle版本的benchmark进行学习和交流,该版本提供了bert,ernie,roberta-wwm三个版本的基线。CLUE官网的链接为:

https://www.cluebenchmarks.com/

二、 数据预处理

在数据预处理前,需要更新一下paddlenlp版本,如果是初次运行,请升级以后,重启内核运行,这样加载的就是最新的paddlenlp了。

# !pip uninstall paddlenlp --yes
# %cd PaddleNLP/
# !python setup.py install
# %cd ..
!pip install --upgrade paddlenlp
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already up-to-date: paddlenlp in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.2.3)
Requirement already satisfied, skipping upgrade: multiprocess in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.70.11.1)
Requirement already satisfied, skipping upgrade: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.42.1)
Requirement already satisfied, skipping upgrade: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.4.4)
Requirement already satisfied, skipping upgrade: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (4.1.0)
Requirement already satisfied, skipping upgrade: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (1.2.2)
Requirement already satisfied, skipping upgrade: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (2.9.0)
Requirement already satisfied, skipping upgrade: dill>=0.3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from multiprocess->paddlenlp) (0.3.3)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp) (0.24.2)
Requirement already satisfied, skipping upgrade: numpy>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp) (1.20.3)
Requirement already satisfied, skipping upgrade: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from h5py->paddlenlp) (1.16.0)
Requirement already satisfied, skipping upgrade: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (2.1.0)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (0.14.1)
Requirement already satisfied, skipping upgrade: scipy>=0.19.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (1.6.3)

解压CHID数据集,案例使用的数据集是CHID数据集,数据集的论文出处是ChID: A Large-scale Chinese IDiom Dataset for Cloze Test,ChID是一个用于完形填空的大规模汉语成语数据集。ChID包含581K文章和729K空白区域,并覆盖多个领域(新闻,小说,文章)。在ChID中,一段话中的习语用空白符号代替。对于每一个空白,包括黄金成语在内的候选成语列表被提供作为选择。

!unzip -o data/data116839/chid_public.zip -d data/
Archive:  data/data116839/chid_public.zip
  inflating: data/train_answer.json  
  inflating: data/train.json         
  inflating: data/dev.json           
  inflating: data/dev_answer.json    
  inflating: data/README.md          
  inflating: data/test1.0.json       
  inflating: data/test1.1.json       
  inflating: data/idiomDict.json     

下面展示一条dev.json里面的数据:

{"candidates": ["祸不单行", "急功近利", "瓜熟蒂落", "画蛇添足", "本末倒置", "因噎废食", "约定俗成", "不伦不类", "芒刺在背", "不合时宜"], "content": ["中国青年报:篮协改革联赛切莫#idiom577157#羊城晚报:信兰成想干什么?都市时报:中国篮协该歇歇了!东方体育日报:CBA要扫帚不要钉耙别把接力棒变成杀威棒东方早报:玩不起就不玩了?这是在扼杀CBA改革新浪网友:先把CBA联赛办好再说其他的,别一天到处狂搞面子工程!虚荣心咋就这么强?", "本土PE如山创投的丁世平告诉记者,那些私募性质的VC和PE由于比较#idiom577158#,往往投资比较快、退出和套现也比较快,所以,退出环节出问题就会影响到他们的运行。和通管理国际有限公司的投资总监杜伟明也认为,金融风暴会对创投行业的每一个环节都会产生影响,因此也将导致创投策略的改变。", "对消费者而言,保险是一种较为特殊的金融理财产品,其最基本原则是对风险的保障。虽然目前保险公司开发了诸多强调投资收益的险种,但无论任何种类的产品都不能偏离保险的本质。尤其是对于注重保障的人身保险而言,强调投资价值实属#idiom577159#。消费者在投资保险理财产品时,需端正心态,树立科学、健康的保险理念,避免走进购买误区。", "从上季腰带开始被搭配在西装外后,很多款式的腰带便被广泛应用在针织外套和套装外,不仅没有#idiom577160#,还体现了搭配的层次感。长开衫+连衣裙:层次感的穿衣是今年的流行趋势,开衫+连衣裙很有混搭的感觉,一条宽腰带更加表现整体的层次,黑色和裙子花纹呼应,也让腰身更加紧致。", "但省社科院社会学研究所所长陈颐也委婉地提醒有些媒体“不要盲目跟风”。他指出,语言是一种交流的工具,是#idiom577161#的,要让人看得明白、听得懂。“给力”这个意思的词,在传统的语言词库里,其实还是有现成的、规范的词汇可以表达。“不分场合,硬是要用网络上流行的‘给力’,并不一定就是代表贴近读者。”陈颐说,就他个人感觉,如果“给力”没用好,有的标题反而会有点#idiom577162#,令人费解。", "股市持续震荡下行,农行为何选择此时上市?4月中旬以来,银行股伴随着大盘的节节下挫出现大幅下跌。在股市持续震荡下行之时,农行上市也进入了“倒计时”。有人认为,农行选择此时上市,有点#idiom577163#。怎样看待这一观点?"]}

其中的选项对应数据中的candidates:

"candidates": ["祸不单行", "急功近利", "瓜熟蒂落", "画蛇添足", "本末倒置", "因噎废食", "约定俗成", "不伦不类", "芒刺在背", "不合时宜"]

上下文对应数据中的content:

"content": ["中国青年报:篮协改革联赛切莫#idiom577157#羊城晚报:信兰成想干什么?都市时报:中国篮协该歇歇了!东方体育日报:CBA要扫帚不要钉耙别把接力棒变成杀威棒东方早报:玩不起就不玩了?这是在扼杀CBA改革新浪网友:先把CBA联赛办好再说其他的,别一天到处狂搞面子工程!虚荣心咋就这么强?", "本土PE如山创投的丁世平告诉记者,那些私募性质的VC和PE由于比较#idiom577158#,往往投资比较快、退出和套现也比较快,所以,退出环节出问题就会影响到他们的运行。和通管理国际有限公司的投资总监杜伟明也认为,金融风暴会对创投行业的每一个环节都会产生影响,因此也将导致创投策略的改变。", "对消费者而言,保险是一种较为特殊的金融理财产品,其最基本原则是对风险的保障。虽然目前保险公司开发了诸多强调投资收益的险种,但无论任何种类的产品都不能偏离保险的本质。尤其是对于注重保障的人身保险而言,强调投资价值实属#idiom577159#。消费者在投资保险理财产品时,需端正心态,树立科学、健康的保险理念,避免走进购买误区。", "从上季腰带开始被搭配在西装外后,很多款式的腰带便被广泛应用在针织外套和套装外,不仅没有#idiom577160#,还体现了搭配的层次感。长开衫+连衣裙:层次感的穿衣是今年的流行趋势,开衫+连衣裙很有混搭的感觉,一条宽腰带更加表现整体的层次,黑色和裙子花纹呼应,也让腰身更加紧致。", "但省社科院社会学研究所所长陈颐也委婉地提醒有些媒体“不要盲目跟风”。他指出,语言是一种交流的工具,是#idiom577161#的,要让人看得明白、听得懂。“给力”这个意思的词,在传统的语言词库里,其实还是有现成的、规范的词汇可以表达。“不分场合,硬是要用网络上流行的‘给力’,并不一定就是代表贴近读者。”陈颐说,就他个人感觉,如果“给力”没用好,有的标题反而会有点#idiom577162#,令人费解。", "股市持续震荡下行,农行为何选择此时上市?4月中旬以来,银行股伴随着大盘的节节下挫出现大幅下跌。在股市持续震荡下行之时,农行上市也进入了“倒计时”。有人认为,农行选择此时上市,有点#idiom577163#。怎样看待这一观点?"]

#idiom577161#等这些地方表示需要填写成语的地方。每一个content包含多条上下文,每个上下文包含一个或者多个需要填空的地方。

对应的dev_answer.json的数据为:

{
  "#idiom577157#": 5,
  "#idiom577158#": 1,
  "#idiom577159#": 4
  ......

5表示的是答案的索引位置,如#idiom577157#的空白填的答案就是candidates的索引位置(从0开始)为5的成语:“因噎废食”,以此类推。

导入实验所需要的库包:

import pickle
from functools import partial
import collections
import time
import json
import inspect
import os
from tqdm import tqdm
import numpy as np
import random 

import paddle
from paddle.io import BatchSampler,SequenceSampler
from paddle.io import TensorDataset,DataLoader
from paddlenlp.transformers import BertForMultipleChoice, BertTokenizer
from paddlenlp.data import Pad, Dict
from paddlenlp.datasets import load_dataset
from paddlenlp.datasets import DatasetBuilder
from paddlenlp.data import Stack, Dict, Pad
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction
from paddlenlp.transformers import LinearDecayWithWarmup
import paddlenlp as ppnlp

from CHID_preprocess import RawResult, get_final_predictions, write_predictions, generate_input, evaluate
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlenlp/transformers/funnel/modeling.py:31: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable

设置随机种子,固定住随机因素,方便用户稳定复现。

def set_seed(seed):
    """sets random seed"""
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)

处理训练集合的数据,转换成id的形式,然后构造训练集的Dataloader。

CHID数据的预处理分两种情况,第一种是一句话里面只包含一个成语,这个处理方式跟C3数据集一样;第二种是一句话里面包含多个成语,但模型每次只能识别一个空位,所以用[unused0]代替我们要识别的空位,而没被识别的空位(如果有的话)直接用4个[MASK]代替。

第一种情况:

content:"中国青年报:篮协改革联赛切莫#idiom577157#羊城晚报:信兰成想干什么?都市时报:中国篮协该歇歇了!东方体育日报:CBA要扫帚不要钉耙别把接力棒变成杀威棒东方早报:玩不起就不玩了?这是在扼杀CBA改革新浪网友:先把CBA联赛办好再说其他的,别一天到处狂搞面子工程!虚荣心咋就这么强?"
"candidates": ["祸不单行", "急功近利", "瓜熟蒂落", "画蛇添足", "本末倒置", "因噎废食", "约定俗成", "不伦不类", "芒刺在背", "不合时宜"]

构造成数据集:

[CLS]祸不单行[SEP]中国青年报:篮协改革联赛切莫[unused0]羊城晚报:信兰成想干什么?都市时报:中国篮协该歇歇了!东方体育日报:CBA要扫帚不要钉耙别把接力棒变成杀威棒东方早报:玩不起就不玩了?这是在扼杀CBA改革新浪网友:先把CBA联赛办好再说其他的,别一天到处狂搞面子工程!虚荣心咋就这么强?[SEP]

[CLS]急功近利[SEP]中国青年报:篮协改革联赛切莫[unused0]羊城晚报:信兰成想干什么?都市时报:中国篮协该歇歇了!东方体育日报:CBA要扫帚不要钉耙别把接力棒变成杀威棒东方早报:玩不起就不玩了?这是在扼杀CBA改革新浪网友:先把CBA联赛办好再说其他的,别一天到处狂搞面子工程!虚荣心咋就这么强?[SEP]
...

第二种情况:

"content": ["在一些卫浴专卖店里,商家们把质量认证书放在店堂内醒目的位置,仔细一看却发现认证机构#idiom577200#,有“中国轻工业认证中心”、有“中国某某标志认证委员会”、有“保护消费者权益基金会”等。而认证的内容更是#idiom577201#,有的自称是“卫生陶瓷”,有的是“绿色环保产品”,有的是“高温无菌陶瓷”,还有的是“臭氧杀菌合格产品”等。再看证书的有效期,有的为3年,有的是5年,有的甚至是10年有效等。"
"candidates": ["饱经忧患", "千变万化", "掌上明珠", "平分秋色", "五花八门", "惹火烧身", "鱼龙混杂", "粗制滥造", "各有千秋", "别有风味"]

构造成数据集:

[CLS]饱经忧患[SEP]在一些卫浴专卖店里,商家们把质量认证书放在店堂内醒目的位置,仔细一看却发现认证机构[unused0],有“中国轻工业认证中心”、有“中国某某标志认证委员会”、有“保护消费者权益基金会”等。而认证的内容更是[MASK][MASK][MASK][MASK],有的自称是“卫生陶瓷”,有的是“绿色环保产品”,有的是“高温无菌陶瓷”,还有的是“臭氧杀菌合格产品”等。再看证书的有效期,有的为3年,有的是5年,有的甚至是10年有效等。[SEP]
...
def process_train_data(input_dir,tokenizer,max_seq_length,max_num_choices):
    
    train_file='data/train.json'
    train_ans_file='data/train_answer.json'

    train_example_file = os.path.join(input_dir, 'train_examples_{}.pkl'.format(str(max_seq_length)))
    train_feature_file = os.path.join(input_dir, 'train_features_{}.pkl'.format(str(max_seq_length)))
    

    train_features = generate_input(train_file, train_ans_file, train_example_file, train_feature_file,
                                        tokenizer, max_seq_length=max_seq_length,
                                        max_num_choices=max_num_choices,
                                        is_training=True)

    print("loaded train dataset")
    print("Num generate examples = {}".format(len(train_features)))

    all_input_ids = paddle.to_tensor([f.input_ids for f in train_features], dtype='int64')
    all_input_masks = paddle.to_tensor([f.input_masks for f in train_features], dtype='int64')
    all_segment_ids = paddle.to_tensor([f.segment_ids for f in train_features], dtype='int64')
    all_choice_masks = paddle.to_tensor([f.choice_masks for f in train_features], dtype='int64')
    all_labels = paddle.to_tensor([f.label for f in train_features], dtype='int64')

    train_data = TensorDataset([all_input_ids, all_input_masks, all_segment_ids, all_choice_masks, all_labels])
    
    return train_data

处理验证集合的数据,转换成id的形式,然后构造验证集的Dataloader

def process_validation_data(input_dir,tokenizer,max_seq_length,max_num_choices):

    predict_file='data/dev.json'
    dev_example_file = os.path.join(input_dir, 'dev_examples_{}.pkl'.format(str(max_seq_length)))
    dev_feature_file = os.path.join(input_dir, 'dev_features_{}.pkl'.format(str(max_seq_length)))

    eval_features = generate_input(predict_file, None, dev_example_file, dev_feature_file, tokenizer,
                                    max_seq_length=max_seq_length, max_num_choices=max_num_choices,
                                    is_training=False)

    all_example_ids = [f.example_id for f in eval_features]
    all_tags = [f.tag for f in eval_features]
    all_input_ids = paddle.to_tensor([f.input_ids for f in eval_features], dtype="int64")
    all_input_masks = paddle.to_tensor([f.input_masks for f in eval_features], dtype="int64")
    all_segment_ids = paddle.to_tensor([f.segment_ids for f in eval_features], dtype="int64")
    all_choice_masks = paddle.to_tensor([f.choice_masks for f in eval_features], dtype="int64")
    all_example_index = paddle.arange(all_input_ids.shape[0], dtype="int64")

    eval_data = TensorDataset([all_input_ids, all_input_masks, all_segment_ids, all_choice_masks,
                              all_example_index])
    

    return eval_data,all_example_ids,all_tags,eval_features
set_seed(2022)
max_seq_length=64
batch_size=4

print('ready for train dataset')
input_dir='output'

max_seq_length=64
max_num_choices=10

MODEL_NAME = "bert-base-chinese"
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
train_data=process_train_data(input_dir,tokenizer,max_seq_length,max_num_choices)
train_data_loader = DataLoader(dataset=train_data,
                                            batch_size=batch_size,
                                            drop_last=True,
                                            num_workers=0)

ready for train dataset


[2022-01-24 15:22:46,610] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt
  0%|          | 0/577157 [00:00<?, ?it/s]

*** Example ***
unique_id: 0
context_id: #idiom000000#
label: 3
tag_index: 115
tokens: [CLS]人老珠黄[SEP]鹰"德拉·佩纳、哈维、莫雷罗、罗杰·加西亚和贝拉乌桑几乎[unused1]的情况下,他们依然在坚持,最终他们等到了哈维的成熟,等到[SEP]
choice_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
*** Example ***
unique_id: 1
context_id: #idiom000001#
label: 5
tag_index: 97
tokens: [CLS]人老珠黄[SEP],只要我们向着这个方向努力,无论我们采取怎样的措施,其结果都会好过[unused1],只是在那里坐等又一份宣布减记消息的银行报告。[SEP]
choice_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


  2%|▏         | 11979/577157 [01:00<37:30, 251.15it/s] 

unique_id: 12000


  4%|▍         | 23976/577157 [02:04<2:18:15, 66.68it/s] 

unique_id: 24000


  6%|▌         | 35995/577157 [02:59<34:40, 260.11it/s]  

unique_id: 36000


  8%|▊         | 47996/577157 [04:08<37:18, 236.43it/s]  

unique_id: 48000


 10%|█         | 59978/577157 [04:58<35:22, 243.70it/s]  

unique_id: 60000


 12%|█▏        | 71976/577157 [06:05<33:26, 251.79it/s]  

unique_id: 72000


 15%|█▍        | 83988/577157 [06:53<34:52, 235.68it/s]  

unique_id: 84000


 17%|█▋        | 95973/577157 [08:00<29:50, 268.82it/s]  

unique_id: 96000


 19%|█▊        | 107975/577157 [08:47<29:44, 262.94it/s]

unique_id: 108000


 21%|██        | 119976/577157 [09:39<28:58, 262.93it/s]

unique_id: 120000


 23%|██▎       | 131976/577157 [10:36<28:21, 261.60it/s]  

unique_id: 132000


 25%|██▍       | 143980/577157 [11:22<28:04, 257.15it/s]

unique_id: 144000


 27%|██▋       | 155992/577157 [12:09<26:41, 262.93it/s]

unique_id: 156000


 29%|██▉       | 167981/577157 [12:55<34:10, 199.53it/s]

unique_id: 168000


 31%|███       | 179990/577157 [13:53<24:49, 266.65it/s]  

unique_id: 180000


 33%|███▎      | 191985/577157 [14:39<23:26, 273.85it/s]

unique_id: 192000


 35%|███▌      | 203974/577157 [15:28<23:33, 264.00it/s]

unique_id: 204000


 37%|███▋      | 215979/577157 [16:14<22:25, 268.45it/s]

unique_id: 216000


 40%|███▉      | 227984/577157 [17:05<23:01, 252.74it/s]

unique_id: 228000


 42%|████▏     | 239985/577157 [18:07<21:20, 263.24it/s]  

unique_id: 240000


 44%|████▎     | 251996/577157 [18:54<19:52, 272.78it/s]

unique_id: 252000


 46%|████▌     | 263978/577157 [19:42<19:55, 262.07it/s]

unique_id: 264000


 48%|████▊     | 275994/577157 [20:29<19:44, 254.22it/s]

unique_id: 276000


 50%|████▉     | 287980/577157 [21:16<18:21, 262.61it/s]

unique_id: 288000


 52%|█████▏    | 299970/577157 [22:02<17:42, 260.92it/s]

unique_id: 300000


 54%|█████▍    | 311990/577157 [23:04<16:23, 269.65it/s]  

unique_id: 312000


 56%|█████▌    | 323978/577157 [23:50<16:26, 256.59it/s]

unique_id: 324000


 58%|█████▊    | 335981/577157 [24:35<14:59, 268.14it/s]

unique_id: 336000


 60%|██████    | 347998/577157 [25:22<18:39, 204.67it/s]

unique_id: 348000


 62%|██████▏   | 359974/577157 [26:09<14:10, 255.47it/s]

unique_id: 360000


 64%|██████▍   | 371982/577157 [26:56<13:48, 247.72it/s]

unique_id: 372000


 67%|██████▋   | 383998/577157 [27:44<14:53, 216.24it/s]

unique_id: 384000


 69%|██████▊   | 395977/577157 [28:53<13:23, 225.49it/s]  

unique_id: 396000


 71%|███████   | 407976/577157 [29:44<12:29, 225.63it/s]

unique_id: 408000


 73%|███████▎  | 419979/577157 [30:30<10:24, 251.61it/s]

unique_id: 420000


 75%|███████▍  | 431976/577157 [31:18<09:07, 265.02it/s]

unique_id: 432000


 77%|███████▋  | 443973/577157 [32:04<08:21, 265.53it/s]

unique_id: 444000


 79%|███████▉  | 455972/577157 [32:52<07:17, 276.88it/s]

unique_id: 456000


 81%|████████  | 467997/577157 [33:41<07:15, 250.42it/s]

unique_id: 468000


 83%|████████▎ | 479972/577157 [34:28<05:57, 271.61it/s]

unique_id: 480000


 85%|████████▌ | 491999/577157 [35:15<05:11, 273.17it/s]

unique_id: 492000


 87%|████████▋ | 503999/577157 [36:23<04:42, 258.65it/s]  

unique_id: 504000


 89%|████████▉ | 515979/577157 [37:12<03:44, 272.31it/s]

unique_id: 516000


 91%|█████████▏| 527988/577157 [38:08<03:11, 256.67it/s]

unique_id: 528000


 94%|█████████▎| 539993/577157 [38:54<02:18, 268.65it/s]

unique_id: 540000


 96%|█████████▌| 551973/577157 [39:40<01:28, 284.72it/s]

unique_id: 552000


 98%|█████████▊| 563974/577157 [40:26<00:55, 238.98it/s]

unique_id: 564000


100%|█████████▉| 575976/577157 [41:12<00:04, 259.17it/s]

unique_id: 576000


100%|██████████| 577157/577157 [41:17<00:00, 232.98it/s]


unique_id: 577157
loaded train dataset
Num generate examples = 577157
eval_data,all_example_ids,all_tags,eval_features=process_validation_data(input_dir,tokenizer,max_seq_length,max_num_choices)

# Run prediction for full data
eval_dataloader = DataLoader(eval_data, batch_size=batch_size)
3218it [00:05, 639.83it/s]


原始样本个数:3218
实际生成总样例数:23011


  0%|          | 0/23011 [00:00<?, ?it/s]

*** Example ***
unique_id: 0
context_id: #idiom577157#
label: None
tag_index: 14
tokens: [CLS]不合时宜[SEP]中国青年报:篮协改革联赛切莫[unused1]羊城晚报:信兰成想干什么?都市时报:中国篮协该歇歇了!东方体育日报:[UNK]要扫帚不要钉耙[SEP]
choice_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
*** Example ***
unique_id: 1
context_id: #idiom577158#
label: None
tag_index: 30
tokens: [CLS]不合时宜[SEP][UNK]如山创投的丁世平告诉记者,那些私募性质的[UNK]和[UNK]由于比较[unused1],往往投资比较快、退出和套现也比较快,所以,退出环节出问[SEP]
choice_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


 52%|█████▏    | 11998/23011 [00:47<00:41, 267.69it/s]

unique_id: 12000


100%|██████████| 23011/23011 [01:29<00:00, 256.63it/s]


unique_id: 23011


W0124 16:09:32.974428  4622 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0124 16:09:32.987335  4622 device_context.cc:465] device: 0, cuDNN Version: 7.6.

数据集的处理时间比较长,需要半个小时以上,需要耐心等待。

三、模型构建

实例化BertForMultipleChoice模型,读者也可以尝试ErnieForMultipleChoice,RobertaForMultipleChoice等模型,测试一下效果。

BertForMultipleChoice的原理图如下图所示,首先数据会处理成下面的形式,每个选项拆开都构成一条单独的样本,然后把数据输入到同一个BERT中,得到CLS位置的输出,然后接入全连接层FC,输出每个样本的概率值,概率最大的样本对应的选项即为最终的答案。

max_num_choices=10
MODEL_NAME = "bert-base-chinese"
model = BertForMultipleChoice.from_pretrained(MODEL_NAME,
                                              num_choices=max_num_choices)
[2022-01-24 16:09:35,756] [    INFO] - Downloading http://bj.bcebos.com/paddlenlp/models/transformers/bert/bert-base-chinese.pdparams and saved to /home/aistudio/.paddlenlp/models/bert-base-chinese
[2022-01-24 16:09:35,758] [    INFO] - Downloading bert-base-chinese.pdparams from http://bj.bcebos.com/paddlenlp/models/transformers/bert/bert-base-chinese.pdparams
100%|██████████| 680M/680M [00:15<00:00, 46.5MB/s] 

四、训练配置

配置训练所需要的超参数,优化器,损失函数,评估方式等。

num_train_epochs=3
max_grad_norm = 1.0
num_training_steps = len(train_data_loader) * num_train_epochs

# 定义 learning_rate_scheduler,负责在训练过程中对 lr 进行调度
lr_scheduler = LinearDecayWithWarmup(2e-5, num_training_steps, 0)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
grad_clip = paddle.nn.ClipGradByGlobalNorm(max_grad_norm)

# 定义 Optimizer
optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=0.01,
        apply_decay_param_fun=lambda x: x in decay_params,
        grad_clip=grad_clip)
# 交叉熵损失
criterion = paddle.nn.loss.CrossEntropyLoss()
# 评估的时候采用准确率指标
metric = paddle.metric.Accuracy()

五、模型训练

接下来是训练模型,在模型训练的过程中需要评估,这里先实现do_evaluate函数,主要用于训练集训练的过程中的评估

@paddle.no_grad()
def do_evaluate(model, dev_data_loader,all_example_ids,all_tags,eval_features):

    all_results = []
    model.eval()
    output_dir='data'
    for step, batch in enumerate(tqdm(dev_data_loader)):

        input_ids, input_masks, segment_ids, choice_masks, example_indices=batch
        batch_logits = model(input_ids=input_ids, token_type_ids=segment_ids,attention_mask=input_masks)
        
        for i, example_index in enumerate(example_indices):
            logits = batch_logits[i].numpy().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(RawResult(unique_id=unique_id,
                                             example_id=all_example_ids[unique_id],
                                             tag=all_tags[unique_id],
                                             logit=logits))
                
    predict_file = 'dev_predictions.json'
    predict_ans_file='data/dev_answer.json'
    print('decoder raw results')
    tmp_predict_file = os.path.join(output_dir, "raw_predictions.pkl")
    output_prediction_file = os.path.join(output_dir, predict_file)
    results = get_final_predictions(all_results, tmp_predict_file, g=True)
    write_predictions(results, output_prediction_file)
    print('predictions saved to {}'.format(output_prediction_file))

    acc = evaluate(predict_ans_file, output_prediction_file)
    print(f'{predict_file} 预测精度:{acc}')
    model.train()
    # return np.mean(all_loss), acc
    return acc

下面定义训练的函数 do_train

def do_train(model,train_data_loader,dev_data_loader,all_example_ids,all_tags,eval_features):

    model.train()
    global_step = 0
    tic_train = time.time()
    log_step = 100
    for epoch in range(num_train_epochs):
        metric.reset()
        for step, batch in enumerate(train_data_loader):
            input_ids, input_masks, segment_ids, choice_masks, labels =batch

            logits = model(input_ids=input_ids, token_type_ids=segment_ids,attention_mask=input_masks)

            loss = criterion(logits, labels)
            correct = metric.compute(logits, labels)
            metric.update(correct)
            acc = metric.accumulate()

            global_step += 1

            # 每间隔 log_step 输出训练指标
            if global_step % log_step == 0:
                print(
                    "global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"
                    % (global_step, epoch, step, loss, acc, 10 /
                       (time.time() - tic_train)))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

        do_evaluate(model, dev_data_loader,all_example_ids,all_tags,eval_features)
        model.save_pretrained("./checkpoint")

训练的时间有点长,需要至少1天的时间,建议改成多卡跑程序。或者注释掉,用已经跑好的模型。

do_train(model,train_data_loader,dev_data_loader,all_example_ids,all_tags,eval_features)
global step 600, epoch: 0, batch: 599, loss: 2.33230, accu: 0.14750, speed: 0.73 step/s
global step 700, epoch: 0, batch: 699, loss: 2.35523, accu: 0.15750, speed: 0.72 step/s
global step 800, epoch: 0, batch: 799, loss: 2.37167, accu: 0.16687, speed: 0.72 step/s
global step 900, epoch: 0, batch: 899, loss: 2.08946, accu: 0.17722, speed: 0.72 step/s
global step 1000, epoch: 0, batch: 999, loss: 2.61547, accu: 0.18975, speed: 0.72 step/s
global step 1100, epoch: 0, batch: 1099, loss: 2.07350, accu: 0.19614, speed: 0.72 step/s
global step 1200, epoch: 0, batch: 1199, loss: 2.19043, accu: 0.20333, speed: 0.72 step/s
...

上图显式的是部分训练日志,从accu来看,模型是慢慢收敛的。

六、模型预测

模型预测的部分主要是处理测试集合,放入模型中进行预测,然后输出json格式的结果,然后就可以传到CLUE网站上进行测试了。

首先处理测试集,把测试文本转换成ID的形式,然后封装到Dataloader中。

def process_test_data(input_dir,tokenizer,max_seq_length,max_num_choices):

    # 测试文件有两个,一个是test1.0.json,另一个是test1.1.json。
    # 可以根据情况使用test的json文件,本案例使用的是test1.0.json
    predict_file='data/test1.0.json'

    test_example_file = os.path.join(input_dir, 'test_examples_{}.pkl'.format(str(max_seq_length)))
    test_feature_file = os.path.join(input_dir, 'test_features_{}.pkl'.format(str(max_seq_length)))

    test_features = generate_input(predict_file, None, test_example_file, test_feature_file, tokenizer,
                                   max_seq_length=max_seq_length, max_num_choices=max_num_choices,
                                   is_training=False)

    all_example_ids = [f.example_id for f in test_features]
    all_tags = [f.tag for f in test_features]
    all_input_ids = paddle.to_tensor([f.input_ids for f in test_features], dtype="int64")
    all_input_masks = paddle.to_tensor([f.input_masks for f in test_features], dtype="int64")
    all_segment_ids = paddle.to_tensor([f.segment_ids for f in test_features], dtype="int64")
    all_choice_masks = paddle.to_tensor([f.choice_masks for f in test_features], dtype="int64")
    all_example_index = paddle.arange(all_input_ids.shape[0], dtype="int64")

    test_data = TensorDataset([all_input_ids, all_input_masks, all_segment_ids, all_choice_masks,
                              all_example_index])
    # Run prediction for full data

    return test_data,all_example_ids,all_tags,test_features


@paddle.no_grad()
def do_test(model, dev_data_loader,all_example_ids,all_tags,eval_features):

    all_results = []
    model.eval()
    output_dir='work'
    for step, batch in enumerate(tqdm(dev_data_loader)):

        input_ids, input_masks, segment_ids, choice_masks, example_indices=batch
        batch_logits = model(input_ids=input_ids, token_type_ids=segment_ids,attention_mask=input_masks)
        # loss = criterion(batch_logits, labels)

        # all_loss.append(loss.numpy())
        for i, example_index in enumerate(example_indices):
            logits = batch_logits[i].numpy().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(RawResult(unique_id=unique_id,
                                             example_id=all_example_ids[unique_id],
                                             tag=all_tags[unique_id],
                                             logit=logits))
                
    output_file = 'chid10_predict.json'
    print('decoder raw results')
    tmp_predict_file = os.path.join(output_dir, "test_raw_predictions.pkl")
    output_prediction_file = os.path.join(output_dir, output_file)
    results = get_final_predictions(all_results, tmp_predict_file, g=True)
    write_predictions(results, output_prediction_file)
    print('predictions saved to {}'.format(output_prediction_file))



加载训练好的模型,进行预测,运行结束后会生成json文件,下载到本地,压缩成zip文件,然后上传到CLUE评测网站上就能进行评测了。

由于训练的时间过长,这里给了一个训练好的模型,下面解压这个模型。

!unzip -o data/data125854/checkpoint.zip
Archive:  data/data125854/checkpoint.zip
  inflating: __MACOSX/._checkpoint   
  inflating: checkpoint/model_state.pdparams  
  inflating: __MACOSX/checkpoint/._model_state.pdparams  
  inflating: checkpoint/model_config.json  
  inflating: __MACOSX/checkpoint/._model_config.json  

加载训练好的模型,然后进行预测。

# 为了防止切换测试集后,复用以前的临时文件,这里在每次测试前,删除生成的临时文件
# 这样就能够保证每次测试的数据集都是最新指定的
!rm -rf output/test_examples_64.pkl
!rm -rf output/test_features_64.pkl
max_num_choices=10
model = BertForMultipleChoice.from_pretrained('checkpoint',
                                              num_choices=max_num_choices)
test_data,all_example_ids,all_tags,test_features=process_test_data(input_dir,tokenizer,max_seq_length,max_num_choices)

test_data_loader = DataLoader(eval_data, batch_size=batch_size)

do_test(model, test_data_loader,all_example_ids,all_tags,test_features)

100%|██████████| 5753/5753 [04:34<00:00, 20.99it/s]


decoder raw results
Writing predictions to: work/chid10_predict.json
predictions saved to work/chid10_predict.json

运行结束以后,在work目录下会生成chid10_predict.json,下载到本地,然后压缩上传到CLUE官网就可以评测了。

七、更多PaddleEdu信息内容

1. PaddleEdu一站式深度学习在线百科awesome-DeepLearning中还有其他的能力,大家可以敬请期待:

  • 深度学习入门课
  • 深度学习百问
  • 特色课
  • 产业实践

PaddleEdu使用过程中有任何问题欢迎在awesome-DeepLearning提issue,同时更多深度学习资料请参阅飞桨深度学习平台

记得点个Star⭐收藏噢~~

2. 飞桨PaddleEdu技术交流群(QQ)

目前QQ群已有2000+同学一起学习,欢迎扫码加入

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐