法律领域篇章级多事件检测

本项目针对法律案件中存在触发词不明显或者不包含触发词的事件,试图建立稳健的事件检测模型,用于判断法律案件中所包含的各个事件对应的事件类型,进而对后续的事件元素抽取任务提供支持。

# 在V100 32G及以上环境中运行
!/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip --user
!pip install pyzmq==18.1.1
!pip install --upgrade paddlenlp

1 准备数据

数据格式如下:

{

"id": 1, 

"text": "赵四与妻子王五通过相亲认识,2011年登记结婚,婚后共生育三个孩子,后双方因感情不和,于2020年协议离婚,协议约定,离婚后,三个孩子在一年内跟随王五生活,赵四每月每个孩子支付2000元抚养费,2021年三个孩子向法院提起诉讼,要求赵四按照协议约定支付抚养费。", 

"classname": "婚姻家庭纠纷", 

"eventchain": [
        {"trigger": "结婚", "eventtype": "Marry", "argument": [{"husband": "赵四", "wife": "王五", "time": "2011年", "loc": ""}]}, 
    
        {"trigger": "生育", "eventtype": "BeBorn", "argument": [{"per": "三个孩子", "time": "婚后", "loc": ""}]}, 
    
        {"trigger": "离婚", "eventtype": "Other", "argument": [{"subjec": "赵四", "object": "王五", "context": "双方感情不和", "time": "2020年", "loc": ""}]}, 
    
        {"trigger": "诉讼", "eventtype": "Prosecute", "argument": [{"prosecutor": "孩子", "defendant": "赵四", "reason": "", "demand": "按照协议约定支付抚养费", "time": "2021年", "court": "法院"}]}], 
"caseresult": ["双方达成调解,被告同意按照协议约定支付抚养费。"
]

}

由于实际上目标任务是一个多分类任务,因此只保留了数据中id,text,event_type属性(其中event_typelabel代替),分别生成训练集、验证集和测试集。

from utils import id_label, data_prepare

datapath = 'data/data143875'

id2label, label2id = id_label(datapath)
train_dataset = data_prepare(datapath, label2id, mode='train')
dev_dataset = data_prepare(datapath, label2id, mode='dev')
test_dataset = data_prepare(datapath, label2id, mode='test')

print(train_dataset[0], '\n\n', id2label, '\n\n', label2id)

