本文是《使用PaddleNLP识别垃圾邮件》系列第三篇,该系列持续更新中……

系列背景介绍:《使用PaddleNLP识别垃圾邮件》系列项目,针对当前企业面临的垃圾邮件问题,尝试使用深度学习的方法,探索多语言垃圾邮件的内容、标题提取与分类识别。

该系列还有一个姊妹篇,《使用PaddleNLP进行恶意网页识别》,欢迎感兴趣的读者点进来交流评论。
https://ai-studio-static-online.cdn.bcebos.com/d500f483e55544a5b929abad59de208f180c068cc81648009fab60a0b6d9bda2

系列目录

背景知识:ELECTRA算法

本项目体验的是——号称在吊打BERT的同时,算力能节约一半以上的ELECTRA算法。

ELECTRA将BERT与类似于GAN的结构相结合,并辅以新的预训练任务来做预训练。结果:在更少的参数量和数据下,效果超越BERT,并且仅用1/4的算力就达到了SOTA模型RoBERTa的效果。

https://i-blog.csdnimg.cn/blog_migrate/6b1575c76460ef7d207bed3d3b924b2b.jpeg

实现

NLP式的Generator-Discriminator

在这里插入图片描述

  • ELECTRA最主要的贡献是提出了新的预训练任务和框架,把生成式的Masked language model(MLM)预训练任务改成了判别式的Replaced token detection(RTD)任务,判断当前token是否被语言模型替换过。
  • 一点思考:随机替换一些输入中的字词,再让BERT去预测是否替换过可以吗?答案:不可以,因为随机替换过于简单,效果不好。
  • 为了让替换更加真实,作者提出了利用一个基于MLM的Generator来替换example中的某些个token,然后丢给Discriminator来判别。

目标函数

由于生成器的输入输出都为句子,而句子中的字词都是离散的,因此判别器的梯度无法传给生成器,因此生成器的训练目标仍旧和BERT一样,而RTD的训练就是传统的二分类:

在这里插入图片描述

所以整个预训练框架的loss为:

在这里插入图片描述

因为判别器的任务相对来说容易些,RTD loss相对MLM loss会很小,因此加上一个系数,作者训练时使用了50

ELECTRA与GAN的区别

在这里插入图片描述

总结

关于ELECTRA,知乎上有个非常精辟的总结:

ELECTRA和BERT最大的不同应该是在于两个方面:

  • masked(replaced) tokens的选择
  • training objective

第一个,token的选择BERT是随机的,这意味着什么呢?比如句子“我想吃苹果”,BERT可以mask为“我想吃苹[MASK]”,这样一来实际上去学它就很简单,如果mask为“我[MASK]吃苹果”,那么去学这个“想”就相对困难了。

换句话说,BERT的mask可能会有很多简单的token,去学这些token就算是简单的bilstm都可以做的。这样一来,一个简单的想法就是,不随机mask,去专门选那些对模型来说学习困难的token。

怎么做呢?这就是ELECTRA非常牛逼的地方了,train一个简单的MLM,当做模型对训练难度的先验,简单的自动过滤(在这里就是sample出来的和原句子一样),复杂的后面再学。

还是举“我想吃苹果”这个例子。比如我这里还是mask为“我想吃苹[MASK]”,MLM这个生成器可以以很高的概率sample到“果”,但是对“我[MASK]想吃苹果”,MLM就很难说大概率采样到“想”了,也可能是“不”、“真”等等……

总的来说,MLM的作用就是为自动选择masked tokens提供了一种非常有效的方法

第二个,既然MLM选择了一些token,那么该怎么去学呢?当然这个地方也可以像BERT那样,如果MLM采样的保持不变,就相当于原BERT中不mask;如果变了,就mask,然后再用BERT的方法去train。

然而ELECTRA另辟蹊径,用一个二分类去判断每个token是否已经被换过了。这就把一个DAE(或者LM)任务转换为了一个分类任务(或者序列标注)。这有两个好处:

1)每个token都能contribute to some extent

2)缓解distribution的问题

第一点是和MLM联系起来的(这也是ELECTRA精妙的地方了)。如果MLM牛逼,那么discriminate的难度就很大,从而就可以看作是hard example。

