使用PaddleNLP识别垃圾邮件(三):用ELECTRA做英文邮件标题分类
本文是《使用PaddleNLP识别垃圾邮件》系列第三篇,该系列持续更新中……系列背景介绍:《使用PaddleNLP识别垃圾邮件》系列项目,针对当前企业面临的垃圾邮件问题,尝试使用深度学习的方法,探索多语言垃圾邮件的内容、标题提取与分类识别。该系列还有一个姊妹篇,《使用PaddleNLP进行恶意网页识别》,欢迎感兴趣的读者点进来交流评论。系列目录使用PaddleNLP识别垃圾邮件(一):准确率98.
本文是《使用PaddleNLP识别垃圾邮件》系列第三篇,该系列持续更新中……
系列背景介绍:《使用PaddleNLP识别垃圾邮件》系列项目,针对当前企业面临的垃圾邮件问题,尝试使用深度学习的方法,探索多语言垃圾邮件的内容、标题提取与分类识别。
该系列还有一个姊妹篇,《使用PaddleNLP进行恶意网页识别》,欢迎感兴趣的读者点进来交流评论。
系列目录
- 使用PaddleNLP识别垃圾邮件(一):准确率98.5%的垃圾邮件分类器
- 使用PaddleNLP的文本分类LSTM模型,提取中文邮件内容判断邮件是否为垃圾邮件。
- 使用PaddleNLP识别垃圾邮件(二):用BERT做中文邮件内容分类
- 使用PaddleNLP的BERT预训练模型,根据提取的中文邮件内容判断邮件是否为垃圾邮件。
- 使用PaddleNLP识别垃圾邮件(三):用ELECTRA做英文邮件标题分类
- 介绍在Python中解析eml邮件内容的办法:email模块和mmpi库;
- 使用PaddleNLP的ELECTRA预训练模型,根据提取的英文邮件标题判断邮件是否为垃圾邮件。
- 使用PaddleNLP识别垃圾邮件(四):用RoBERTa做中文邮件标题分类
- 升级到最新自定义数据集方法;
- 使用PaddleNLP模型库,大幅简化开发流程;
- 使用PaddleNLP的RoBERTa预训练模型,根据提取的中文邮件标题判断邮件是否为垃圾邮件;
- 完成完整的批量邮件分类部署流程。
背景知识:ELECTRA算法
本项目体验的是——号称在吊打BERT的同时,算力能节约一半以上的ELECTRA算法。
ELECTRA将BERT与类似于GAN的结构相结合,并辅以新的预训练任务来做预训练。结果:在更少的参数量和数据下,效果超越BERT,并且仅用1/4的算力就达到了SOTA模型RoBERTa的效果。
实现
NLP式的Generator-Discriminator
- ELECTRA最主要的贡献是提出了新的预训练任务和框架,把生成式的Masked language model(MLM)预训练任务改成了判别式的Replaced token detection(RTD)任务,判断当前token是否被语言模型替换过。
- 一点思考:随机替换一些输入中的字词,再让BERT去预测是否替换过可以吗?答案:不可以,因为随机替换过于简单,效果不好。
- 为了让替换更加真实,作者提出了利用一个基于MLM的Generator来替换example中的某些个token,然后丢给Discriminator来判别。
目标函数
由于生成器的输入输出都为句子,而句子中的字词都是离散的,因此判别器的梯度无法传给生成器,因此生成器的训练目标仍旧和BERT一样,而RTD的训练就是传统的二分类:
所以整个预训练框架的loss为:
因为判别器的任务相对来说容易些,RTD loss相对MLM loss会很小,因此加上一个系数,作者训练时使用了50
ELECTRA与GAN的区别
总结
关于ELECTRA,知乎上有个非常精辟的总结:
ELECTRA和BERT最大的不同应该是在于两个方面:
- masked(replaced) tokens的选择
- training objective
第一个,token的选择BERT是随机的,这意味着什么呢?比如句子“我想吃苹果”,BERT可以mask为“我想吃苹[MASK]”,这样一来实际上去学它就很简单,如果mask为“我[MASK]吃苹果”,那么去学这个“想”就相对困难了。
换句话说,BERT的mask可能会有很多简单的token,去学这些token就算是简单的bilstm都可以做的。这样一来,一个简单的想法就是,不随机mask,去专门选那些对模型来说学习困难的token。
怎么做呢?这就是ELECTRA非常牛逼的地方了,train一个简单的MLM,当做模型对训练难度的先验,简单的自动过滤(在这里就是sample出来的和原句子一样),复杂的后面再学。
还是举“我想吃苹果”这个例子。比如我这里还是mask为“我想吃苹[MASK]”,MLM这个生成器可以以很高的概率sample到“果”,但是对“我[MASK]想吃苹果”,MLM就很难说大概率采样到“想”了,也可能是“不”、“真”等等……
总的来说,MLM的作用就是为自动选择masked tokens提供了一种非常有效的方法!
第二个,既然MLM选择了一些token,那么该怎么去学呢?当然这个地方也可以像BERT那样,如果MLM采样的保持不变,就相当于原BERT中不mask;如果变了,就mask,然后再用BERT的方法去train。
然而ELECTRA另辟蹊径,用一个二分类去判断每个token是否已经被换过了。这就把一个DAE(或者LM)任务转换为了一个分类任务(或者序列标注)。这有两个好处:
1)每个token都能contribute to some extent
2)缓解distribution的问题
第一点是和MLM联系起来的(这也是ELECTRA精妙的地方了)。如果MLM牛逼,那么discriminate的难度就很大,从而就可以看作是hard example。
第二点是,如果我们像BERT那样去预测真正的token,也即通过一个classifier [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1vxEU1BO-1635902021512)(https://www.zhihu.com/equation?tex=%5Ctilde%7BC%7D%5Cin%5Cmathbb%7BR%7D%5E%7Bd%5Ctimes+%7CV%7C%7D)] 的话,那么它相比二分类器 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qCHZsbAv-1635902021520)(https://www.zhihu.com/equation?tex=C%5Cin%5Cmathbb%7BR%7D%5E%7Bd%5Ctimes+2%7D)] 而言就需要更多的计算量,而且还要suffer 由于 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZUNQmeES-1635902021529)(https://www.zhihu.com/equation?tex=%7CV%7C)] 较大导致的分布问题。
以上两点总结起来就是:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aH7Bia5y-1635902021536)(https://www.zhihu.com/equation?tex=%5Cboxed%7B%5B%5Ccolor%7Bblue%7D%7B%5Ctext%7BBERT%7D%7D%5D%3A%5Ctext%7BMasked+tokens%7D%5Crightarrow%5Ctext%7BPredict+them%7D%7D+%5C%5C+%5Ctext%7BFaster%7D%5CDownarrow+%5Ctext%7BStronger%7D+%5C%5C+%5Cboxed%7B%5B%5Ccolor%7Bred%7D%7B%5Ctext%7BELECTRA%7D%7D%5D%3A+%5Ctext%7BReplaced+tokens%7D%5Crightarrow%5Ctext%7BDistinguish+them%7D%7D)]
数据集介绍
TREC 2006 Spam Track Public Corpora是一个公开的垃圾邮件语料库,由国际文本检索会议提供,分为英文数据集(trec06p)和中文数据集(trec06c),其中所含的邮件均来源于真实邮件保留了邮件的原有格式和内容。
除TREC 2006外,还有TREC 2005和TREC 2007的英文垃圾邮件数据集(对,没有中文),本项目中,仅使用TREC 2006提供的英文数据集进行演示。TREC 2005-2007的垃圾邮件数据集,均已整理在项目挂载的数据集中,感兴趣的读者可以自行fork。
文件目录形式:delay和full分别是一种垃圾邮件过滤器的过滤机制,full目录下,是理想的邮件分类结果,我们可以视为研究的标签。
trec06c
│
└───data
│ │ 000
│ │ 001
│ │ ...
│ └───215
└───delay
│ │ index
└───full
│ │ index
一、环境配置
本项目基于Paddle 2.0 编写,如果你的环境不是本版本,请先参考官网安装 Paddle 2.0 。
# 加载开源邮件快速检测工具库mmpi
!pip install mmpi
!pip install yara-python
# 导入相关的模块
import re
import jieba
import os
import random
import paddle
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Pad, Tuple
import paddle.nn.functional as F
import paddle.nn as nn
from visualdl import LogWriter
import numpy as np
from functools import partial #partial()函数可以用来固定某些参数值,并返回一个新的callable对象
from tqdm import tqdm
import codecs
import chardet
import email
from mmpi import mmpi
print(paddle.__version__)
2.0.2
https://bbs.pediy.com/thread-264909.html
二、数据加载
2.1 数据集准备
# 解压数据集
!tar xvf data/data89631/trec06p.tgz
随机选择一个文件,查看邮件文件的具体内容。请注意,一定要设置errors='ignore'
,否则可能因为个别字符无法识别,导致内容无法读取。
f = open('trec06p/data/100/059', 'r', encoding='utf-8', errors='ignore')
text = ''
for line in f:
line = line.strip().strip('\n')
if len(line) > 1:
print(line)
text = text + line
我们可以看到,在邮件的正文开始前,有相当多的信息,比如收件人、发件人、邮件标题、时间等。但是在本项目使用的多语言邮件数据集中,字符编码格式也是个让人头疼的问题,比如下面这封邮件,就属于典型的垃圾邮件,标题是比较混乱的符号。
Received: from 221.201.116.189 (unknown [221.201.116.189])
by jalapeno.cc.columbia.edu (8.13.0/8.13.0) with ESMTP id j1PLKJKZ017426
for <romanianclub@send.columbia.edu>; Sun, 20 Feb 2005 23:23:04 -0500 (EST)
Received: from 60.199.250.207 by ; Sun, 20 Feb 2005 22:18:42 -0100
Message-ID: <IJDNTHVEZCAQUORIDVRANKF@yahoo.com>
From: "" <j90jil0pipojgh@naver.com>
Reply-To: "" <j90jil0pipojgh7@naver.com>
To: romanianclub@send.columbia.edu
Subject: ,, = ,, ̵ ּ. ʽϴ...
Date: Sun, 20 Feb 2005 18:21:42 -0500
X-Mailer: The Bat! (v1.52f) Business
……
由于原数据集中这些问题比较多,因此,在展开具体的邮件内容解析前这里先统一进行转码。
!mkdir trec06p/change_code
# 读取标签文件信息
f = open('trec06p/full/index', 'r')
for line in tqdm(f):
str_list = line.split(" ")
# 读取文件,获取字符集
content = codecs.open('trec06p/'+str_list[1][3:].replace("\n", ""),'rb').read()
source_encoding = chardet.detect(content)['encoding']
if not os.path.exists('trec06p/change_code' + str_list[1][7:-5].replace("\n", "")):
os.makedirs('trec06p/change_code' + str_list[1][7:-5].replace("\n", ""))
# 个别文件的source_encoding是None,这里要先进行筛选
if source_encoding is None:
pass
# 对字符集不是utf-8格式的文件尝试转码
elif source_encoding != 'utf-8':
# 转码如果失败,就跳过该文件
try:
content = content.decode(source_encoding).encode('utf-8')
codecs.open('trec06p/change_code'+str_list[1][7:].replace("\n", ""),'wb').write(content)
except UnicodeDecodeError:
print(str_list[1][7:].replace("\n", "") + "读取失败")
pass
# 字符集是utf-8格式的文件直接保存
else:
codecs.open('trec06p/change_code'+str_list[1][7:].replace("\n", ""),'wb').write(content)
2.2 解析邮件内容:Python的Email模块
首先,我们来看看Python自带的email模块,官方文档这样写到:
smtplib
和nntplib
等模块的功能。
具体用法上:
email.message
: 表示一封电子邮件信息email.parser
: 解析电子邮件信息email.generator
: 生成 MIME 文档email.policy
: 策略对象email.errors
: 异常和缺陷类email.headerregistry
: 自定义标头对象email.contentmanager
: 管理 MIME 内容
如果通过Email模块查看转码后的邮件内容,我们不仅可以轻松解析文件内容,还会发现,这封邮件的标题中非英文字符“存在感”特别明显。
fp = open("trec06p/data/100/059", "r", encoding='utf-8', errors='ignore')
msg = email.message_from_file(fp) # 直接文件创建message对象,这个时候也会做初步的解码
subject = msg.get("subject") # 取信件头里的subject, 也就是主题
# 下面的三行代码只是为了解码像=?gbk?Q?=CF=E0=C6=AC?=这样的subject
h = email.header.Header(subject)
dh = email.header.decode_header(h)
subject = dh[0][0]
print("subject:", subject)
print("from: ", email.utils.parseaddr(msg.get("from"))[1]) # 取from
print("to: ", email.utils.parseaddr(msg.get("to"))[1]) # 取to
print("content-type: ", email.utils.parseaddr(msg.get("Content-Type")))
fp.close()
我们也可以用chardet
库查看这些奇怪字符到底是哪类编码模式——这里原来是韩文。
content = codecs.open('trec06p/data/100/059','rb').read()
chardet.detect(content)['encoding']
'EUC-KR'
2.3 解析邮件内容:强大的MMPI库
mmpi,是一款使用python实现的开源邮件快速检测工具库,基于community框架设计开发。mmpi支持对邮件头、邮件正文、邮件附件的解析检测,并输出json检测报告。
mmpi,代码项目地址:https://github.com/a232319779/mmpi,pypi项目地址https://pypi.org/project/mmpi/
mmpi,邮件快速检测工具库检测逻辑:
- 支持解析提取邮件头数据,包括收件人、发件人的姓名和邮箱,邮件主题,邮件发送时间,以及邮件原始发送IP。通过检测发件人邮箱和邮件原始发送IP,实现对邮件头的检测。
- 支持对邮件正文的解析检测,提取text和html格式的邮件正文,对text邮件正文进行关键字匹配,对html邮件正文进行解析分析检测,实现探针邮件检测、钓鱼邮件检测、垃圾邮件检测等其他检测。
- 支持对邮件附件等解析检测
- ole文件格式:如doc、xls等,提取其中的vba宏代码、模板注入链接
- zip文件格式:提取压缩文件列表,统计文件名、文件格式等
- rtf文件格式:解析内嵌ole对象等
- 其他文件格式:如PE可执行文件
- 检测方式包括
- 基础信息规则检测方式
- yara规则检测方式
简而言之,MMPI库不仅能解析邮件内容,还能“分析”邮件内容,它自身已经能依据一些规则,做垃圾邮件检测了。
emp = mmpi()
emp.parse('trec06p/change_code/100/059')
report = emp.get_report()
print("mmpi解析结果: ", report)
# mmpi解析结构
for i in report:
print(i)
headers
body
attachments
signatures
# mmpi库将该邮件标记为垃圾邮件
report['signatures'][4]
{'name': 'spam_detection',
'description': 'SPAM Detection',
'severity': 3,
'marks': [{'type': 'html', 'tag': 'spam_detection'}],
'markcount': 1}
当然,MMPI库也有缺点:没有字符集时,会出现乱码。且由于MMPI库封装比较厉害,要进行改动较为困难,乱码问题很难通过二次开发解决。
# 显示该邮件标题会出现乱码
report['headers'][0]['Subject']
'recover back your lost youth'
# 此时,可以考虑使用正则表达式,检测提取的标题信息中是否含有非英文字符
regexp = re.compile(r'[^\x00-\x7f]')
if regexp.search(report['headers'][0]['Subject']):
print('matched')
matched
基于上面的处理,我们可以考虑进一步将问题简单化:对于英文邮件,尝试提取邮件标题内容进行垃圾邮件识别;而对于非英文的邮件,可以考虑直接用MMPI库自带的垃圾邮件识别功能进行分类,然后通过工程化的方法,将两个结果予以合并。接下来,项目将聚集于训练一个根据英文邮件标题识别垃圾邮件的模型。
2.4 提取英文邮件标题,划分训练集、验证集、测试集
# 检测提取的标题信息中是否含有非英文字符
regexp = re.compile(r'[^\x00-\x7f]')
# 从指定路径读取邮件文件内容信息
def get_data_in_a_file(original_path, save_path='all_email.txt'):
emp = mmpi()
emp.parse(original_path)
report = emp.get_report()
# 如果可以解析到邮件头信息
if report.get('headers') is not None:
try:
# 尝试提取邮件标题,如果成功提取到只含有英文字符的标题,则返回
if report['headers'][0]['Subject'] is not None:
if regexp.search(report['headers'][0]['Subject']):
pass
else:
return report['headers'][0]['Subject']
except TypeError:
print("读取失败:", report['headers'][0]['Subject'])
pass
注意:下面这段代码执行时间非常长,当然,也会生成个别脏数据还需要手动清洗。因此读者可以跳过下面两个cell,直接使用项目提供的处理后训练集、验证集、测试集文件。
# 读取标签文件信息
f = open('trec06p/full/index', 'r')
for line in f:
str_list = line.split(" ")
# 设置垃圾邮件的标签为0
if str_list[0] == 'spam':
label = '0'
# 设置正常邮件标签为1
elif str_list[0] == 'ham':
label = '1'
text = get_data_in_a_file('trec06p/full/' + str(str_list[1].split("\n")[0]))
if text is not None:
with open("all_email.txt","a+") as f:
f.write(text + '\t' + label + '\n')
data_list_path="./"
with open(os.path.join(data_list_path, 'eval_list.txt'), 'w', encoding='utf-8') as f_eval:
f_eval.seek(0)
f_eval.truncate()
with open(os.path.join(data_list_path, 'train_list.txt'), 'w', encoding='utf-8') as f_train:
f_train.seek(0)
f_train.truncate()
with open(os.path.join(data_list_path, 'test_list.txt'), 'w', encoding='utf-8') as f_test:
f_test.seek(0)
f_test.truncate()
with open(os.path.join(data_list_path, 'all_email.txt'), 'r', encoding='utf-8') as f_data:
lines = f_data.readlines()
i = 0
with open(os.path.join(data_list_path, 'eval_list.txt'), 'a', encoding='utf-8') as f_eval,open(os.path.join(data_list_path, 'test_list.txt'), 'a', encoding='utf-8') as f_test,open(os.path.join(data_list_path, 'train_list.txt'), 'a', encoding='utf-8') as f_train:
for line in lines:
words = line.split('\t')[-1].replace('\n', '')
label = line.split('\t')[0]
labs = ""
# 划分验证集
if i % 10 == 1:
labs = label + '\t' + words + '\n'
f_eval.write(labs)
# 划分测试集
elif i % 10 == 2:
labs = label + '\t' + words + '\n'
f_test.write(labs)
# 划分训练集
else:
labs = label + '\t' + words + '\n'
f_train.write(labs)
i += 1
print("数据列表生成完成!")
数据列表生成完成!
2.5 自定义数据集
class SelfDefinedDataset(paddle.io.Dataset):
def __init__(self, data):
super(SelfDefinedDataset, self).__init__()
self.data = data
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def get_labels(self):
return ["0", "1"]
def txt_to_list(file_name):
res_list = []
for line in open(file_name):
res_list.append(line.strip().split('\t'))
return res_list
trainlst = txt_to_list('train_list.txt')
devlst = txt_to_list('eval_list.txt')
testlst = txt_to_list('test_list.txt')
train_ds, dev_ds, test_ds = SelfDefinedDataset.get_datasets([trainlst, devlst, testlst])
#获得标签列表
label_list = train_ds.get_labels()
2.6 训练数据分析
#打印标签
print(label_list)
#看看数据长什么样子,分别打印训练集、验证集、测试集的前3条数据。
print("训练集数据:{}\n".format(train_ds[0:3]))
print("验证集数据:{}\n".format(dev_ds[0:3]))
print("测试集数据:{}\n".format(test_ds[0:3]))
print("训练集样本个数:{}".format(len(train_ds)))
print("验证集样本个数:{}".format(len(dev_ds)))
print("测试集样本个数:{}".format(len(test_ds)))
['0', '1']
训练集数据:[['new Catholic mailing list now up and running', '1'], ['Greetings', '1'], ['LOANS @ 3.17% (27 term)', '0']]
验证集数据:[['re[12]:', '0'], ['Re: VtALtUM news', '0'], ["Re: Plan 9 beginner's questions", '1']]
测试集数据:[['Take a moment to explore this.', '0'], ['Job offer for your person', '0'], ['re[12]:', '0']]
训练集样本个数:2460
验证集样本个数:308
测试集样本个数:308
# 统计训练集正负样本数量
spam = 0
for data in train_ds:
if data[1] == '0':
spam += 1
print("训练集垃圾邮件数量:{}".format(spam))
print("训练集正常邮件数量:{}".format(len(train_ds) - spam))
训练集垃圾邮件数量:1318
训练集正常邮件数量:1142
因此,从统计结果看,训练集样本数量还是比较均衡的。
2.7 数据预处理
在PaddleNLP Transformer API中选择ELECTRA算法,只需参考之前BertTokenizer进行数据处理、加载BERT预训练模型的做法即可。
#调用ppnlp.transformers.ElectraTokenizer进行数据处理,tokenizer可以把原始输入文本转化成模型model可接受的输入数据格式。
tokenizer = ppnlp.transformers.ElectraTokenizer.from_pretrained("electra-large")
#数据预处理
def convert_example(example,tokenizer,label_list,max_seq_length=64,is_test=False):
if is_test:
text = example
else:
text, label = example
#tokenizer.encode方法能够完成切分token,映射token ID以及拼接特殊token
encoded_inputs = tokenizer.encode(text=text, max_seq_len=max_seq_length)
input_ids = encoded_inputs["input_ids"]
#注意,在早前的PaddleNLP版本中,token_type_ids叫做segment_ids
segment_ids = encoded_inputs["token_type_ids"]
if not is_test:
label_map = {}
for (i, l) in enumerate(label_list):
label_map[l] = i
label = label_map[label]
label = np.array([label], dtype="int64")
return input_ids, segment_ids, label
else:
return input_ids, segment_ids
#数据迭代器构造方法
def create_dataloader(dataset, trans_fn=None, mode='train', batch_size=1, use_gpu=False, pad_token_id=0, batchify_fn=None):
if trans_fn:
dataset = dataset.apply(trans_fn, lazy=True)
if mode == 'train' and use_gpu:
sampler = paddle.io.DistributedBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=True)
else:
shuffle = True if mode == 'train' else False #如果不是训练集,则不打乱顺序
sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle) #生成一个取样器
dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, return_list=True, collate_fn=batchify_fn)
return dataloader
#使用partial()来固定convert_example函数的tokenizer, label_list, max_seq_length, is_test等参数值
trans_fn = partial(convert_example, tokenizer=tokenizer, label_list=label_list, max_seq_length=64, is_test=False)
batchify_fn = lambda samples, fn=Tuple(Pad(axis=0,pad_val=tokenizer.pad_token_id), Pad(axis=0, pad_val=tokenizer.pad_token_id), Stack(dtype="int64")):[data for data in fn(samples)]
#训练集迭代器
train_loader = create_dataloader(train_ds, mode='train', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)
#验证集迭代器
dev_loader = create_dataloader(dev_ds, mode='dev', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)
#测试集迭代器
test_loader = create_dataloader(test_ds, mode='test', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)
三、模型训练
3.1 加载预训练模型
model = ppnlp.transformers.ElectraForSequenceClassification.from_pretrained("electra-large", num_classes=2)
3.2 训练与可视化
#设置训练超参数
#学习率
learning_rate = 5e-5
#训练轮次
epochs = 10
#学习率预热比率
warmup_proption = 0.1
#权重衰减系数
weight_decay = 0.01
num_training_steps = len(train_loader) * epochs
num_warmup_steps = int(warmup_proption * num_training_steps)
def get_lr_factor(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
else:
return max(0.0,
float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps)))
#学习率调度器
lr_scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate, lr_lambda=lambda current_step: get_lr_factor(current_step))
#优化器
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
#损失函数
criterion = paddle.nn.loss.CrossEntropyLoss()
#评估函数
metric = paddle.metric.Accuracy()
#评估函数,设置返回值,便于VisualDL记录
def evaluate(model, criterion, metric, data_loader):
model.eval()
metric.reset()
losses = []
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
loss = criterion(logits, labels)
losses.append(loss.numpy())
correct = metric.compute(logits, labels)
metric.update(correct)
accu = metric.accumulate()
print("eval loss: %.5f, accu: %.5f" % (np.mean(losses), accu))
model.train()
metric.reset()
return np.mean(losses), accu
#开始训练
global_step = 0
with LogWriter(logdir="./log") as writer:
for epoch in range(1, epochs + 1):
for step, batch in enumerate(train_loader, start=1): #从训练数据迭代器中取数据
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
loss = criterion(logits, labels) #计算损失
probs = F.softmax(logits, axis=1)
correct = metric.compute(probs, labels)
metric.update(correct)
acc = metric.accumulate()
global_step += 1
if global_step % 10 == 0 :
print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))
#记录训练过程
writer.add_scalar(tag="train/loss", step=global_step, value=loss)
writer.add_scalar(tag="train/acc", step=global_step, value=acc)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
eval_loss, eval_acc = evaluate(model, criterion, metric, dev_loader)
#记录评估过程
writer.add_scalar(tag="eval/loss", step=epoch, value=eval_loss)
writer.add_scalar(tag="eval/acc", step=epoch, value=eval_acc)
3.3 保存模型和网络结构
# Convert to static graph with specific input description
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, None], dtype="int64") # segment_ids
])
# Save in static graph model.
paddle.jit.save(model, './static_graph_params')
四、预测效果
首先,评估一下模型在测试集上的表现,根据标题进行垃圾邮件分类的准确率达到了96%以上。看来,至少在该数据集的垃圾邮件中,标题特征还是非常明显的。
# 评估模型在测试集上的表现
evaluate(model, criterion, metric, test_loader)
eval loss: 0.10905, accu: 0.96753
(0.109047726, 0.9675324675324676)
完成上面的模型训练之后,我们得到了一个能够通过标题对英文垃圾邮件进行初步判定的模型。
实测下它的效果吧!
def predict(model, data, tokenizer, label_map, batch_size=1):
examples = []
for text in data:
input_ids, segment_ids = convert_example(text, tokenizer, label_list=label_map.values(), max_seq_length=128, is_test=True)
examples.append((input_ids, segment_ids))
batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id), Pad(axis=0, pad_val=tokenizer.pad_token_id)): fn(samples)
batches = []
one_batch = []
for example in examples:
one_batch.append(example)
if len(one_batch) == batch_size:
batches.append(one_batch)
one_batch = []
if one_batch:
batches.append(one_batch)
results = []
model.eval()
for batch in batches:
input_ids, segment_ids = batchify_fn(batch)
input_ids = paddle.to_tensor(input_ids)
segment_ids = paddle.to_tensor(segment_ids)
logits = model(input_ids, segment_ids)
probs = F.softmax(logits, axis=1)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
labels = [label_map[i] for i in idx]
results.extend(labels)
return results
label_map = {0: '垃圾邮件', 1: '正常邮件'}
data = ['Good effects of Ephedra', 'Watch INFX like a hawk [D E T A I L S NEW PICK friday it is] drift hypophyseal', 'Re: Off topic: right font on MSIE4']
predictions = predict(model, data, tokenizer, label_map, batch_size=64)
for idx, text in enumerate(data):
end(labels)
return results
label_map = {0: '垃圾邮件', 1: '正常邮件'}
data = ['Good effects of Ephedra', 'Watch INFX like a hawk [D E T A I L S NEW PICK friday it is] drift hypophyseal', 'Re: Off topic: right font on MSIE4']
predictions = predict(model, data, tokenizer, label_map, batch_size=64)
for idx, text in enumerate(data):
print('邮件标题: {} \n邮件标签: {}'.format(text, predictions[idx]))
邮件标题: Good effects of Ephedra
邮件标签: 垃圾邮件
邮件标题: Watch INFX like a hawk [D E T A I L S NEW PICK friday it is] drift hypophyseal
邮件标签: 垃圾邮件
邮件标题: Re: Off topic: right font on MSIE4
邮件标签: 正常邮件
参考资料
- 如何评价NLP算法ELECTRA的表现?
- 阅读笔记 – ELECTRA: PRE-TRAINING TEXT ENCODERS AS DISCRIMINATORS RATHER THAN GENERATORS
小结
1. 真实数据到训练模型的数据集,还是需要取舍的。有时候可以通过规则过滤了一些数据,从而缩小了模型的范围。但无论如何选择,根据业务场景进行判断才是最重要的。
2. 英文邮件短标题的分类效果能达到92%以上,说明这个场景有比较高的业务价值,在工程上可以作为规则(类似于字符集)后的下一道拦截关卡。
更多推荐
所有评论(0)