{'id': 1, 'text': '赵四与妻子王五通过相亲认识,2011年登记结婚,婚后共生育三个孩子,后双方因感情不和,于2020年协议离婚,协议约定,离婚后,三个孩子在一年内跟随王五生活,赵四每月每个孩子支付2000元抚养费,2021年三个孩子向法院提起诉讼,要求赵四按照协议约定支付抚养费。', 'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]} 

 {0: 'Be_Born', 1: 'Cohabit', 2: 'Derailment', 3: 'Domestic_Violence', 4: 'Gamble', 5: 'Marry', 6: 'Other', 7: 'Prosecute', 8: 'Raise', 9: 'Separation', 10: 'Support'} 

 {'Be_Born': 0, 'Cohabit': 1, 'Derailment': 2, 'Domestic_Violence': 3, 'Gamble': 4, 'Marry': 5, 'Other': 6, 'Prosecute': 7, 'Raise': 8, 'Separation': 9, 'Support': 10}

2 数据探索

EDA部分主要参考了https://github.com/kangyishuai/NEWS-TEXT-CLASSIFICATION/blob/master/EDA.ipynb ,将单标签改成了多标签

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
from collections import Counter
from matplotlib.font_manager import FontProperties

2.1 简单查看数据

df_train = pd.DataFrame.from_dict(train_dataset)
df_train.head()
idtextlabels
01赵四与妻子王五通过相亲认识,2011年登记结婚,婚后共生育三个孩子,后双方因感情不和,于20...[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...
12张先生与李女士于2010年登记结婚,2011年育有一子,但因两人婚前感情基础薄弱,婚后两人开...[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ...
23原告范文与妻子李云系夫妻关系,两人自2000年2月19日登记结婚。在范文与妻子李云婚姻关系存...[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ...
341999年,朱三与妻子倪静经人介绍认识,两人于2000年登记结婚。2014年,双方因感情不和...[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ...
45黄云和丈夫张山在1997年开始以夫妻名义同居生活在一起。双方于1998年9月6日生育大女儿张...[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ...
df_train.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 971 entries, 0 to 970
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   id      971 non-null    int64 
 1   text    971 non-null    object
 2   labels  971 non-null    object
dtypes: int64(1), object(2)
memory usage: 22.9+ KB

2.2 文本长度分布

2.2.1 文本长度描述

# 训练集的文本长度描述
df_train['lenth'] = df_train['text'].apply(lambda x:len(x))
df_train['lenth'].describe()
count     971.000000
mean      282.075180
std       453.396933
min         3.000000
25%       149.000000
50%       199.000000
75%       261.000000
max      5057.000000
Name: lenth, dtype: float64
# 测试集的文本长度描述
df_test = pd.DataFrame.from_dict(test_dataset)
df_test['lenth'] = df_test['text'].apply(lambda x:len(x))
df_test['lenth'].describe()
count    485.000000
mean     190.096907
std       94.097040
min       28.000000
25%      122.000000
50%      180.000000
75%      237.000000
max      597.000000
Name: lenth, dtype: float64

2.2.2 文本长度频数分布

fig, ax = plt.subplots(1,1,figsize=(12,4))

ax = plt.hist(x=df_train['lenth'], bins=100)
ax = plt.hist(x=df_test['lenth'], bins=100)

plt.xlim([0, max(max(df_train['lenth']), max(df_test['lenth']))])
plt.xlabel("length of sample")
plt.ylabel("number of sample")
plt.legend(['train_len','test_len'])

plt.show()


在这里插入图片描述

2.2.3 文本长度频率分布

plt.figure(figsize=(12,4))
ax = sns.distplot(df_train['lenth'], bins=100)
ax = sns.distplot(df_test['lenth'], bins=100)
plt.xlim([0, max(max(df_train['lenth']), max(df_test['lenth']))])
plt.xlabel("length of sample")
plt.ylabel("prob of sample")
plt.legend(['train_len','test_len'])
<matplotlib.legend.Legend at 0x7f7835469890>


在这里插入图片描述

2.2.4 同分布验证

# https://blog.csdn.net/weixin_30230009/article/details/122872228

import scipy
scipy.stats.ks_2samp(df_train['lenth'], df_test['lenth'])
KstestResult(statistic=0.1258836145115568, pvalue=6.357808192014591e-05)

可以看到,训练集和验证集均与测试集的文本长度分布存在偏移,这个地方可以进行考虑优化。

2.3 截断位置选择

2.3.1 正态性验证

log_train_len = np.log(1+df_train['lenth'])
log_test_len = np.log(1+df_test['lenth'])
_, lognormal_ks_pvalue = scipy.stats.kstest(rvs=log_train_len, cdf='norm')
print(lognormal_ks_pvalue)
trans_data, lam = scipy.stats.boxcox(df_train['lenth']+1)
print(scipy.stats.normaltest(trans_data))
0.0
NormaltestResult(statistic=288.1385347098726, pvalue=2.7009228305021675e-63)

两个方法得到的p值都很小,说明不符合正态分布,不能直接使用3σ原则来截断

2.3.2 看图猜截断长度

plt.figure(figsize=(12,4))
ax = sns.distplot(log_train_len)
ax = sns.distplot(log_test_len)
plt.xlabel("log length of sample")
plt.ylabel("prob of log")
plt.legend(['train_len','test_len'])
<matplotlib.legend.Legend at 0x7f78351aa610>


在这里插入图片描述

看log图可以看到训练集和测试集的频率在6.5左右的位置基本为0,因此根据np.exp(6.5)≈665来选择截断长度为640。

2.4 类别分析

2.4.1 类别分布


num_label = []
for j in range(len(id2label)):
    m = 0
    for i in range(len(df_train['labels'])):
        m += df_train['labels'][i][j]
    num_label.append(m)

plt.figure()
plt.bar(x=range(len(id2label)), height=num_label)
plt.xlabel("label")
plt.ylabel("number of sample")
plt.xticks(range(len(id2label)), list(id2label.values()), rotation=45)
plt.show()


在这里插入图片描述

2.4.2 类别长度

objs = [df_train[['id', 'lenth']], pd.DataFrame(df_train['labels'].tolist())]
ans1 = pd.concat(objs, axis=1)
ans2 = pd.melt(ans1, var_name='id_label', value_name='labels', id_vars=['id', 'lenth'])
ans2 = ans2[ans2['labels']!=0.0].reset_index(drop=True).drop('labels', axis=1)

plt.figure()
ax = sns.catplot(x='id_label', y='lenth', data=ans2, kind='strip')
plt.xticks(range(len(id2label)), list(id2label.values()), rotation=45)
([<matplotlib.axis.XTick at 0x7f7832d79090>,
  <matplotlib.axis.XTick at 0x7f7832d73710>,
  <matplotlib.axis.XTick at 0x7f7832d13fd0>,
  <matplotlib.axis.XTick at 0x7f7832d1c290>,
  <matplotlib.axis.XTick at 0x7f7832d1cc10>,
  <matplotlib.axis.XTick at 0x7f7832d24090>,
  <matplotlib.axis.XTick at 0x7f7832d73410>,
  <matplotlib.axis.XTick at 0x7f7832ce3ad0>,
  <matplotlib.axis.XTick at 0x7f7832ce3c50>,
  <matplotlib.axis.XTick at 0x7f7832ce7610>,
  <matplotlib.axis.XTick at 0x7f7832ce7b50>],
 <a list of 11 Text xticklabel objects>)




<Figure size 432x288 with 0 Axes>

在这里插入图片描述

可以看到该数据类别失衡,而参考来源中发现不同类别的文本长度不同,因而使用文本长度作为特征之一来应对失衡问题,但该数据不符合这种特征,要想一些别的办法来优化。

小结

对数据进行探索得到一些不绝对的结论:

  1. 训练集和测试集的文本长度分布存在差异,训练集文本长度的最大值和最小值差距较大。
  2. 训练集文本长度分布是非正态分布,无法使用3σ原则来对文本长度进行截断,最终结合图粗略选择了640作为截断长度。
  3. 训练集中文本类别不均衡,存在数据偏移现象,且与文本长度关联不大,因此训练中尝试加入了对抗训练。
  4. 训练集和验证集在文本长度、类型分布较测试集更相似,因此训练中由于数据不多而把验证集加入到训练集中时,采取了剪枝策略。

3 事件检测分类

3.1 从本地文件创建数据集

3.1和3.2是两种封装数据集的方法,选择一个即可。

from paddlenlp.datasets import load_dataset
from utils import id_label, data_prepare, data_split

datapath = 'data/data143875'

id2label, label2id = id_label(datapath)
train_dataset = data_prepare(datapath, label2id, mode='train')
test_dataset = data_prepare(datapath, label2id, mode='dev')

dataset = train_dataset + test_dataset # 训练集和验证集共同组成新的训练集
train_dataset, test_dataset = data_split(dataset) # 划分数据集

def read(dataset):
    for data in dataset:
        text, labels = data['text'], data['labels']
        yield {'text': text, 'labels': labels}

train_dataset = load_dataset(read, dataset=train_dataset,lazy=False) # dataset是read的参数
test_dataset = load_dataset(read, dataset=test_dataset,lazy=False)

print(id2label, '\n\n', label2id, '\n\n')
print("训练集样例:", train_dataset[291])
print("测试集样例:", test_dataset[291])

{0: 'Be_Born', 1: 'Cohabit', 2: 'Derailment', 3: 'Domestic_Violence', 4: 'Gamble', 5: 'Marry', 6: 'Other', 7: 'Prosecute', 8: 'Raise', 9: 'Separation', 10: 'Support'} 

 {'Be_Born': 0, 'Cohabit': 1, 'Derailment': 2, 'Domestic_Violence': 3, 'Gamble': 4, 'Marry': 5, 'Other': 6, 'Prosecute': 7, 'Raise': 8, 'Separation': 9, 'Support': 10} 


训练集样例: {‘text’: ‘原告李晓晓(女方)与被告陈武强(男方)于2007年10月26日办理结婚登记手续,于2009年2月24日生育一女陈安安,被告对原告态度野蛮,其父母唆使被告对原告冷淡,双方无法共同生活,双方曾协商离婚,其所达成离婚协议,被告不履行,现双方夫妻感情彻底破裂,为此,原告于2009年11月4日起诉要求离婚,婚生女跟随原告共同生活,被告按离婚协议承担子女抚养费每月500元,婚前嫁妆归原告所有,另合理分割夫妻共同财产,被告遗弃虐待家庭成员存在过错,应向原告支付50000元赔偿,现原告无固定住所,无经济来源,需被告给付20000元经济帮助金,诉讼费由被告承担’, ‘labels’: [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]}
测试集样例: {‘text’: ‘原告黄楠(女方)、被告李志斌(男方)经人介绍于1997年12月18日在邵阳县民政局登记结婚,婚后被告经常对原告进行殴打和人身诋毁,夜不归宿,对家庭不负责任,被告既吸毒又有外遇,与被告有不正当男女关系的女人罗静经常打电话骚扰原告,原、被告2015年7月因争吵打架,分居至今,原、被告婚后生育有一女李文文,生日1998年7月14日,一儿李磊,生日2003年9月4日,综上所述,原、被告的感情已完全破裂,再无和好的可能,故于2015年9月8日诉至湖南省邵阳市大祥区人民法院,请求:1、判决原、被告离婚;2、依法平均分割夫妻存续期间的共同财产;3、判决婚生小孩李文文(女儿)李磊(儿子)由原告直接抚养,被告每月支付抚养费2000元’, ‘labels’: [1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0]}

3.2 从 paddle.io.Dataset创建数据集

from paddle.io import Dataset
from paddlenlp.datasets import MapDataset
from utils import id_label, data_prepare, data_split
import json, os

datapath = 'data/data143875'

class LawDataset(Dataset):
    def __init__(self, path, *modes):
        self.id2label, self.label2id = id_label(datapath)
        self.data = []
        for mode in modes:
            data = data_prepare(datapath, self.label2id, mode=mode)
            self.data.extend(data)


    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

dataset = LawDataset(datapath, 'train', 'dev')
train_dataset, test_dataset = data_split(dataset) # 划分数据集
print(train_dataset[0])
print(test_dataset[0])

train_dataset = MapDataset(train_dataset)
test_dataset = MapDataset(test_dataset)
{'text': '原告陈红(女方)、被告陆磊(男方)于2004年正月经人介绍相识,2004年12月9日办理结婚登记,2005年5月10日婚生女孩陆小莉,由于原、被告脾气性格差异太大,特别是被告心胸狭隘又性情暴躁,稍不如意便要打骂原告,以致原、被告婚后感情一直不和,2011年农历12月,原、被告协议离婚之事发生口角,被告不由分说便将原告毒打一顿,原告父亲见状,便批评指责被告,被告恼羞成怒,又将原告的父亲毒打一顿,此后原、被告因夫妻感情不和而持续分居至今,期间,原告于2012年5月向法院起诉离婚未果,现夫妻感情已完全破裂,现于2015年6月16日再次起诉请求湖南省隆回县人民法院判令原、被告离婚;婚生小孩由原告抚养成年,由被告负担抚养费4万元', 'labels': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]}
{'text': '原告江俊与被告徐丽于1991年相识,1992年确立恋爱关系,1994年12月31日登记结婚,1995年5月2日生育一女江敏。婚后初期双方感情尚可,后因原告长期在外地工作,双方聚少离多,沟通不够,逐渐产生矛盾。原告于2013年6月4日向本院起诉离婚。2013年8月5日,本院判决驳回原告的诉讼请求。之后原告搬至其父母处居住。2014年3月4日,原告再次向本院起诉要求离婚,经调解,双方未达成一致协议。', 'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]}

3.3 将文本数据处理成模型可以接受的 feature

from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from paddlenlp.data import Stack, Tuple, Pad, Dict
from functools import partial

model_name = "ernie-3.0-base-zh"

tokenizer = AutoTokenizer.from_pretrained(model_name)

[2022-06-24 10:35:35,644] [    INFO] - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'ernie-3.0-base-zh'.
[2022-06-24 10:35:35,646] [    INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh_vocab.txt and saved to /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh
[2022-06-24 10:35:35,647] [    INFO] - Downloading ernie_3.0_base_zh_vocab.txt from https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh_vocab.txt
100%|██████████| 182k/182k [00:00<00:00, 22.2MB/s]
def convert_example(example, tokenizer, max_seq_length=640, is_test=False): # 数据预处理函数
    
    # tokenizer.encode方法能够完成切分token,映射token ID以及拼接特殊token
    tokenized_example = tokenizer.encode(text=example['text'], max_seq_len=max_seq_length, truncation=True)
    if not is_test:
        tokenized_example['labels'] = example['labels'] # 加上labels用于训练
    else:
        tokenized_example['ids'] = example['id']

    return tokenized_example

trans_func = partial( # 给convert_example传入参数
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=640,
    is_test=False
    )

train_dataset = train_dataset.map(trans_func)
test_dataset = test_dataset.map(trans_func)

print(train_dataset[0])
print(test_dataset[0])
{'input_ids': [1, 250, 612, 830, 536, 78, 291, 58, 77, 6, 171, 612, 891, 3548, 78, 654, 58, 77, 37, 1540, 17, 243, 136, 60, 8, 769, 965, 156, 474, 4, 1540, 17, 768, 136, 701, 139, 315, 38, 215, 1059, 883, 374, 4, 1464, 17, 317, 136, 530, 139, 1059, 21, 291, 751, 891, 96, 2113, 4, 190, 37, 250, 6, 171, 612, 2178, 266, 92, 318, 859, 712, 512, 19, 4, 169, 348, 10, 171, 612, 111, 1553, 2035, 3977, 311, 92, 182, 1386, 3004, 4, 1632, 16, 142, 221, 518, 41, 445, 2636, 250, 612, 4, 22, 600, 250, 6, 171, 612, 1059, 49, 345, 182, 7, 339, 16, 14, 4, 1490, 17, 194, 382, 768, 136, 4, 250, 6, 171, 612, 443, 454, 417, 1059, 46, 104, 34, 21, 270, 591, 4, 171, 612, 16, 190, 59, 178, 518, 174, 250, 612, 886, 445, 7, 1410, 4, 250, 612, 795, 601, 373, 498, 4, 518, 650, 480, 288, 319, 171, 612, 4, 171, 612, 2532, 2836, 33, 1758, 4, 311, 174, 250, 612, 5, 795, 601, 886, 445, 7, 1410, 4, 198, 49, 250, 6, 171, 612, 196, 752, 1478, 345, 182, 16, 14, 83, 303, 607, 59, 529, 268, 508, 4, 195, 143, 4, 250, 612, 37, 1541, 17, 317, 136, 253, 72, 245, 200, 1005, 417, 1059, 556, 228, 4, 87, 752, 1478, 345, 182, 265, 328, 62, 727, 1142, 4, 87, 37, 2016, 17, 515, 136, 1227, 139, 486, 218, 200, 1005, 647, 323, 677, 219, 244, 1310, 381, 308, 8, 119, 72, 245, 1079, 708, 250, 6, 171, 612, 417, 1059, 12048, 1059, 21, 96, 751, 190, 250, 612, 1815, 423, 33, 17, 4, 190, 171, 612, 383, 675, 1815, 423, 453, 397, 211, 183, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]}
{'input_ids': [1, 250, 612, 409, 1897, 54, 171, 612, 1416, 856, 37, 2304, 17, 156, 474, 4, 2117, 17, 524, 202, 1419, 329, 129, 135, 4, 2129, 17, 768, 136, 1962, 139, 883, 374, 215, 1059, 4, 2029, 17, 317, 136, 249, 139, 21, 227, 7, 291, 409, 1443, 12043, 1059, 49, 590, 195, 653, 58, 345, 182, 854, 48, 4, 49, 196, 250, 612, 84, 195, 11, 137, 31, 35, 25, 4, 653, 58, 968, 332, 417, 65, 4, 1195, 124, 16, 824, 4, 913, 956, 66, 21, 1894, 1707, 12043, 250, 612, 37, 1648, 17, 515, 136, 397, 139, 253, 89, 245, 200, 1005, 417, 1059, 12043, 1648, 17, 585, 136, 317, 139, 4, 89, 245, 1079, 448, 3059, 381, 250, 612, 5, 1005, 2065, 647, 323, 12043, 46, 49, 250, 612, 2136, 268, 63, 795, 746, 239, 529, 520, 12043, 1838, 17, 284, 136, 397, 139, 4, 250, 612, 486, 218, 253, 89, 245, 200, 1005, 41, 323, 417, 1059, 4, 60, 290, 273, 4, 653, 58, 556, 302, 33, 7, 600, 443, 454, 12043, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]}

3.4 组成batch

# collate_fn函数构造,将不同长度序列充到批中数据的最大长度,再将数据堆叠

collate_fn = lambda samples, fn=Dict({
    'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id),
    'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    'labels': Stack(dtype="float32")
}): fn(samples)

from paddle.io import DataLoader, BatchSampler

train_batch_sampler = BatchSampler(train_dataset, batch_size=16, shuffle=True)
train_data_loader = DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, collate_fn=collate_fn)

test_batch_sampler = BatchSampler(test_dataset, batch_size=16, shuffle=False)
test_data_loader = DataLoader(dataset=test_dataset, batch_sampler=test_batch_sampler, collate_fn=collate_fn)

3.5 定义模型网络和损失函数

from metric import MultiLabelReport
import paddle

max_steps = -1
epochs = 10
learning_rate = 2e-5
warmup_steps = 1000
id2label, label2id = id_label(datapath)

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(id2label))

num_training_steps = max_steps if max_steps > 0 else len(train_data_loader) * epochs

lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_steps)

# 梯度裁剪
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)

# 生成执行权重衰减所需的参数名称,所有偏差和分层参数被被排除
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]

optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    epsilon=1e-8,
    parameters=model.parameters(),
    grad_clip=clip,
    weight_decay=0.0,
    apply_decay_param_fun=lambda x: x in decay_params)

criterion = paddle.nn.BCEWithLogitsLoss()
metric = MultiLabelReport()
[2022-06-24 10:35:35,775] [    INFO] - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForSequenceClassification'> to load 'ernie-3.0-base-zh'.
[2022-06-24 10:35:35,777] [    INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh
[2022-06-24 10:35:35,780] [    INFO] - Downloading ernie_3.0_base_zh.pdparams from https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh.pdparams
100%|██████████| 452M/452M [00:06<00:00, 74.9MB/s] 
W0624 10:35:42.158780  1265 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0624 10:35:42.163237  1265 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.

3.6 开始训练

import time
from eval import evaluate
import paddle.nn.functional as F

ckpt_dir = "ernie_ckpt" # 训练过程中保存模型参数的文件夹
def train(epochs, save_dir=ckpt_dir):
    model.train()
    best_f1_score = 0
    for epoch in range(1, epochs + 1):
        global_step = 0 # 迭代次数
        for step, batch in enumerate(train_data_loader, start=1):
            tic_train = time.time()
            length = len(train_data_loader)
            input_ids, token_type_ids, labels = batch
            # 计算模型输出、损失函数值、分类概率值、准确率、f1分数
            logits = model(input_ids, token_type_ids)
            loss = criterion(logits, labels)
            probs = F.sigmoid(logits)
            metric.update(probs, labels)
            auc, f1_score, _, _ = metric.accumulate()

            # 每迭代40次或batch训练完毕,打印损失函数值、准确率、f1分数、计算速度
            global_step += 1
            if global_step % 40 == 0 or global_step == length:
                print(
                    "epoch: %d, batch: %d, loss: %.5f, auc: %.5f, f1 score: %.5f, time: %.2f s"
                    % (epoch, step, loss, auc, f1_score, (time.time() - tic_train))) # 每个batch用时
                tic_train = time.time()
            
            # 梯度回传,更新参数
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

        # 每个epoch保存一次最佳模型参数
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        eval_f1_score = evaluate(model, criterion, metric, test_data_loader, id2label, if_return_results=False)
        if eval_f1_score > best_f1_score:
            best_f1_score = eval_f1_score
            model.save_pretrained(save_dir)
            tokenizer.save_pretrained(save_dir)

train(epochs=epochs, save_dir=ckpt_dir)
epoch: 1, batch: 40, loss: 0.66966, auc: 0.55707, f1 score: 0.53233, time: 0.55 s
epoch: 1, batch: 73, loss: 0.61689, auc: 0.60806, f1 score: 0.53390, time: 0.52 s
eval loss: 0.61039, auc: 0.78400, f1 score: 0.65799, precison: 0.64456, recall: 0.67199


[2022-06-24 10:36:50,808] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:36:50,810] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 2, batch: 40, loss: 0.51279, auc: 0.82381, f1 score: 0.68527, time: 0.61 s
epoch: 2, batch: 73, loss: 0.43841, auc: 0.84239, f1 score: 0.71003, time: 0.76 s
eval loss: 0.42994, auc: 0.89576, f1 score: 0.77901, precison: 0.78145, recall: 0.77660


[2022-06-24 10:38:04,407] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:38:04,409] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 3, batch: 40, loss: 0.49173, auc: 0.87730, f1 score: 0.76658, time: 0.71 s
epoch: 3, batch: 73, loss: 0.44703, auc: 0.88049, f1 score: 0.77201, time: 0.81 s
eval loss: 0.37946, auc: 0.90521, f1 score: 0.79335, precison: 0.76374, recall: 0.82535