第二点是,如果我们像BERT那样去预测真正的token,也即通过一个classifier [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1vxEU1BO-1635902021512)(https://www.zhihu.com/equation?tex=%5Ctilde%7BC%7D%5Cin%5Cmathbb%7BR%7D%5E%7Bd%5Ctimes+%7CV%7C%7D)] 的话,那么它相比二分类器 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qCHZsbAv-1635902021520)(https://www.zhihu.com/equation?tex=C%5Cin%5Cmathbb%7BR%7D%5E%7Bd%5Ctimes+2%7D)] 而言就需要更多的计算量,而且还要suffer 由于 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZUNQmeES-1635902021529)(https://www.zhihu.com/equation?tex=%7CV%7C)] 较大导致的分布问题。

以上两点总结起来就是:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aH7Bia5y-1635902021536)(https://www.zhihu.com/equation?tex=%5Cboxed%7B%5B%5Ccolor%7Bblue%7D%7B%5Ctext%7BBERT%7D%7D%5D%3A%5Ctext%7BMasked+tokens%7D%5Crightarrow%5Ctext%7BPredict+them%7D%7D+%5C%5C+%5Ctext%7BFaster%7D%5CDownarrow+%5Ctext%7BStronger%7D+%5C%5C+%5Cboxed%7B%5B%5Ccolor%7Bred%7D%7B%5Ctext%7BELECTRA%7D%7D%5D%3A+%5Ctext%7BReplaced+tokens%7D%5Crightarrow%5Ctext%7BDistinguish+them%7D%7D)]

数据集介绍

TREC 2006 Spam Track Public Corpora是一个公开的垃圾邮件语料库,由国际文本检索会议提供,分为英文数据集(trec06p)和中文数据集(trec06c),其中所含的邮件均来源于真实邮件保留了邮件的原有格式和内容。

除TREC 2006外,还有TREC 2005和TREC 2007的英文垃圾邮件数据集(对,没有中文),本项目中,仅使用TREC 2006提供的英文数据集进行演示。TREC 2005-2007的垃圾邮件数据集,均已整理在项目挂载的数据集中,感兴趣的读者可以自行fork。

文件目录形式:delay和full分别是一种垃圾邮件过滤器的过滤机制,full目录下,是理想的邮件分类结果,我们可以视为研究的标签。

trec06c
│
└───data
│   │   000
│   │   001
│   │   ...
│   └───215
└───delay
│   │   index
└───full
│   │   index  

一、环境配置

本项目基于Paddle 2.0 编写,如果你的环境不是本版本,请先参考官网安装 Paddle 2.0 。

# 加载开源邮件快速检测工具库mmpi
!pip install mmpi
!pip install yara-python
# 导入相关的模块
import re
import jieba
import os 
import random
import paddle
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Pad, Tuple
import paddle.nn.functional as F
import paddle.nn as nn
from visualdl import LogWriter
import numpy as np
from functools import partial #partial()函数可以用来固定某些参数值,并返回一个新的callable对象
from tqdm import tqdm
import codecs
import chardet
import email
from mmpi import mmpi
print(paddle.__version__)
2.0.2

https://bbs.pediy.com/thread-264909.html

二、数据加载

2.1 数据集准备

# 解压数据集
!tar xvf data/data89631/trec06p.tgz

随机选择一个文件,查看邮件文件的具体内容。请注意,一定要设置errors='ignore',否则可能因为个别字符无法识别,导致内容无法读取。

f = open('trec06p/data/100/059', 'r', encoding='utf-8', errors='ignore')
text = ''
for line in f:
    line = line.strip().strip('\n')
    if len(line) > 1:
        print(line)
        text = text + line 

我们可以看到,在邮件的正文开始前,有相当多的信息,比如收件人、发件人、邮件标题、时间等。但是在本项目使用的多语言邮件数据集中,字符编码格式也是个让人头疼的问题,比如下面这封邮件,就属于典型的垃圾邮件,标题是比较混乱的符号。

Received: from 221.201.116.189 (unknown [221.201.116.189])
by jalapeno.cc.columbia.edu (8.13.0/8.13.0) with ESMTP id j1PLKJKZ017426
for <romanianclub@send.columbia.edu>; Sun, 20 Feb 2005 23:23:04 -0500 (EST)
Received: from 60.199.250.207 by ; Sun, 20 Feb 2005 22:18:42 -0100
Message-ID: <IJDNTHVEZCAQUORIDVRANKF@yahoo.com>
From: "" <j90jil0pipojgh@naver.com>
Reply-To: "" <j90jil0pipojgh7@naver.com>
To: romanianclub@send.columbia.edu
Subject:  ,, = ,, ̵  ּ.  ʽϴ...
Date: Sun, 20 Feb 2005 18:21:42 -0500
X-Mailer: The Bat! (v1.52f) Business
……

由于原数据集中这些问题比较多,因此,在展开具体的邮件内容解析前这里先统一进行转码。

!mkdir trec06p/change_code
# 读取标签文件信息
f = open('trec06p/full/index', 'r')
for line in tqdm(f):
    str_list = line.split(" ")
    # 读取文件,获取字符集
    content = codecs.open('trec06p/'+str_list[1][3:].replace("\n", ""),'rb').read()
    source_encoding = chardet.detect(content)['encoding']
    if not os.path.exists('trec06p/change_code' + str_list[1][7:-5].replace("\n", "")):
        os.makedirs('trec06p/change_code' + str_list[1][7:-5].replace("\n", ""))
    # 个别文件的source_encoding是None,这里要先进行筛选
    if source_encoding is None:
        pass
    # 对字符集不是utf-8格式的文件尝试转码
    elif source_encoding != 'utf-8':
        # 转码如果失败,就跳过该文件
        try:
            content = content.decode(source_encoding).encode('utf-8')
            codecs.open('trec06p/change_code'+str_list[1][7:].replace("\n", ""),'wb').write(content)
        except UnicodeDecodeError:
            print(str_list[1][7:].replace("\n", "") + "读取失败")
            pass
    # 字符集是utf-8格式的文件直接保存
    else:
        codecs.open('trec06p/change_code'+str_list[1][7:].replace("\n", ""),'wb').write(content)

2.2 解析邮件内容:Python的Email模块

首先,我们来看看Python自带的email模块,官方文档这样写到:

email 包是一个用于管理电子邮件消息的库。 它 并非 被设计为执行向 SMTP (RFC 2821), NNTP 或其他服务器发送电子邮件消息的操作;这些是 smtplibnntplib 等模块的功能。 email 包试图尽可能遵循 RFC,支持 RFC 5233RFC 6532,以及与 MIME 相关的各个 RFC 例如 RFC 2045, RFC 2046, RFC 2047, RFC 2183RFC 2231

具体用法上:

如果通过Email模块查看转码后的邮件内容,我们不仅可以轻松解析文件内容,还会发现,这封邮件的标题中非英文字符“存在感”特别明显。

fp = open("trec06p/data/100/059", "r", encoding='utf-8', errors='ignore')
msg = email.message_from_file(fp) # 直接文件创建message对象,这个时候也会做初步的解码
subject = msg.get("subject") # 取信件头里的subject, 也就是主题
# 下面的三行代码只是为了解码像=?gbk?Q?=CF=E0=C6=AC?=这样的subject
h = email.header.Header(subject)
dh = email.header.decode_header(h)
subject = dh[0][0]
print("subject:", subject)
print("from: ", email.utils.parseaddr(msg.get("from"))[1]) # 取from
print("to: ", email.utils.parseaddr(msg.get("to"))[1]) # 取to
print("content-type: ", email.utils.parseaddr(msg.get("Content-Type")))
fp.close()

我们也可以用chardet库查看这些奇怪字符到底是哪类编码模式——这里原来是韩文。

content = codecs.open('trec06p/data/100/059','rb').read()
chardet.detect(content)['encoding']
'EUC-KR'

2.3 解析邮件内容:强大的MMPI库

mmpi,是一款使用python实现的开源邮件快速检测工具库,基于community框架设计开发。mmpi支持对邮件头、邮件正文、邮件附件的解析检测,并输出json检测报告。

mmpi,代码项目地址:https://github.com/a232319779/mmpi,pypi项目地址https://pypi.org/project/mmpi/

mmpi,邮件快速检测工具库检测逻辑:

  • 支持解析提取邮件头数据,包括收件人、发件人的姓名和邮箱,邮件主题,邮件发送时间,以及邮件原始发送IP。通过检测发件人邮箱和邮件原始发送IP,实现对邮件头的检测。
  • 支持对邮件正文的解析检测,提取text和html格式的邮件正文,对text邮件正文进行关键字匹配,对html邮件正文进行解析分析检测,实现探针邮件检测、钓鱼邮件检测、垃圾邮件检测等其他检测。
  • 支持对邮件附件等解析检测
    • ole文件格式:如doc、xls等,提取其中的vba宏代码、模板注入链接
    • zip文件格式:提取压缩文件列表,统计文件名、文件格式等
    • rtf文件格式:解析内嵌ole对象等
    • 其他文件格式:如PE可执行文件
  • 检测方式包括
    • 基础信息规则检测方式
    • yara规则检测方式

简而言之,MMPI库不仅能解析邮件内容,还能“分析”邮件内容,它自身已经能依据一些规则,做垃圾邮件检测了。

emp = mmpi()
emp.parse('trec06p/change_code/100/059')
report = emp.get_report()
print("mmpi解析结果: ", report)
# mmpi解析结构
for i in report:
    print(i)
headers
body
attachments
signatures
# mmpi库将该邮件标记为垃圾邮件
report['signatures'][4]
{'name': 'spam_detection',
 'description': 'SPAM Detection',
 'severity': 3,
 'marks': [{'type': 'html', 'tag': 'spam_detection'}],
 'markcount': 1}

当然,MMPI库也有缺点:没有字符集时,会出现乱码。且由于MMPI库封装比较厉害,要进行改动较为困难,乱码问题很难通过二次开发解决。

# 显示该邮件标题会出现乱码
report['headers'][0]['Subject']
'recover back your lost youth'
# 此时,可以考虑使用正则表达式,检测提取的标题信息中是否含有非英文字符
regexp = re.compile(r'[^\x00-\x7f]')
if regexp.search(report['headers'][0]['Subject']):
  print('matched')
matched

基于上面的处理,我们可以考虑进一步将问题简单化:对于英文邮件,尝试提取邮件标题内容进行垃圾邮件识别;而对于非英文的邮件,可以考虑直接用MMPI库自带的垃圾邮件识别功能进行分类,然后通过工程化的方法,将两个结果予以合并。接下来,项目将聚集于训练一个根据英文邮件标题识别垃圾邮件的模型。

2.4 提取英文邮件标题,划分训练集、验证集、测试集

# 检测提取的标题信息中是否含有非英文字符
regexp = re.compile(r'[^\x00-\x7f]')
# 从指定路径读取邮件文件内容信息
def get_data_in_a_file(original_path, save_path='all_email.txt'):
    emp = mmpi()
    emp.parse(original_path)
    report = emp.get_report()
    # 如果可以解析到邮件头信息
    if report.get('headers') is not None:
        try:
            # 尝试提取邮件标题,如果成功提取到只含有英文字符的标题,则返回
            if report['headers'][0]['Subject'] is not None:
                if regexp.search(report['headers'][0]['Subject']):
                    pass
                else:
                    return report['headers'][0]['Subject']
        except TypeError:
            print("读取失败:", report['headers'][0]['Subject'])
            pass

注意:下面这段代码执行时间非常长,当然,也会生成个别脏数据还需要手动清洗。因此读者可以跳过下面两个cell,直接使用项目提供的处理后训练集、验证集、测试集文件。

# 读取标签文件信息
f = open('trec06p/full/index', 'r')
for line in f:
    str_list = line.split(" ")
    # 设置垃圾邮件的标签为0
    if str_list[0] == 'spam':
        label = '0'
    # 设置正常邮件标签为1
    elif str_list[0] == 'ham':
        label = '1'
    text = get_data_in_a_file('trec06p/full/' + str(str_list[1].split("\n")[0]))
    if text is not None:
        with open("all_email.txt","a+") as f:
                        f.write(text + '\t' + label + '\n')
data_list_path="./"

with open(os.path.join(data_list_path, 'eval_list.txt'), 'w', encoding='utf-8') as f_eval:
    f_eval.seek(0)
    f_eval.truncate()
    
with open(os.path.join(data_list_path, 'train_list.txt'), 'w', encoding='utf-8') as f_train:
    f_train.seek(0)
    f_train.truncate() 

with open(os.path.join(data_list_path, 'test_list.txt'), 'w', encoding='utf-8') as f_test:
    f_test.seek(0)
    f_test.truncate()

with open(os.path.join(data_list_path, 'all_email.txt'), 'r', encoding='utf-8') as f_data:
    lines = f_data.readlines()

i = 0
with open(os.path.join(data_list_path, 'eval_list.txt'), 'a', encoding='utf-8') as f_eval,open(os.path.join(data_list_path, 'test_list.txt'), 'a', encoding='utf-8') as f_test,open(os.path.join(data_list_path, 'train_list.txt'), 'a', encoding='utf-8') as f_train:
    for line in lines:
        words = line.split('\t')[-1].replace('\n', '')
        label = line.split('\t')[0]
        labs = ""
        # 划分验证集
        if i % 10 == 1:
            labs = label + '\t' + words + '\n'
            f_eval.write(labs)
        # 划分测试集
        elif i % 10 == 2:
            labs = label + '\t' + words + '\n'
            f_test.write(labs)
        # 划分训练集
        else:
            labs = label + '\t' + words + '\n'
            f_train.write(labs)
        i += 1
    
print("数据列表生成完成!")
数据列表生成完成!

2.5 自定义数据集

class SelfDefinedDataset(paddle.io.Dataset):
    def __init__(self, data):
        super(SelfDefinedDataset, self).__init__()
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)
        
    def get_labels(self):
        return ["0", "1"]

