【AI达人特训营】法律领域篇章级多事件检测
本项目用于判断法律案件中各个事件对应的多种事件类型,对数据进行EDA后基于ernie3.0预训练模型,使用了权重衰减、梯度裁剪、对抗训练等方法,对抗前后模型F1值分别为0.916和0.781。
法律领域篇章级多事件检测
本项目针对法律案件中存在触发词不明显或者不包含触发词的事件,试图建立稳健的事件检测模型,用于判断法律案件中所包含的各个事件对应的事件类型,进而对后续的事件元素抽取任务提供支持。
# 在V100 32G及以上环境中运行
!/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip --user
!pip install pyzmq==18.1.1
!pip install --upgrade paddlenlp
1 准备数据
数据格式如下:
{
"id": 1,
"text": "赵四与妻子王五通过相亲认识,2011年登记结婚,婚后共生育三个孩子,后双方因感情不和,于2020年协议离婚,协议约定,离婚后,三个孩子在一年内跟随王五生活,赵四每月每个孩子支付2000元抚养费,2021年三个孩子向法院提起诉讼,要求赵四按照协议约定支付抚养费。",
"classname": "婚姻家庭纠纷",
"eventchain": [
{"trigger": "结婚", "eventtype": "Marry", "argument": [{"husband": "赵四", "wife": "王五", "time": "2011年", "loc": ""}]},
{"trigger": "生育", "eventtype": "BeBorn", "argument": [{"per": "三个孩子", "time": "婚后", "loc": ""}]},
{"trigger": "离婚", "eventtype": "Other", "argument": [{"subjec": "赵四", "object": "王五", "context": "双方感情不和", "time": "2020年", "loc": ""}]},
{"trigger": "诉讼", "eventtype": "Prosecute", "argument": [{"prosecutor": "孩子", "defendant": "赵四", "reason": "", "demand": "按照协议约定支付抚养费", "time": "2021年", "court": "法院"}]}],
"caseresult": ["双方达成调解,被告同意按照协议约定支付抚养费。"
]
}
由于实际上目标任务是一个多分类任务,因此只保留了数据中id
,text
,event_type
属性(其中event_type
用label
代替),分别生成训练集、验证集和测试集。
from utils import id_label, data_prepare
datapath = 'data/data143875'
id2label, label2id = id_label(datapath)
train_dataset = data_prepare(datapath, label2id, mode='train')
dev_dataset = data_prepare(datapath, label2id, mode='dev')
test_dataset = data_prepare(datapath, label2id, mode='test')
print(train_dataset[0], '\n\n', id2label, '\n\n', label2id)
{'id': 1, 'text': '赵四与妻子王五通过相亲认识,2011年登记结婚,婚后共生育三个孩子,后双方因感情不和,于2020年协议离婚,协议约定,离婚后,三个孩子在一年内跟随王五生活,赵四每月每个孩子支付2000元抚养费,2021年三个孩子向法院提起诉讼,要求赵四按照协议约定支付抚养费。', 'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]}
{0: 'Be_Born', 1: 'Cohabit', 2: 'Derailment', 3: 'Domestic_Violence', 4: 'Gamble', 5: 'Marry', 6: 'Other', 7: 'Prosecute', 8: 'Raise', 9: 'Separation', 10: 'Support'}
{'Be_Born': 0, 'Cohabit': 1, 'Derailment': 2, 'Domestic_Violence': 3, 'Gamble': 4, 'Marry': 5, 'Other': 6, 'Prosecute': 7, 'Raise': 8, 'Separation': 9, 'Support': 10}
2 数据探索
EDA部分主要参考了https://github.com/kangyishuai/NEWS-TEXT-CLASSIFICATION/blob/master/EDA.ipynb ,将单标签改成了多标签
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
from collections import Counter
from matplotlib.font_manager import FontProperties
2.1 简单查看数据
df_train = pd.DataFrame.from_dict(train_dataset)
df_train.head()
id | text | labels | |
---|---|---|---|
0 | 1 | 赵四与妻子王五通过相亲认识,2011年登记结婚,婚后共生育三个孩子,后双方因感情不和,于20... | [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, ... |
1 | 2 | 张先生与李女士于2010年登记结婚,2011年育有一子,但因两人婚前感情基础薄弱,婚后两人开... | [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ... |
2 | 3 | 原告范文与妻子李云系夫妻关系,两人自2000年2月19日登记结婚。在范文与妻子李云婚姻关系存... | [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ... |
3 | 4 | 1999年,朱三与妻子倪静经人介绍认识,两人于2000年登记结婚。2014年,双方因感情不和... | [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ... |
4 | 5 | 黄云和丈夫张山在1997年开始以夫妻名义同居生活在一起。双方于1998年9月6日生育大女儿张... | [1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ... |
df_train.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 971 entries, 0 to 970
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 id 971 non-null int64
1 text 971 non-null object
2 labels 971 non-null object
dtypes: int64(1), object(2)
memory usage: 22.9+ KB
2.2 文本长度分布
2.2.1 文本长度描述
# 训练集的文本长度描述
df_train['lenth'] = df_train['text'].apply(lambda x:len(x))
df_train['lenth'].describe()
count 971.000000
mean 282.075180
std 453.396933
min 3.000000
25% 149.000000
50% 199.000000
75% 261.000000
max 5057.000000
Name: lenth, dtype: float64
# 测试集的文本长度描述
df_test = pd.DataFrame.from_dict(test_dataset)
df_test['lenth'] = df_test['text'].apply(lambda x:len(x))
df_test['lenth'].describe()
count 485.000000
mean 190.096907
std 94.097040
min 28.000000
25% 122.000000
50% 180.000000
75% 237.000000
max 597.000000
Name: lenth, dtype: float64
2.2.2 文本长度频数分布
fig, ax = plt.subplots(1,1,figsize=(12,4))
ax = plt.hist(x=df_train['lenth'], bins=100)
ax = plt.hist(x=df_test['lenth'], bins=100)
plt.xlim([0, max(max(df_train['lenth']), max(df_test['lenth']))])
plt.xlabel("length of sample")
plt.ylabel("number of sample")
plt.legend(['train_len','test_len'])
plt.show()
2.2.3 文本长度频率分布
plt.figure(figsize=(12,4))
ax = sns.distplot(df_train['lenth'], bins=100)
ax = sns.distplot(df_test['lenth'], bins=100)
plt.xlim([0, max(max(df_train['lenth']), max(df_test['lenth']))])
plt.xlabel("length of sample")
plt.ylabel("prob of sample")
plt.legend(['train_len','test_len'])
<matplotlib.legend.Legend at 0x7f7835469890>
2.2.4 同分布验证
# https://blog.csdn.net/weixin_30230009/article/details/122872228
import scipy
scipy.stats.ks_2samp(df_train['lenth'], df_test['lenth'])
KstestResult(statistic=0.1258836145115568, pvalue=6.357808192014591e-05)
可以看到,训练集和验证集均与测试集的文本长度分布存在偏移,这个地方可以进行考虑优化。
2.3 截断位置选择
2.3.1 正态性验证
log_train_len = np.log(1+df_train['lenth'])
log_test_len = np.log(1+df_test['lenth'])
_, lognormal_ks_pvalue = scipy.stats.kstest(rvs=log_train_len, cdf='norm')
print(lognormal_ks_pvalue)
trans_data, lam = scipy.stats.boxcox(df_train['lenth']+1)
print(scipy.stats.normaltest(trans_data))
0.0
NormaltestResult(statistic=288.1385347098726, pvalue=2.7009228305021675e-63)
两个方法得到的p值都很小,说明不符合正态分布,不能直接使用3σ原则来截断
2.3.2 看图猜截断长度
plt.figure(figsize=(12,4))
ax = sns.distplot(log_train_len)
ax = sns.distplot(log_test_len)
plt.xlabel("log length of sample")
plt.ylabel("prob of log")
plt.legend(['train_len','test_len'])
<matplotlib.legend.Legend at 0x7f78351aa610>
看log图可以看到训练集和测试集的频率在6.5左右的位置基本为0,因此根据np.exp(6.5)≈665来选择截断长度为640。
2.4 类别分析
2.4.1 类别分布
num_label = []
for j in range(len(id2label)):
m = 0
for i in range(len(df_train['labels'])):
m += df_train['labels'][i][j]
num_label.append(m)
plt.figure()
plt.bar(x=range(len(id2label)), height=num_label)
plt.xlabel("label")
plt.ylabel("number of sample")
plt.xticks(range(len(id2label)), list(id2label.values()), rotation=45)
plt.show()
2.4.2 类别长度
objs = [df_train[['id', 'lenth']], pd.DataFrame(df_train['labels'].tolist())]
ans1 = pd.concat(objs, axis=1)
ans2 = pd.melt(ans1, var_name='id_label', value_name='labels', id_vars=['id', 'lenth'])
ans2 = ans2[ans2['labels']!=0.0].reset_index(drop=True).drop('labels', axis=1)
plt.figure()
ax = sns.catplot(x='id_label', y='lenth', data=ans2, kind='strip')
plt.xticks(range(len(id2label)), list(id2label.values()), rotation=45)
([<matplotlib.axis.XTick at 0x7f7832d79090>,
<matplotlib.axis.XTick at 0x7f7832d73710>,
<matplotlib.axis.XTick at 0x7f7832d13fd0>,
<matplotlib.axis.XTick at 0x7f7832d1c290>,
<matplotlib.axis.XTick at 0x7f7832d1cc10>,
<matplotlib.axis.XTick at 0x7f7832d24090>,
<matplotlib.axis.XTick at 0x7f7832d73410>,
<matplotlib.axis.XTick at 0x7f7832ce3ad0>,
<matplotlib.axis.XTick at 0x7f7832ce3c50>,
<matplotlib.axis.XTick at 0x7f7832ce7610>,
<matplotlib.axis.XTick at 0x7f7832ce7b50>],
<a list of 11 Text xticklabel objects>)
<Figure size 432x288 with 0 Axes>
可以看到该数据类别失衡,而参考来源中发现不同类别的文本长度不同,因而使用文本长度作为特征之一来应对失衡问题,但该数据不符合这种特征,要想一些别的办法来优化。
小结
对数据进行探索得到一些不绝对的结论:
- 训练集和测试集的文本长度分布存在差异,训练集文本长度的最大值和最小值差距较大。
- 训练集文本长度分布是非正态分布,无法使用3σ原则来对文本长度进行截断,最终结合图粗略选择了640作为截断长度。
- 训练集中文本类别不均衡,存在数据偏移现象,且与文本长度关联不大,因此训练中尝试加入了对抗训练。
- 训练集和验证集在文本长度、类型分布较测试集更相似,因此训练中由于数据不多而把验证集加入到训练集中时,采取了剪枝策略。
3 事件检测分类
3.1 从本地文件创建数据集
3.1和3.2是两种封装数据集的方法,选择一个即可。
from paddlenlp.datasets import load_dataset
from utils import id_label, data_prepare, data_split
datapath = 'data/data143875'
id2label, label2id = id_label(datapath)
train_dataset = data_prepare(datapath, label2id, mode='train')
test_dataset = data_prepare(datapath, label2id, mode='dev')
dataset = train_dataset + test_dataset # 训练集和验证集共同组成新的训练集
train_dataset, test_dataset = data_split(dataset) # 划分数据集
def read(dataset):
for data in dataset:
text, labels = data['text'], data['labels']
yield {'text': text, 'labels': labels}
train_dataset = load_dataset(read, dataset=train_dataset,lazy=False) # dataset是read的参数
test_dataset = load_dataset(read, dataset=test_dataset,lazy=False)
print(id2label, '\n\n', label2id, '\n\n')
print("训练集样例:", train_dataset[291])
print("测试集样例:", test_dataset[291])
{0: 'Be_Born', 1: 'Cohabit', 2: 'Derailment', 3: 'Domestic_Violence', 4: 'Gamble', 5: 'Marry', 6: 'Other', 7: 'Prosecute', 8: 'Raise', 9: 'Separation', 10: 'Support'}
{'Be_Born': 0, 'Cohabit': 1, 'Derailment': 2, 'Domestic_Violence': 3, 'Gamble': 4, 'Marry': 5, 'Other': 6, 'Prosecute': 7, 'Raise': 8, 'Separation': 9, 'Support': 10}
训练集样例: {‘text’: ‘原告李晓晓(女方)与被告陈武强(男方)于2007年10月26日办理结婚登记手续,于2009年2月24日生育一女陈安安,被告对原告态度野蛮,其父母唆使被告对原告冷淡,双方无法共同生活,双方曾协商离婚,其所达成离婚协议,被告不履行,现双方夫妻感情彻底破裂,为此,原告于2009年11月4日起诉要求离婚,婚生女跟随原告共同生活,被告按离婚协议承担子女抚养费每月500元,婚前嫁妆归原告所有,另合理分割夫妻共同财产,被告遗弃虐待家庭成员存在过错,应向原告支付50000元赔偿,现原告无固定住所,无经济来源,需被告给付20000元经济帮助金,诉讼费由被告承担’, ‘labels’: [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]}
测试集样例: {‘text’: ‘原告黄楠(女方)、被告李志斌(男方)经人介绍于1997年12月18日在邵阳县民政局登记结婚,婚后被告经常对原告进行殴打和人身诋毁,夜不归宿,对家庭不负责任,被告既吸毒又有外遇,与被告有不正当男女关系的女人罗静经常打电话骚扰原告,原、被告2015年7月因争吵打架,分居至今,原、被告婚后生育有一女李文文,生日1998年7月14日,一儿李磊,生日2003年9月4日,综上所述,原、被告的感情已完全破裂,再无和好的可能,故于2015年9月8日诉至湖南省邵阳市大祥区人民法院,请求:1、判决原、被告离婚;2、依法平均分割夫妻存续期间的共同财产;3、判决婚生小孩李文文(女儿)李磊(儿子)由原告直接抚养,被告每月支付抚养费2000元’, ‘labels’: [1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0]}
3.2 从 paddle.io.Dataset创建数据集
from paddle.io import Dataset
from paddlenlp.datasets import MapDataset
from utils import id_label, data_prepare, data_split
import json, os
datapath = 'data/data143875'
class LawDataset(Dataset):
def __init__(self, path, *modes):
self.id2label, self.label2id = id_label(datapath)
self.data = []
for mode in modes:
data = data_prepare(datapath, self.label2id, mode=mode)
self.data.extend(data)
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
dataset = LawDataset(datapath, 'train', 'dev')
train_dataset, test_dataset = data_split(dataset) # 划分数据集
print(train_dataset[0])
print(test_dataset[0])
train_dataset = MapDataset(train_dataset)
test_dataset = MapDataset(test_dataset)
{'text': '原告陈红(女方)、被告陆磊(男方)于2004年正月经人介绍相识,2004年12月9日办理结婚登记,2005年5月10日婚生女孩陆小莉,由于原、被告脾气性格差异太大,特别是被告心胸狭隘又性情暴躁,稍不如意便要打骂原告,以致原、被告婚后感情一直不和,2011年农历12月,原、被告协议离婚之事发生口角,被告不由分说便将原告毒打一顿,原告父亲见状,便批评指责被告,被告恼羞成怒,又将原告的父亲毒打一顿,此后原、被告因夫妻感情不和而持续分居至今,期间,原告于2012年5月向法院起诉离婚未果,现夫妻感情已完全破裂,现于2015年6月16日再次起诉请求湖南省隆回县人民法院判令原、被告离婚;婚生小孩由原告抚养成年,由被告负担抚养费4万元', 'labels': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]}
{'text': '原告江俊与被告徐丽于1991年相识,1992年确立恋爱关系,1994年12月31日登记结婚,1995年5月2日生育一女江敏。婚后初期双方感情尚可,后因原告长期在外地工作,双方聚少离多,沟通不够,逐渐产生矛盾。原告于2013年6月4日向本院起诉离婚。2013年8月5日,本院判决驳回原告的诉讼请求。之后原告搬至其父母处居住。2014年3月4日,原告再次向本院起诉要求离婚,经调解,双方未达成一致协议。', 'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]}
3.3 将文本数据处理成模型可以接受的 feature
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from paddlenlp.data import Stack, Tuple, Pad, Dict
from functools import partial
model_name = "ernie-3.0-base-zh"
tokenizer = AutoTokenizer.from_pretrained(model_name)
[2022-06-24 10:35:35,644] [ INFO] - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'ernie-3.0-base-zh'.
[2022-06-24 10:35:35,646] [ INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh_vocab.txt and saved to /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh
[2022-06-24 10:35:35,647] [ INFO] - Downloading ernie_3.0_base_zh_vocab.txt from https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh_vocab.txt
100%|██████████| 182k/182k [00:00<00:00, 22.2MB/s]
def convert_example(example, tokenizer, max_seq_length=640, is_test=False): # 数据预处理函数
# tokenizer.encode方法能够完成切分token,映射token ID以及拼接特殊token
tokenized_example = tokenizer.encode(text=example['text'], max_seq_len=max_seq_length, truncation=True)
if not is_test:
tokenized_example['labels'] = example['labels'] # 加上labels用于训练
else:
tokenized_example['ids'] = example['id']
return tokenized_example
trans_func = partial( # 给convert_example传入参数
convert_example,
tokenizer=tokenizer,
max_seq_length=640,
is_test=False
)
train_dataset = train_dataset.map(trans_func)
test_dataset = test_dataset.map(trans_func)
print(train_dataset[0])
print(test_dataset[0])
{'input_ids': [1, 250, 612, 830, 536, 78, 291, 58, 77, 6, 171, 612, 891, 3548, 78, 654, 58, 77, 37, 1540, 17, 243, 136, 60, 8, 769, 965, 156, 474, 4, 1540, 17, 768, 136, 701, 139, 315, 38, 215, 1059, 883, 374, 4, 1464, 17, 317, 136, 530, 139, 1059, 21, 291, 751, 891, 96, 2113, 4, 190, 37, 250, 6, 171, 612, 2178, 266, 92, 318, 859, 712, 512, 19, 4, 169, 348, 10, 171, 612, 111, 1553, 2035, 3977, 311, 92, 182, 1386, 3004, 4, 1632, 16, 142, 221, 518, 41, 445, 2636, 250, 612, 4, 22, 600, 250, 6, 171, 612, 1059, 49, 345, 182, 7, 339, 16, 14, 4, 1490, 17, 194, 382, 768, 136, 4, 250, 6, 171, 612, 443, 454, 417, 1059, 46, 104, 34, 21, 270, 591, 4, 171, 612, 16, 190, 59, 178, 518, 174, 250, 612, 886, 445, 7, 1410, 4, 250, 612, 795, 601, 373, 498, 4, 518, 650, 480, 288, 319, 171, 612, 4, 171, 612, 2532, 2836, 33, 1758, 4, 311, 174, 250, 612, 5, 795, 601, 886, 445, 7, 1410, 4, 198, 49, 250, 6, 171, 612, 196, 752, 1478, 345, 182, 16, 14, 83, 303, 607, 59, 529, 268, 508, 4, 195, 143, 4, 250, 612, 37, 1541, 17, 317, 136, 253, 72, 245, 200, 1005, 417, 1059, 556, 228, 4, 87, 752, 1478, 345, 182, 265, 328, 62, 727, 1142, 4, 87, 37, 2016, 17, 515, 136, 1227, 139, 486, 218, 200, 1005, 647, 323, 677, 219, 244, 1310, 381, 308, 8, 119, 72, 245, 1079, 708, 250, 6, 171, 612, 417, 1059, 12048, 1059, 21, 96, 751, 190, 250, 612, 1815, 423, 33, 17, 4, 190, 171, 612, 383, 675, 1815, 423, 453, 397, 211, 183, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]}
{'input_ids': [1, 250, 612, 409, 1897, 54, 171, 612, 1416, 856, 37, 2304, 17, 156, 474, 4, 2117, 17, 524, 202, 1419, 329, 129, 135, 4, 2129, 17, 768, 136, 1962, 139, 883, 374, 215, 1059, 4, 2029, 17, 317, 136, 249, 139, 21, 227, 7, 291, 409, 1443, 12043, 1059, 49, 590, 195, 653, 58, 345, 182, 854, 48, 4, 49, 196, 250, 612, 84, 195, 11, 137, 31, 35, 25, 4, 653, 58, 968, 332, 417, 65, 4, 1195, 124, 16, 824, 4, 913, 956, 66, 21, 1894, 1707, 12043, 250, 612, 37, 1648, 17, 515, 136, 397, 139, 253, 89, 245, 200, 1005, 417, 1059, 12043, 1648, 17, 585, 136, 317, 139, 4, 89, 245, 1079, 448, 3059, 381, 250, 612, 5, 1005, 2065, 647, 323, 12043, 46, 49, 250, 612, 2136, 268, 63, 795, 746, 239, 529, 520, 12043, 1838, 17, 284, 136, 397, 139, 4, 250, 612, 486, 218, 253, 89, 245, 200, 1005, 41, 323, 417, 1059, 4, 60, 290, 273, 4, 653, 58, 556, 302, 33, 7, 600, 443, 454, 12043, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]}
3.4 组成batch
# collate_fn函数构造,将不同长度序列充到批中数据的最大长度,再将数据堆叠
collate_fn = lambda samples, fn=Dict({
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id),
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
'labels': Stack(dtype="float32")
}): fn(samples)
from paddle.io import DataLoader, BatchSampler
train_batch_sampler = BatchSampler(train_dataset, batch_size=16, shuffle=True)
train_data_loader = DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, collate_fn=collate_fn)
test_batch_sampler = BatchSampler(test_dataset, batch_size=16, shuffle=False)
test_data_loader = DataLoader(dataset=test_dataset, batch_sampler=test_batch_sampler, collate_fn=collate_fn)
3.5 定义模型网络和损失函数
from metric import MultiLabelReport
import paddle
max_steps = -1
epochs = 10
learning_rate = 2e-5
warmup_steps = 1000
id2label, label2id = id_label(datapath)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(id2label))
num_training_steps = max_steps if max_steps > 0 else len(train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_steps)
# 梯度裁剪
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
# 生成执行权重衰减所需的参数名称,所有偏差和分层参数被被排除
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
epsilon=1e-8,
parameters=model.parameters(),
grad_clip=clip,
weight_decay=0.0,
apply_decay_param_fun=lambda x: x in decay_params)
criterion = paddle.nn.BCEWithLogitsLoss()
metric = MultiLabelReport()
[2022-06-24 10:35:35,775] [ INFO] - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForSequenceClassification'> to load 'ernie-3.0-base-zh'.
[2022-06-24 10:35:35,777] [ INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh
[2022-06-24 10:35:35,780] [ INFO] - Downloading ernie_3.0_base_zh.pdparams from https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh.pdparams
100%|██████████| 452M/452M [00:06<00:00, 74.9MB/s]
W0624 10:35:42.158780 1265 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0624 10:35:42.163237 1265 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.
3.6 开始训练
import time
from eval import evaluate
import paddle.nn.functional as F
ckpt_dir = "ernie_ckpt" # 训练过程中保存模型参数的文件夹
def train(epochs, save_dir=ckpt_dir):
model.train()
best_f1_score = 0
for epoch in range(1, epochs + 1):
global_step = 0 # 迭代次数
for step, batch in enumerate(train_data_loader, start=1):
tic_train = time.time()
length = len(train_data_loader)
input_ids, token_type_ids, labels = batch
# 计算模型输出、损失函数值、分类概率值、准确率、f1分数
logits = model(input_ids, token_type_ids)
loss = criterion(logits, labels)
probs = F.sigmoid(logits)
metric.update(probs, labels)
auc, f1_score, _, _ = metric.accumulate()
# 每迭代40次或batch训练完毕,打印损失函数值、准确率、f1分数、计算速度
global_step += 1
if global_step % 40 == 0 or global_step == length:
print(
"epoch: %d, batch: %d, loss: %.5f, auc: %.5f, f1 score: %.5f, time: %.2f s"
% (epoch, step, loss, auc, f1_score, (time.time() - tic_train))) # 每个batch用时
tic_train = time.time()
# 梯度回传,更新参数
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
# 每个epoch保存一次最佳模型参数
if not os.path.exists(save_dir):
os.makedirs(save_dir)
eval_f1_score = evaluate(model, criterion, metric, test_data_loader, id2label, if_return_results=False)
if eval_f1_score > best_f1_score:
best_f1_score = eval_f1_score
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
train(epochs=epochs, save_dir=ckpt_dir)
epoch: 1, batch: 40, loss: 0.66966, auc: 0.55707, f1 score: 0.53233, time: 0.55 s
epoch: 1, batch: 73, loss: 0.61689, auc: 0.60806, f1 score: 0.53390, time: 0.52 s
eval loss: 0.61039, auc: 0.78400, f1 score: 0.65799, precison: 0.64456, recall: 0.67199
[2022-06-24 10:36:50,808] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:36:50,810] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 2, batch: 40, loss: 0.51279, auc: 0.82381, f1 score: 0.68527, time: 0.61 s
epoch: 2, batch: 73, loss: 0.43841, auc: 0.84239, f1 score: 0.71003, time: 0.76 s
eval loss: 0.42994, auc: 0.89576, f1 score: 0.77901, precison: 0.78145, recall: 0.77660
[2022-06-24 10:38:04,407] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:38:04,409] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 3, batch: 40, loss: 0.49173, auc: 0.87730, f1 score: 0.76658, time: 0.71 s
epoch: 3, batch: 73, loss: 0.44703, auc: 0.88049, f1 score: 0.77201, time: 0.81 s
eval loss: 0.37946, auc: 0.90521, f1 score: 0.79335, precison: 0.76374, recall: 0.82535
[2022-06-24 10:39:19,422] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:39:19,425] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 4, batch: 40, loss: 0.40879, auc: 0.89755, f1 score: 0.78716, time: 0.66 s
epoch: 4, batch: 73, loss: 0.45828, auc: 0.89706, f1 score: 0.78756, time: 0.78 s
eval loss: 0.35623, auc: 0.91176, f1 score: 0.79962, precison: 0.85628, recall: 0.75000
[2022-06-24 10:40:32,164] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:40:32,166] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 5, batch: 40, loss: 0.44394, auc: 0.90714, f1 score: 0.80619, time: 0.70 s
epoch: 5, batch: 73, loss: 0.34573, auc: 0.90582, f1 score: 0.80204, time: 0.71 s
eval loss: 0.33909, auc: 0.91683, f1 score: 0.81329, precison: 0.86125, recall: 0.77039
[2022-06-24 10:41:46,671] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:41:46,674] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 6, batch: 40, loss: 0.39112, auc: 0.91125, f1 score: 0.80203, time: 0.60 s
epoch: 6, batch: 73, loss: 0.32801, auc: 0.91551, f1 score: 0.80923, time: 0.74 s
eval loss: 0.31262, auc: 0.92841, f1 score: 0.82943, precison: 0.92100, recall: 0.75443
[2022-06-24 10:43:00,391] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:43:00,393] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 7, batch: 40, loss: 0.30550, auc: 0.92791, f1 score: 0.81883, time: 0.62 s
epoch: 7, batch: 73, loss: 0.25464, auc: 0.93432, f1 score: 0.83203, time: 0.81 s
eval loss: 0.27542, auc: 0.94808, f1 score: 0.85993, precison: 0.95657, recall: 0.78103
[2022-06-24 10:44:18,270] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:44:18,272] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 8, batch: 40, loss: 0.27216, auc: 0.94950, f1 score: 0.86048, time: 0.71 s
epoch: 8, batch: 73, loss: 0.31507, auc: 0.95240, f1 score: 0.86804, time: 0.79 s
eval loss: 0.24641, auc: 0.96001, f1 score: 0.88509, precison: 0.92300, recall: 0.85018
[2022-06-24 10:45:37,531] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:45:37,534] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 9, batch: 40, loss: 0.21307, auc: 0.96411, f1 score: 0.89286, time: 0.69 s
epoch: 9, batch: 73, loss: 0.16695, auc: 0.96628, f1 score: 0.89663, time: 0.76 s
eval loss: 0.21330, auc: 0.97392, f1 score: 0.90639, precison: 0.92048, recall: 0.89273
[2022-06-24 10:46:56,266] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:46:56,269] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
epoch: 10, batch: 40, loss: 0.20534, auc: 0.97830, f1 score: 0.91651, time: 0.77 s
epoch: 10, batch: 73, loss: 0.18349, auc: 0.97840, f1 score: 0.91527, time: 0.79 s
eval loss: 0.20442, auc: 0.97672, f1 score: 0.91613, precison: 0.94278, recall: 0.89096
[2022-06-24 10:48:12,789] [ INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:48:12,791] [ INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json
此时模型在验证集上的最佳F1值表现为:
eval loss: 0.20442, auc: 0.97672, f1 score: 0.91613, precison: 0.94278, recall: 0.89096
# 释放显存分配器中空闲的显存
paddle.device.cuda.empty_cache()
3.7 对抗训练
class FGM(): # 对抗训练
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='embedding'):
# emb_name这个参数要换成模型中embedding的参数名
for name, param in self.model.named_parameters():
if not param.stop_gradient and emb_name in name:
self.backup[name] = param.clone().numpy()
norm = paddle.norm(param.grad) # 默认为2范数
if norm != 0:
r_at = epsilon * param.grad / norm
param.stop_gradient = True
param.add_(r_at)
param.stop_gradient = False
def restore(self, emb_name='embedding'):
# emb_name这个参数要换成模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.stop_gradient and emb_name in name:
assert name in self.backup
param = self.backup[name]
self.backup = {}
ckpt_dir = "fgm_ernie_ckpt"
def FGM_train(epochs, save_dir=ckpt_dir):
best_f1_score = 0
model.train()
fgm = FGM(model)
for epoch in range(1, epochs + 1):
global_step = 0 # 迭代次数
for step, batch in enumerate(train_data_loader, start=1):
tic_train = time.time()
length = len(train_data_loader)
input_ids, token_type_ids, labels = batch
# 计算模型输出、损失函数值、分类概率值、准确率、f1分数
logits = model(input_ids, token_type_ids)
loss = criterion(logits, labels)
probs = F.sigmoid(logits)
metric.update(probs, labels)
auc, f1_score, _, _ = metric.accumulate()
# 每迭代40次或batch训练完毕,打印损失函数值、准确率、f1分数、计算速度
global_step += 1
if global_step % 40 == 0 or global_step == length:
print(
"epoch: %d, batch: %d, loss: %.5f, auc: %.5f, f1 score: %.5f, time: %.2f s"
% (epoch, step, loss, auc, f1_score, (time.time() - tic_train)))
tic_train = time.time()
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# 对抗训练
fgm.attack() # embedding被修改了
logits = model(input_ids, token_type_ids)
loss = criterion(logits, labels)
loss.backward() # 反向传播,在正常的grad基础上,累加对抗训练的梯度
fgm.restore() # 恢复Embedding的参数
# 参数更新
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
# 每个epoch保存一次最佳模型参数
if not os.path.exists(save_dir):
os.makedirs(save_dir)
eval_f1_score = evaluate(model, criterion, metric, test_data_loader, id2label, if_return_results=False)
if eval_f1_score > best_f1_score:
best_f1_score = eval_f1_score
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
FGM_train(epochs=epochs, save_dir=ckpt_dir)
epoch: 1, batch: 40, loss: 0.37059, auc: 0.87029, f1 score: 0.77379, time: 0.69 s
epoch: 1, batch: 73, loss: 0.34276, auc: 0.87343, f1 score: 0.77299, time: 0.92 s
eval loss: 0.37995, auc: 0.88987, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
[2022-06-24 10:50:05,850] [ INFO] - tokenizer config file saved in fgm_ernie_ckpt/tokenizer_config.json
[2022-06-24 10:50:05,852] [ INFO] - Special tokens file saved in fgm_ernie_ckpt/special_tokens_map.json
epoch: 2, batch: 40, loss: 0.29749, auc: 0.87332, f1 score: 0.77661, time: 0.60 s
epoch: 2, batch: 73, loss: 0.37608, auc: 0.87443, f1 score: 0.77543, time: 0.80 s
eval loss: 0.37686, auc: 0.88516, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 3, batch: 40, loss: 0.28839, auc: 0.87727, f1 score: 0.77771, time: 0.61 s
epoch: 3, batch: 73, loss: 0.35874, auc: 0.87587, f1 score: 0.77572, time: 0.73 s
eval loss: 0.37785, auc: 0.88685, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 4, batch: 40, loss: 0.48132, auc: 0.87239, f1 score: 0.77316, time: 0.71 s
epoch: 4, batch: 73, loss: 0.39473, auc: 0.87462, f1 score: 0.77547, time: 0.73 s
eval loss: 0.37898, auc: 0.88575, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 5, batch: 40, loss: 0.49463, auc: 0.87467, f1 score: 0.77549, time: 0.56 s
epoch: 5, batch: 73, loss: 0.38273, auc: 0.87670, f1 score: 0.77555, time: 0.79 s
eval loss: 0.38142, auc: 0.88562, f1 score: 0.78142, precison: 0.78004, recall: 0.78280
[2022-06-24 10:57:18,421] [ INFO] - tokenizer config file saved in fgm_ernie_ckpt/tokenizer_config.json
[2022-06-24 10:57:18,424] [ INFO] - Special tokens file saved in fgm_ernie_ckpt/special_tokens_map.json
epoch: 6, batch: 40, loss: 0.40169, auc: 0.88226, f1 score: 0.78193, time: 0.61 s
epoch: 6, batch: 73, loss: 0.37306, auc: 0.87720, f1 score: 0.77583, time: 0.71 s
eval loss: 0.37932, auc: 0.88370, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 7, batch: 40, loss: 0.35651, auc: 0.87507, f1 score: 0.77076, time: 0.67 s
epoch: 7, batch: 73, loss: 0.29711, auc: 0.87667, f1 score: 0.77560, time: 0.70 s
eval loss: 0.37934, auc: 0.88511, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 8, batch: 40, loss: 0.27640, auc: 0.88613, f1 score: 0.78743, time: 0.60 s
epoch: 8, batch: 73, loss: 0.46059, auc: 0.87625, f1 score: 0.77547, time: 0.71 s
eval loss: 0.37983, auc: 0.88450, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 9, batch: 40, loss: 0.45253, auc: 0.87141, f1 score: 0.76842, time: 0.60 s
epoch: 9, batch: 73, loss: 0.33335, auc: 0.87772, f1 score: 0.77547, time: 0.78 s
eval loss: 0.37906, auc: 0.88502, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 10, batch: 40, loss: 0.38910, auc: 0.87572, f1 score: 0.77567, time: 0.66 s
epoch: 10, batch: 73, loss: 0.56978, auc: 0.87773, f1 score: 0.77547, time: 0.79 s
eval loss: 0.37916, auc: 0.88262, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
经过对抗训练后,模型在验证集上的最佳F1值表现为:
eval loss: 0.38142, auc: 0.88562, f1 score: 0.78142, precison: 0.78004, recall: 0.78280
import paddle
paddle.device.cuda.empty_cache()
3.8 提交结果
# 加载已经训练好的模型
model.set_dict(paddle.load('ernie_ckpt/model_state.pdparams'))
# 加载测试集
test_ds0 = LawDataset(datapath, 'test')
test_ds = MapDataset(test_ds0)
test_trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_length=640,
is_test=True
)
test_ds = test_ds.map(test_trans_func)
collate_fn = lambda samples, fn=Dict({
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id),
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
'ids': Stack(dtype="int32")
}): fn(samples)
test_ds_batch_sampler = BatchSampler(test_ds, batch_size=16, shuffle=False)
test_ds_data_loader = DataLoader(dataset=test_ds, batch_sampler=test_ds_batch_sampler, collate_fn=collate_fn)
import paddle.nn.functional as F
def data_reprocess():
# 生成预测结果
ids_ = []
y_prob = None
for _, (input_ids, token_type_ids, ids) in enumerate(test_ds_data_loader,start=1):
model.eval()
logits = model(input_ids, token_type_ids)
probs = F.sigmoid(logits)
if y_prob is not None:
y_prob = np.append(y_prob, probs.numpy(), axis=0)
else:
y_prob = probs.numpy()
ids_.extend(ids)
best_threshold = 0.32
# 参照https://aistudio.baidu.com/aistudio/projectdetail/4201483,使用了0.32作为阈值
# 可以对训练好的模型遍历阈值来找到最佳阈值
y_prob = y_prob > best_threshold
results = []
pos = 0
for event in test_ds0:
assert event['id'] == ids_[pos].item() # 确保是同一条信息
event['event_chain'] = []
for i in range(len(id2label)):
if y_prob[pos][i] == True:
event['event_chain'].append(id2label[i])
pos+=1
results.append(event)
return results
results = data_reprocess()
print(results[:5])
[{'id': 1460, 'text': '原告刘小美(女方)与被告吴京京(男方)于2005年经人介绍相识恋爱,2006年开始同居生活,2007年2月25日办理结婚登记手续,2009年9月16日生育男孩吴勇,原、被告结婚后,常为家庭琐事发生纠纷,2009年,原告刘小美经医院检查确定患精神分裂症,需长期服药治疗,现夫妻感情已经彻底破裂,遂提起离婚诉讼,要求与被告离婚', 'event_chain': ['Be_Born', 'Cohabit', 'Marry', 'Prosecute']}, {'id': 1461, 'text': '2011年12月,谢楠(女方)与被告王雄(男方)介绍相识恋爱,2012年5月10日登记结婚,2012年9月4日生育男孩王应龙;因婚前缺乏了解,未建立起牢固的夫妻感情,双方常为家庭琐事发生争吵,且被告对她和小孩缺少关心和爱护,亦不尽相应的义务,导致夫妻感情不和;2014年4月8日,双方再次发生争吵后,被告赶她和小孩回家,现双方夫妻感情已经破裂向法院起诉,要求与被告离婚,并要求抚养婚生男孩A,由被告支付抚养费', 'event_chain': ['Be_Born', 'Marry', 'Prosecute']}, {'id': 1462, 'text': '原告李文萱与丈夫李铁刚于2008年下半年相识恋爱,2010年5月9日生育一女,取名李小丽,2014年1月21日登记结婚。原、被告婚后因性格不合,常为家庭琐事吵闹,夫妻感情不和。小孩李小丽现随原告李文萱共同生活。原告认为夫妻感情已经彻底破裂,故诉至法院,要求离婚', 'event_chain': ['Be_Born', 'Marry', 'Prosecute']}, {'id': 1463, 'text': '田小妞(女方)与被告汪志诚(男方)于2001年6月14日经人介绍相识,2001年11月19日登记结婚,2002年12月30日生育女儿汪洋,2007年8月8日生育女儿汪果。双方因婚前了解不够,未建立牢固的夫妻感情,婚后不久就常为生活琐事发生争吵。被告对家庭、对小孩不尽义务,将打工所赚的钱用于打牌赌博、负债累累。双方夫妻感情已彻底破裂,且无和好可能,故她向法院起诉,要求离婚、依法分割夫妻共同财产并确定小孩汪洋、汪果的抚养权。', 'event_chain': ['Be_Born', 'Gamble', 'Marry', 'Prosecute']}, {'id': 1464, 'text': '孙建军与妻子田甜于1999年在温州打工时相识、恋爱,2001年办理结婚登记手续,2003年8月1日生育男孩孙翔,婚后,被告田甜性格暴躁,常因小事不如意就摔打家电,曾多次殴打他及小孩,2013年国庆节期间,被告出轨,被告的行为严重伤害了他,双方无法继续共同生活,现夫妻感情已经破裂,故他上述至湖南省桃江县人民法院,要求与原告离婚,并要求抚养小孩,依法分摊共同债务', 'event_chain': ['Be_Born', 'Domestic_Violence', 'Marry', 'Prosecute']}]
# 加载对抗训练得到的模型
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(id2label))
model.set_dict(paddle.load('fgm_ernie_ckpt/model_state.pdparams'))
results1 = data_reprocess()
print(results1[:5])
[2022-06-24 11:06:11,572] [ INFO] - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForSequenceClassification'> to load 'ernie-3.0-base-zh'.
[2022-06-24 11:06:11,574] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/ernie_3.0_base_zh.pdparams
[{'id': 1460, 'text': '原告刘小美(女方)与被告吴京京(男方)于2005年经人介绍相识恋爱,2006年开始同居生活,2007年2月25日办理结婚登记手续,2009年9月16日生育男孩吴勇,原、被告结婚后,常为家庭琐事发生纠纷,2009年,原告刘小美经医院检查确定患精神分裂症,需长期服药治疗,现夫妻感情已经彻底破裂,遂提起离婚诉讼,要求与被告离婚', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1461, 'text': '2011年12月,谢楠(女方)与被告王雄(男方)介绍相识恋爱,2012年5月10日登记结婚,2012年9月4日生育男孩王应龙;因婚前缺乏了解,未建立起牢固的夫妻感情,双方常为家庭琐事发生争吵,且被告对她和小孩缺少关心和爱护,亦不尽相应的义务,导致夫妻感情不和;2014年4月8日,双方再次发生争吵后,被告赶她和小孩回家,现双方夫妻感情已经破裂向法院起诉,要求与被告离婚,并要求抚养婚生男孩A,由被告支付抚养费', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1462, 'text': '原告李文萱与丈夫李铁刚于2008年下半年相识恋爱,2010年5月9日生育一女,取名李小丽,2014年1月21日登记结婚。原、被告婚后因性格不合,常为家庭琐事吵闹,夫妻感情不和。小孩李小丽现随原告李文萱共同生活。原告认为夫妻感情已经彻底破裂,故诉至法院,要求离婚', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1463, 'text': '田小妞(女方)与被告汪志诚(男方)于2001年6月14日经人介绍相识,2001年11月19日登记结婚,2002年12月30日生育女儿汪洋,2007年8月8日生育女儿汪果。双方因婚前了解不够,未建立牢固的夫妻感情,婚后不久就常为生活琐事发生争吵。被告对家庭、对小孩不尽义务,将打工所赚的钱用于打牌赌博、负债累累。双方夫妻感情已彻底破裂,且无和好可能,故她向法院起诉,要求离婚、依法分割夫妻共同财产并确定小孩汪洋、汪果的抚养权。', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1464, 'text': '孙建军与妻子田甜于1999年在温州打工时相识、恋爱,2001年办理结婚登记手续,2003年8月1日生育男孩孙翔,婚后,被告田甜性格暴躁,常因小事不如意就摔打家电,曾多次殴打他及小孩,2013年国庆节期间,被告出轨,被告的行为严重伤害了他,双方无法继续共同生活,现夫妻感情已经破裂,故他上述至湖南省桃江县人民法院,要求与原告离婚,并要求抚养小孩,依法分摊共同债务', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}]
# 保存结果文件
with open('submit.json','w') as f:
json.dump(results1,f)
小结
对比直接使用ernie3.0的基线,这里训练出来的两个模型在验证集上的效果相对较差,也不确定是否更合测试集的口味,但也算是有益的尝试。
相关数据集很干净,仍有完成其他想法的空间,如与使用trigger进行事件分类相比,这种多标签分类的稳健程度;event_chain是一个事件链,考虑事件发生的先后顺序和文本的相似度会不会有助于任务,甚至说能不能捣鼓个小型的事理图谱等。
原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4252788
更多推荐
所有评论(0)