[2022-06-24 10:39:19,422] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:39:19,425] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 4, batch: 40, loss: 0.40879, auc: 0.89755, f1 score: 0.78716, time: 0.66 s
epoch: 4, batch: 73, loss: 0.45828, auc: 0.89706, f1 score: 0.78756, time: 0.78 s
eval loss: 0.35623, auc: 0.91176, f1 score: 0.79962, precison: 0.85628, recall: 0.75000


[2022-06-24 10:40:32,164] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:40:32,166] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 5, batch: 40, loss: 0.44394, auc: 0.90714, f1 score: 0.80619, time: 0.70 s
epoch: 5, batch: 73, loss: 0.34573, auc: 0.90582, f1 score: 0.80204, time: 0.71 s
eval loss: 0.33909, auc: 0.91683, f1 score: 0.81329, precison: 0.86125, recall: 0.77039


[2022-06-24 10:41:46,671] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:41:46,674] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 6, batch: 40, loss: 0.39112, auc: 0.91125, f1 score: 0.80203, time: 0.60 s
epoch: 6, batch: 73, loss: 0.32801, auc: 0.91551, f1 score: 0.80923, time: 0.74 s
eval loss: 0.31262, auc: 0.92841, f1 score: 0.82943, precison: 0.92100, recall: 0.75443