def txt_to_list(file_name):
    res_list = []
    for line in open(file_name):
        res_list.append(line.strip().split('\t'))
    return res_list

trainlst = txt_to_list('train_list.txt')
devlst = txt_to_list('eval_list.txt')
testlst = txt_to_list('test_list.txt')

train_ds, dev_ds, test_ds = SelfDefinedDataset.get_datasets([trainlst, devlst, testlst])
#获得标签列表
label_list = train_ds.get_labels()

2.6 训练数据分析

#打印标签
print(label_list)

#看看数据长什么样子,分别打印训练集、验证集、测试集的前3条数据。
print("训练集数据:{}\n".format(train_ds[0:3]))
print("验证集数据:{}\n".format(dev_ds[0:3]))
print("测试集数据:{}\n".format(test_ds[0:3]))

print("训练集样本个数:{}".format(len(train_ds)))
print("验证集样本个数:{}".format(len(dev_ds)))
print("测试集样本个数:{}".format(len(test_ds)))
['0', '1']
训练集数据:[['new Catholic mailing list now up and running', '1'], ['Greetings', '1'], ['LOANS @ 3.17% (27 term)', '0']]

验证集数据:[['re[12]:', '0'], ['Re: VtALtUM news', '0'], ["Re: Plan 9 beginner's questions", '1']]