[2022-06-24 10:43:00,391] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:43:00,393] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 7, batch: 40, loss: 0.30550, auc: 0.92791, f1 score: 0.81883, time: 0.62 s
epoch: 7, batch: 73, loss: 0.25464, auc: 0.93432, f1 score: 0.83203, time: 0.81 s
eval loss: 0.27542, auc: 0.94808, f1 score: 0.85993, precison: 0.95657, recall: 0.78103


[2022-06-24 10:44:18,270] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:44:18,272] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 8, batch: 40, loss: 0.27216, auc: 0.94950, f1 score: 0.86048, time: 0.71 s
epoch: 8, batch: 73, loss: 0.31507, auc: 0.95240, f1 score: 0.86804, time: 0.79 s
eval loss: 0.24641, auc: 0.96001, f1 score: 0.88509, precison: 0.92300, recall: 0.85018


[2022-06-24 10:45:37,531] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:45:37,534] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 9, batch: 40, loss: 0.21307, auc: 0.96411, f1 score: 0.89286, time: 0.69 s
epoch: 9, batch: 73, loss: 0.16695, auc: 0.96628, f1 score: 0.89663, time: 0.76 s
eval loss: 0.21330, auc: 0.97392, f1 score: 0.90639, precison: 0.92048, recall: 0.89273