测试集数据:[['Take a moment to explore this.', '0'], ['Job offer for your person', '0'], ['re[12]:', '0']]

训练集样本个数:2460
验证集样本个数:308
测试集样本个数:308
# 统计训练集正负样本数量
spam = 0
for data in train_ds:
    if data[1] == '0':
        spam += 1
print("训练集垃圾邮件数量:{}".format(spam))
print("训练集正常邮件数量:{}".format(len(train_ds) - spam))
训练集垃圾邮件数量:1318
训练集正常邮件数量:1142

因此,从统计结果看,训练集样本数量还是比较均衡的。

2.7 数据预处理

PaddleNLP Transformer API中选择ELECTRA算法,只需参考之前BertTokenizer进行数据处理、加载BERT预训练模型的做法即可。

#调用ppnlp.transformers.ElectraTokenizer进行数据处理,tokenizer可以把原始输入文本转化成模型model可接受的输入数据格式。
tokenizer = ppnlp.transformers.ElectraTokenizer.from_pretrained("electra-large")

#数据预处理
def convert_example(example,tokenizer,label_list,max_seq_length=64,is_test=False):
    if is_test:
        text = example
    else:
        text, label = example
    #tokenizer.encode方法能够完成切分token,映射token ID以及拼接特殊token
    encoded_inputs = tokenizer.encode(text=text, max_seq_len=max_seq_length)
    input_ids = encoded_inputs["input_ids"]
    #注意,在早前的PaddleNLP版本中,token_type_ids叫做segment_ids
    segment_ids = encoded_inputs["token_type_ids"]

    if not is_test:
        label_map = {}
        for (i, l) in enumerate(label_list):
            label_map[l] = i

        label = label_map[label]
        label = np.array([label], dtype="int64")
        return input_ids, segment_ids, label
    else:
        return input_ids, segment_ids