[2022-06-24 10:46:56,266] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:46:56,269] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json


epoch: 10, batch: 40, loss: 0.20534, auc: 0.97830, f1 score: 0.91651, time: 0.77 s
epoch: 10, batch: 73, loss: 0.18349, auc: 0.97840, f1 score: 0.91527, time: 0.79 s
eval loss: 0.20442, auc: 0.97672, f1 score: 0.91613, precison: 0.94278, recall: 0.89096


[2022-06-24 10:48:12,789] [    INFO] - tokenizer config file saved in ernie_ckpt/tokenizer_config.json
[2022-06-24 10:48:12,791] [    INFO] - Special tokens file saved in ernie_ckpt/special_tokens_map.json

此时模型在验证集上的最佳F1值表现为:

eval loss: 0.20442, auc: 0.97672, f1 score: 0.91613, precison: 0.94278, recall: 0.89096

# 释放显存分配器中空闲的显存
paddle.device.cuda.empty_cache()

3.7 对抗训练

class FGM(): # 对抗训练
    def __init__(self, model):
        self.model = model
        self.backup = {}

    def attack(self, epsilon=1., emb_name='embedding'):
        # emb_name这个参数要换成模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if not param.stop_gradient and emb_name in name:
                self.backup[name] = param.clone().numpy()
                norm = paddle.norm(param.grad) # 默认为2范数
                if norm != 0:
                    r_at = epsilon * param.grad / norm
                    param.stop_gradient = True
                    param.add_(r_at)
                    param.stop_gradient = False

    def restore(self, emb_name='embedding'):
        # emb_name这个参数要换成模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.stop_gradient and emb_name in name: 
                assert name in self.backup
                param = self.backup[name]
        self.backup = {}