#数据迭代器构造方法
def create_dataloader(dataset, trans_fn=None, mode='train', batch_size=1, use_gpu=False, pad_token_id=0, batchify_fn=None):
    if trans_fn:
        dataset = dataset.apply(trans_fn, lazy=True)

    if mode == 'train' and use_gpu:
        sampler = paddle.io.DistributedBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=True)
    else:
        shuffle = True if mode == 'train' else False #如果不是训练集,则不打乱顺序
        sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle) #生成一个取样器
    dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, return_list=True, collate_fn=batchify_fn)
    return dataloader

#使用partial()来固定convert_example函数的tokenizer, label_list, max_seq_length, is_test等参数值
trans_fn = partial(convert_example, tokenizer=tokenizer, label_list=label_list, max_seq_length=64, is_test=False)
batchify_fn = lambda samples, fn=Tuple(Pad(axis=0,pad_val=tokenizer.pad_token_id), Pad(axis=0, pad_val=tokenizer.pad_token_id), Stack(dtype="int64")):[data for data in fn(samples)]
#训练集迭代器
train_loader = create_dataloader(train_ds, mode='train', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)
#验证集迭代器
dev_loader = create_dataloader(dev_ds, mode='dev', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)
#测试集迭代器
test_loader = create_dataloader(test_ds, mode='test', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)