ckpt_dir = "fgm_ernie_ckpt"

def FGM_train(epochs, save_dir=ckpt_dir):
    best_f1_score = 0
    model.train()
    fgm = FGM(model)
    for epoch in range(1, epochs + 1):
        global_step = 0 # 迭代次数
        for step, batch in enumerate(train_data_loader, start=1):
            tic_train = time.time()
            length = len(train_data_loader)
            input_ids, token_type_ids, labels = batch
            # 计算模型输出、损失函数值、分类概率值、准确率、f1分数
            logits = model(input_ids, token_type_ids)
            loss = criterion(logits, labels)
            probs = F.sigmoid(logits)
            metric.update(probs, labels)
            auc, f1_score, _, _ = metric.accumulate()

            # 每迭代40次或batch训练完毕,打印损失函数值、准确率、f1分数、计算速度
            global_step += 1
            if global_step % 40 == 0 or global_step == length:
                print(
                    "epoch: %d, batch: %d, loss: %.5f, auc: %.5f, f1 score: %.5f, time: %.2f s"
                    % (epoch, step, loss, auc, f1_score, (time.time() - tic_train)))
                tic_train = time.time()

            loss.backward()

            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # 对抗训练
            fgm.attack() # embedding被修改了
            logits = model(input_ids, token_type_ids)
            loss = criterion(logits, labels)
            loss.backward() # 反向传播,在正常的grad基础上,累加对抗训练的梯度
            fgm.restore() # 恢复Embedding的参数
            
            # 参数更新
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

        # 每个epoch保存一次最佳模型参数
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        eval_f1_score = evaluate(model, criterion, metric, test_data_loader, id2label, if_return_results=False)
        if eval_f1_score > best_f1_score:
            best_f1_score = eval_f1_score
            model.save_pretrained(save_dir)
            tokenizer.save_pretrained(save_dir)

FGM_train(epochs=epochs, save_dir=ckpt_dir)
epoch: 1, batch: 40, loss: 0.37059, auc: 0.87029, f1 score: 0.77379, time: 0.69 s
epoch: 1, batch: 73, loss: 0.34276, auc: 0.87343, f1 score: 0.77299, time: 0.92 s
eval loss: 0.37995, auc: 0.88987, f1 score: 0.77613, precison: 0.76284, recall: 0.78989


[2022-06-24 10:50:05,850] [    INFO] - tokenizer config file saved in fgm_ernie_ckpt/tokenizer_config.json
[2022-06-24 10:50:05,852] [    INFO] - Special tokens file saved in fgm_ernie_ckpt/special_tokens_map.json


epoch: 2, batch: 40, loss: 0.29749, auc: 0.87332, f1 score: 0.77661, time: 0.60 s
epoch: 2, batch: 73, loss: 0.37608, auc: 0.87443, f1 score: 0.77543, time: 0.80 s
eval loss: 0.37686, auc: 0.88516, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 3, batch: 40, loss: 0.28839, auc: 0.87727, f1 score: 0.77771, time: 0.61 s
epoch: 3, batch: 73, loss: 0.35874, auc: 0.87587, f1 score: 0.77572, time: 0.73 s
eval loss: 0.37785, auc: 0.88685, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 4, batch: 40, loss: 0.48132, auc: 0.87239, f1 score: 0.77316, time: 0.71 s
epoch: 4, batch: 73, loss: 0.39473, auc: 0.87462, f1 score: 0.77547, time: 0.73 s
eval loss: 0.37898, auc: 0.88575, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 5, batch: 40, loss: 0.49463, auc: 0.87467, f1 score: 0.77549, time: 0.56 s
epoch: 5, batch: 73, loss: 0.38273, auc: 0.87670, f1 score: 0.77555, time: 0.79 s
eval loss: 0.38142, auc: 0.88562, f1 score: 0.78142, precison: 0.78004, recall: 0.78280