三、模型训练

3.1 加载预训练模型

model = ppnlp.transformers.ElectraForSequenceClassification.from_pretrained("electra-large", num_classes=2)

3.2 训练与可视化

#设置训练超参数

#学习率
learning_rate = 5e-5 
#训练轮次
epochs = 10
#学习率预热比率
warmup_proption = 0.1
#权重衰减系数
weight_decay = 0.01

num_training_steps = len(train_loader) * epochs
num_warmup_steps = int(warmup_proption * num_training_steps)

def get_lr_factor(current_step):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    else:
        return max(0.0,
                    float(num_training_steps - current_step) /
                    float(max(1, num_training_steps - num_warmup_steps)))
#学习率调度器
lr_scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate, lr_lambda=lambda current_step: get_lr_factor(current_step))

#优化器
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ])

#损失函数
criterion = paddle.nn.loss.CrossEntropyLoss()
#评估函数
metric = paddle.metric.Accuracy()
#评估函数,设置返回值,便于VisualDL记录
def evaluate(model, criterion, metric, data_loader):
    model.eval()
    metric.reset()
    losses = []
    for batch in data_loader:
        input_ids, segment_ids, labels = batch
        logits = model(input_ids, segment_ids)
        loss = criterion(logits, labels)
        losses.append(loss.numpy())
        correct = metric.compute(logits, labels)
        metric.update(correct)
        accu = metric.accumulate()
    print("eval loss: %.5f, accu: %.5f" % (np.mean(losses), accu))
    model.train()
    metric.reset()
    return np.mean(losses), accu
#开始训练
global_step = 0
with LogWriter(logdir="./log") as writer:
    for epoch in range(1, epochs + 1):    
        for step, batch in enumerate(train_loader, start=1): #从训练数据迭代器中取数据
            input_ids, segment_ids, labels = batch
            logits = model(input_ids, segment_ids)
            loss = criterion(logits, labels) #计算损失
            probs = F.softmax(logits, axis=1)
            correct = metric.compute(probs, labels)
            metric.update(correct)
            acc = metric.accumulate()

            global_step += 1
            if global_step % 10 == 0 :
                print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))
                #记录训练过程
                writer.add_scalar(tag="train/loss", step=global_step, value=loss)
                writer.add_scalar(tag="train/acc", step=global_step, value=acc)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
        eval_loss, eval_acc = evaluate(model, criterion, metric, dev_loader)
        #记录评估过程
        writer.add_scalar(tag="eval/loss", step=epoch, value=eval_loss)
        writer.add_scalar(tag="eval/acc", step=epoch, value=eval_acc)

3.3 保存模型和网络结构

# Convert to static graph with specific input description
model = paddle.jit.to_static(
    model,
    input_spec=[
        paddle.static.InputSpec(
            shape=[None, None], dtype="int64"),  # input_ids
        paddle.static.InputSpec(
            shape=[None, None], dtype="int64")  # segment_ids
    ])
# Save in static graph model.
paddle.jit.save(model, './static_graph_params')

四、预测效果

首先,评估一下模型在测试集上的表现,根据标题进行垃圾邮件分类的准确率达到了96%以上。看来,至少在该数据集的垃圾邮件中,标题特征还是非常明显的。

# 评估模型在测试集上的表现
evaluate(model, criterion, metric, test_loader)
eval loss: 0.10905, accu: 0.96753





(0.109047726, 0.9675324675324676)

完成上面的模型训练之后,我们得到了一个能够通过标题对英文垃圾邮件进行初步判定的模型。

实测下它的效果吧!https://ai-studio-static-online.cdn.bcebos.com/d500f483e55544a5b929abad59de208f180c068cc81648009fab60a0b6d9bda2

def predict(model, data, tokenizer, label_map, batch_size=1):
    examples = []
    for text in data:
        input_ids, segment_ids = convert_example(text, tokenizer, label_list=label_map.values(),  max_seq_length=128, is_test=True)
        examples.append((input_ids, segment_ids))

    batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id), Pad(axis=0, pad_val=tokenizer.pad_token_id)): fn(samples)
    batches = []
    one_batch = []
    for example in examples:
        one_batch.append(example)
        if len(one_batch) == batch_size:
            batches.append(one_batch)
            one_batch = []
    if one_batch:
        batches.append(one_batch)

    results = []
    model.eval()
    for batch in batches:
        input_ids, segment_ids = batchify_fn(batch)
        input_ids = paddle.to_tensor(input_ids)
        segment_ids = paddle.to_tensor(segment_ids)
        logits = model(input_ids, segment_ids)
        probs = F.softmax(logits, axis=1)
        idx = paddle.argmax(probs, axis=1).numpy()
        idx = idx.tolist()
        labels = [label_map[i] for i in idx]
        results.extend(labels)
    return results
label_map = {0: '垃圾邮件', 1: '正常邮件'}
data = ['Good effects of Ephedra', 'Watch INFX like a hawk [D E T A I L S NEW PICK friday  it is]  drift hypophyseal', 'Re: Off topic: right font on MSIE4']
predictions = predict(model, data, tokenizer, label_map, batch_size=64)
for idx, text in enumerate(data):
end(labels)
    return results
label_map = {0: '垃圾邮件', 1: '正常邮件'}
data = ['Good effects of Ephedra', 'Watch INFX like a hawk [D E T A I L S NEW PICK friday  it is]  drift hypophyseal', 'Re: Off topic: right font on MSIE4']
predictions = predict(model, data, tokenizer, label_map, batch_size=64)
for idx, text in enumerate(data):
    print('邮件标题: {} \n邮件标签: {}'.format(text, predictions[idx]))
邮件标题: Good effects of Ephedra 
邮件标签: 垃圾邮件
邮件标题: Watch INFX like a hawk [D E T A I L S NEW PICK friday  it is]  drift hypophyseal 
邮件标签: 垃圾邮件
邮件标题: Re: Off topic: right font on MSIE4 
邮件标签: 正常邮件

参考资料

小结

1. 真实数据到训练模型的数据集,还是需要取舍的。有时候可以通过规则过滤了一些数据,从而缩小了模型的范围。但无论如何选择,根据业务场景进行判断才是最重要的。
2. 英文邮件短标题的分类效果能达到92%以上,说明这个场景有比较高的业务价值,在工程上可以作为规则(类似于字符集)后的下一道拦截关卡。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