[2022-06-24 10:57:18,421] [    INFO] - tokenizer config file saved in fgm_ernie_ckpt/tokenizer_config.json
[2022-06-24 10:57:18,424] [    INFO] - Special tokens file saved in fgm_ernie_ckpt/special_tokens_map.json


epoch: 6, batch: 40, loss: 0.40169, auc: 0.88226, f1 score: 0.78193, time: 0.61 s
epoch: 6, batch: 73, loss: 0.37306, auc: 0.87720, f1 score: 0.77583, time: 0.71 s
eval loss: 0.37932, auc: 0.88370, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 7, batch: 40, loss: 0.35651, auc: 0.87507, f1 score: 0.77076, time: 0.67 s
epoch: 7, batch: 73, loss: 0.29711, auc: 0.87667, f1 score: 0.77560, time: 0.70 s
eval loss: 0.37934, auc: 0.88511, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 8, batch: 40, loss: 0.27640, auc: 0.88613, f1 score: 0.78743, time: 0.60 s
epoch: 8, batch: 73, loss: 0.46059, auc: 0.87625, f1 score: 0.77547, time: 0.71 s
eval loss: 0.37983, auc: 0.88450, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 9, batch: 40, loss: 0.45253, auc: 0.87141, f1 score: 0.76842, time: 0.60 s
epoch: 9, batch: 73, loss: 0.33335, auc: 0.87772, f1 score: 0.77547, time: 0.78 s
eval loss: 0.37906, auc: 0.88502, f1 score: 0.77613, precison: 0.76284, recall: 0.78989
epoch: 10, batch: 40, loss: 0.38910, auc: 0.87572, f1 score: 0.77567, time: 0.66 s
epoch: 10, batch: 73, loss: 0.56978, auc: 0.87773, f1 score: 0.77547, time: 0.79 s
eval loss: 0.37916, auc: 0.88262, f1 score: 0.77613, precison: 0.76284, recall: 0.78989

经过对抗训练后,模型在验证集上的最佳F1值表现为:

eval loss: 0.38142, auc: 0.88562, f1 score: 0.78142, precison: 0.78004, recall: 0.78280

import paddle
paddle.device.cuda.empty_cache()

3.8 提交结果

# 加载已经训练好的模型
model.set_dict(paddle.load('ernie_ckpt/model_state.pdparams'))

# 加载测试集
test_ds0 = LawDataset(datapath, 'test')
test_ds = MapDataset(test_ds0)

test_trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=640,
    is_test=True
    )

test_ds = test_ds.map(test_trans_func)

collate_fn = lambda samples, fn=Dict({
    'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id),
    'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    'ids': Stack(dtype="int32")
}): fn(samples)

test_ds_batch_sampler = BatchSampler(test_ds, batch_size=16, shuffle=False)
test_ds_data_loader = DataLoader(dataset=test_ds, batch_sampler=test_ds_batch_sampler, collate_fn=collate_fn)
import paddle.nn.functional as F

def data_reprocess():
    # 生成预测结果
    ids_ = []
    y_prob = None
    for _, (input_ids, token_type_ids, ids) in enumerate(test_ds_data_loader,start=1):
        model.eval()
        logits = model(input_ids, token_type_ids)
        probs = F.sigmoid(logits)
        if y_prob is not None:
            y_prob = np.append(y_prob, probs.numpy(), axis=0)
        else:
            y_prob = probs.numpy()
        ids_.extend(ids)

    best_threshold = 0.32 
    # 参照https://aistudio.baidu.com/aistudio/projectdetail/4201483,使用了0.32作为阈值
    # 可以对训练好的模型遍历阈值来找到最佳阈值

    y_prob = y_prob > best_threshold
    results = []
    pos = 0
    for event in test_ds0:
        assert event['id'] == ids_[pos].item() # 确保是同一条信息
        event['event_chain'] = []
        for i in range(len(id2label)):
            if y_prob[pos][i] == True:
                event['event_chain'].append(id2label[i])
                
        pos+=1
        results.append(event)
    return results

results = data_reprocess()
print(results[:5])
[{'id': 1460, 'text': '原告刘小美(女方)与被告吴京京(男方)于2005年经人介绍相识恋爱,2006年开始同居生活,2007年2月25日办理结婚登记手续,2009年9月16日生育男孩吴勇,原、被告结婚后,常为家庭琐事发生纠纷,2009年,原告刘小美经医院检查确定患精神分裂症,需长期服药治疗,现夫妻感情已经彻底破裂,遂提起离婚诉讼,要求与被告离婚', 'event_chain': ['Be_Born', 'Cohabit', 'Marry', 'Prosecute']}, {'id': 1461, 'text': '2011年12月,谢楠(女方)与被告王雄(男方)介绍相识恋爱,2012年5月10日登记结婚,2012年9月4日生育男孩王应龙;因婚前缺乏了解,未建立起牢固的夫妻感情,双方常为家庭琐事发生争吵,且被告对她和小孩缺少关心和爱护,亦不尽相应的义务,导致夫妻感情不和;2014年4月8日,双方再次发生争吵后,被告赶她和小孩回家,现双方夫妻感情已经破裂向法院起诉,要求与被告离婚,并要求抚养婚生男孩A,由被告支付抚养费', 'event_chain': ['Be_Born', 'Marry', 'Prosecute']}, {'id': 1462, 'text': '原告李文萱与丈夫李铁刚于2008年下半年相识恋爱,2010年5月9日生育一女,取名李小丽,2014年1月21日登记结婚。原、被告婚后因性格不合,常为家庭琐事吵闹,夫妻感情不和。小孩李小丽现随原告李文萱共同生活。原告认为夫妻感情已经彻底破裂,故诉至法院,要求离婚', 'event_chain': ['Be_Born', 'Marry', 'Prosecute']}, {'id': 1463, 'text': '田小妞(女方)与被告汪志诚(男方)于2001年6月14日经人介绍相识,2001年11月19日登记结婚,2002年12月30日生育女儿汪洋,2007年8月8日生育女儿汪果。双方因婚前了解不够,未建立牢固的夫妻感情,婚后不久就常为生活琐事发生争吵。被告对家庭、对小孩不尽义务,将打工所赚的钱用于打牌赌博、负债累累。双方夫妻感情已彻底破裂,且无和好可能,故她向法院起诉,要求离婚、依法分割夫妻共同财产并确定小孩汪洋、汪果的抚养权。', 'event_chain': ['Be_Born', 'Gamble', 'Marry', 'Prosecute']}, {'id': 1464, 'text': '孙建军与妻子田甜于1999年在温州打工时相识、恋爱,2001年办理结婚登记手续,2003年8月1日生育男孩孙翔,婚后,被告田甜性格暴躁,常因小事不如意就摔打家电,曾多次殴打他及小孩,2013年国庆节期间,被告出轨,被告的行为严重伤害了他,双方无法继续共同生活,现夫妻感情已经破裂,故他上述至湖南省桃江县人民法院,要求与原告离婚,并要求抚养小孩,依法分摊共同债务', 'event_chain': ['Be_Born', 'Domestic_Violence', 'Marry', 'Prosecute']}]
# 加载对抗训练得到的模型
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(id2label))
model.set_dict(paddle.load('fgm_ernie_ckpt/model_state.pdparams'))
results1 = data_reprocess()
print(results1[:5])
[2022-06-24 11:06:11,572] [    INFO] - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForSequenceClassification'> to load 'ernie-3.0-base-zh'.
[2022-06-24 11:06:11,574] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/ernie_3.0_base_zh.pdparams


[{'id': 1460, 'text': '原告刘小美(女方)与被告吴京京(男方)于2005年经人介绍相识恋爱,2006年开始同居生活,2007年2月25日办理结婚登记手续,2009年9月16日生育男孩吴勇,原、被告结婚后,常为家庭琐事发生纠纷,2009年,原告刘小美经医院检查确定患精神分裂症,需长期服药治疗,现夫妻感情已经彻底破裂,遂提起离婚诉讼,要求与被告离婚', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1461, 'text': '2011年12月,谢楠(女方)与被告王雄(男方)介绍相识恋爱,2012年5月10日登记结婚,2012年9月4日生育男孩王应龙;因婚前缺乏了解,未建立起牢固的夫妻感情,双方常为家庭琐事发生争吵,且被告对她和小孩缺少关心和爱护,亦不尽相应的义务,导致夫妻感情不和;2014年4月8日,双方再次发生争吵后,被告赶她和小孩回家,现双方夫妻感情已经破裂向法院起诉,要求与被告离婚,并要求抚养婚生男孩A,由被告支付抚养费', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1462, 'text': '原告李文萱与丈夫李铁刚于2008年下半年相识恋爱,2010年5月9日生育一女,取名李小丽,2014年1月21日登记结婚。原、被告婚后因性格不合,常为家庭琐事吵闹,夫妻感情不和。小孩李小丽现随原告李文萱共同生活。原告认为夫妻感情已经彻底破裂,故诉至法院,要求离婚', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1463, 'text': '田小妞(女方)与被告汪志诚(男方)于2001年6月14日经人介绍相识,2001年11月19日登记结婚,2002年12月30日生育女儿汪洋,2007年8月8日生育女儿汪果。双方因婚前了解不够,未建立牢固的夫妻感情,婚后不久就常为生活琐事发生争吵。被告对家庭、对小孩不尽义务,将打工所赚的钱用于打牌赌博、负债累累。双方夫妻感情已彻底破裂,且无和好可能,故她向法院起诉,要求离婚、依法分割夫妻共同财产并确定小孩汪洋、汪果的抚养权。', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}, {'id': 1464, 'text': '孙建军与妻子田甜于1999年在温州打工时相识、恋爱,2001年办理结婚登记手续,2003年8月1日生育男孩孙翔,婚后,被告田甜性格暴躁,常因小事不如意就摔打家电,曾多次殴打他及小孩,2013年国庆节期间,被告出轨,被告的行为严重伤害了他,双方无法继续共同生活,现夫妻感情已经破裂,故他上述至湖南省桃江县人民法院,要求与原告离婚,并要求抚养小孩,依法分摊共同债务', 'event_chain': ['Be_Born', 'Marry', 'Prosecute', 'Separation']}]
# 保存结果文件
with open('submit.json','w') as f:
    json.dump(results1,f)

小结

对比直接使用ernie3.0的基线,这里训练出来的两个模型在验证集上的效果相对较差,也不确定是否更合测试集的口味,但也算是有益的尝试。

相关数据集很干净,仍有完成其他想法的空间,如与使用trigger进行事件分类相比,这种多标签分类的稳健程度;event_chain是一个事件链,考虑事件发生的先后顺序和文本的相似度会不会有助于任务,甚至说能不能捣鼓个小型的事理图谱等。

原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4252788

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