使用PaddleNLP识别垃圾邮件(四):用RoBERTa做中文邮件标题分类
使用PaddleNLP的RoBERTa预训练模型,根据提取的中文邮件标题判断邮件是否为垃圾邮件,并完成批量邮件分类的python部署预测。
本文是《使用PaddleNLP识别垃圾邮件》系列第四篇,该系列持续更新中……
系列背景介绍:《使用PaddleNLP识别垃圾邮件》系列项目,针对当前企业面临的垃圾邮件问题,尝试使用深度学习的方法,探索多语言垃圾邮件的内容、标题提取与分类识别。
该系列还有一个姊妹篇,《使用PaddleNLP进行恶意网页识别》,欢迎感兴趣的读者点进来交流评论。
系列目录
- 使用PaddleNLP识别垃圾邮件(一):准确率98.5%的垃圾邮件分类器
- 使用PaddleNLP的文本分类LSTM模型,提取中文邮件内容判断邮件是否为垃圾邮件。
- 使用PaddleNLP识别垃圾邮件(二):用BERT做中文邮件内容分类
- 使用PaddleNLP的BERT预训练模型,根据提取的中文邮件内容判断邮件是否为垃圾邮件。
- 使用PaddleNLP识别垃圾邮件(三):用ELECTRA做英文邮件标题分类
- 介绍在Python中解析eml邮件内容的办法:email模块和mmpi库;
- 使用PaddleNLP的ELECTRA预训练模型,根据提取的英文邮件标题判断邮件是否为垃圾邮件。
- 使用PaddleNLP识别垃圾邮件(四):用RoBERTa做中文邮件标题分类
- 升级到最新自定义数据集方法;
- 使用PaddleNLP模型库,大幅简化开发流程;
- 使用PaddleNLP的RoBERTa预训练模型,根据提取的中文邮件标题判断邮件是否为垃圾邮件;
- 完成完整的批量邮件分类部署流程。
背景知识
RoBERTa算法
- RoBERTa: A Robustly Optimized BERT Pretraining Approach
- RoBERTa项目中文地址:https://github.com/brightmart/roberta_zh
RoBERTa是BERT的改进版,通过改进训练任务和数据生成方式、训练更久、使用更大批次、使用更多数据等获得了SOTA的效果;可以用BERT直接加载。A Robustly Optimized BERT,顾名思义,就是简单粗暴的BERT调优……
在模型规模、算力和数据上,与BERT相比主要有以下几点改进:
- 更大的模型参数量(论文提供的训练时间来看,模型使用 1024 块 V100 GPU 训练了 1 天的时间)
- 更大bacth size。RoBERTa 在训练过程中使用了更大的bacth size。尝试过从 256 到 8000 不等的bacth size。
- 更多的训练数据(包括:CC-NEWS 等在内的 160GB 纯文本。而最初的BERT使用16GB BookCorpus数据集和英语维基百科进行训练)
另外,RoBERTa在训练方法上有以下改进:
- 去掉下一句预测(NSP)任务
- 动态掩码。BERT 依赖随机掩码和预测 token。原版的 BERT 实现在数据预处理期间执行一次掩码,得到一个静态掩码。 而 RoBERTa 使用了动态掩码:每次向模型输入一个序列时都会生成新的掩码模式。这样,在大量数据不断输入的过程中,模型会逐渐适应不同的掩码策略,学习不同的语言表征。
- 文本编码。Byte-Pair Encoding(BPE)是字符级和词级别表征的混合,支持处理自然语言语料库中的众多常见词汇。原版的 BERT 实现使用字符级别的 BPE 词汇,大小为 30K,是在利用启发式分词规则对输入进行预处理之后学得的。Facebook 研究者没有采用这种方式,而是考虑用更大的 byte 级别 BPE 词汇表来训练 BERT,这一词汇表包含 50K 的 subword 单元,且没有对输入作任何额外的预处理或分词。
RoBERTa中文版
RoBERTa中文版所指的中文预训练RoBERTa模型只指按照RoBERTa论文主要精神训练的模型。包括:
1、数据生成方式和任务改进:取消下一个句子预测,并且数据连续从一个文档中获得。
2、更大更多样性的数据:使用30G中文训练,包含3亿个句子,100亿个字(即token)。由新闻、社区讨论、多个百科,包罗万象,覆盖数十万个主题,所以数据具有多样性(为了更有多样性,可以可以加入网络书籍、小说、故事类文学、微博等)。
3、训练更久:总共训练了近20万,总共见过近16亿个训练数据(instance); 在Cloud TPU v3-256 上训练了24小时,相当于在TPU v3-8(128G显存)上需要训练一个月。
4、更大批次:使用了超大(8k)的批次batch size。
5、调整优化器等超参数。
除以上外,RoBERTa中文版,使用了全词掩码(whole word mask)。在全词Mask中,如果一个完整的词的部分WordPiece子词被mask,则同属该词的其他部分也会被mask,即全词掩码。
RoBERTa中并没有直接实现动态掩码。通过复制一个训练样本得到多份数据,每份数据使用不同掩码,并加大复制的分数,可间接得到动态掩码效果。
PaddleNLP提供的RoBERTa预训练模型
在PaddleNLP Transformer API中,提供了下面几种RoBERTa算法的预训练模型。
数据集介绍
TREC 2006 Spam Track Public Corpora是一个公开的垃圾邮件语料库,由国际文本检索会议提供,分为英文数据集(trec06p)和中文数据集(trec06c),其中所含的邮件均来源于真实邮件保留了邮件的原有格式和内容。
除TREC 2006外,还有TREC 2005和TREC 2007的英文垃圾邮件数据集(对,没有中文),本项目中,仅使用TREC 2006提供的中文数据集进行演示。TREC 2005-2007的垃圾邮件数据集,均已整理在项目挂载的数据集中,感兴趣的读者可以自行fork。
文件目录形式:delay和full分别是一种垃圾邮件过滤器的过滤机制,full目录下,是理想的邮件分类结果,我们可以视为研究的标签。
trec06c
│
└───data
│ │ 000
│ │ 001
│ │ ...
│ └───215
└───delay
│ │ index
└───full
│ │ index
一、环境配置
本项目基于Paddle 2.0 编写,如果你的环境不是本版本,请先参考官网安装 Paddle 2.0 。
# 升级paddlenlp到2.0.2
!pip install -U paddlenlp
# 加载开源邮件快速检测工具库mmpi
!pip install mmpi
!pip install yara-python
# 导入相关的模块
import numpy as np
import re
import jieba
import os
import random
import paddle
import paddlenlp
from paddlenlp.data import Stack, Pad, Tuple
import paddle.nn.functional as F
import paddle.nn as nn
from visualdl import LogWriter
from functools import partial #partial()函数可以用来固定某些参数值,并返回一个新的callable对象
from tqdm import tqdm
from mmpi import mmpi
print(paddle.__version__)
2.0.2
二、数据加载
2.1 数据集准备
# 解压数据集
!tar xvf data/data89631/trec06c.tgz
随机选择一个文件,查看邮件文件的具体内容。中文邮件有可能编码格式是gb2312
,因此要注意,如果指定了utf-8
编码,可能看到的会是乱码。
f = open('trec06c/data/050/105', 'r', encoding='gb2312', errors='ignore')
text = ''
for line in f:
line = line.strip().strip('\n')
if len(line) > 1:
print(line)
text = text + line
我们可以看到,读取原始eml
格式文件,包括邮件的正文在内,可以得到相当多的信息,比如收件人、发件人、邮件标题、时间等。
但是,最关键的中文的邮件标题会被显示成:Subject: =?gb2312?B?tee2r83yxNy/qsv4xvcvtqjG2s/7yqexyjAwOjI2OjU0?=
这种形式,还需要一次额外转码。
正如在使用PaddleNLP识别垃圾邮件(三):用ELECTRA做英文邮件标题分类中已经演示过的,使用Python自带的Email模块,中文邮件标题的显示结果也是一样。
Received: from cfiin.com ([61.153.234.29])
by spam-gw.ccert.edu.cn (MIMEDefang) with ESMTP id j7OGD6fP030969
for <ma@ccert.edu.cn>; Sun, 28 Aug 2005 03:00:33 +0800 (CST)
Message-ID: <200508250013.j7OGD6fP030969@spam-gw.ccert.edu.cn>
From: =?GB2312?B?wO7Qoc2u?= <ren@cfiin.com>
Subject: =?gb2312?B?tee2r83yxNy/qsv4xvcvtqjG2s/7yqexyjAwOjI2OjU0?=
To: ma@ccert.edu.cn
Content-Type: text/html;charset="GB2312"
Reply-To: ren@cfiiiin.com
Date: Sun, 28 Aug 2005 03:14:19 +0800
X-Priority: 3
X-Mailer: Microsoft Outlook Express 6.00.2800.1106
REFRESH(3 sec): http://www.wst00.com/index.htm
最新窃听、监视、透视设备
万事通(国际)电子科技隶属迪克伟业(国际)集团子公司。专业经
营最新高科技电子产品业务,为香港同行业之龙头企业。本公司专业经
营最新高科技电子产品业务,产品包括:
电话通话内容监听器 GSM移动电话拦截系统 手机窃听器
超远程窃听器 微型随身窃听器(世界上最小的窃听器) 透视滤镜
监视眼镜 WL-01高性能隔墙监听器 电子追踪器
电话传真拦截器 针孔摄像机(超远程传送) 卫星全球定位系统
游戏机反遥控 掌中手机跟踪定位器 电动万能开锁器
隐形墨水 万能钥匙 针孔照相机
定期消失笔 电表控制器 信用卡设备
手机窃听器(骇客手机),在手机里置入一个监听芯片,需要监听的时候,
我幌卤鹑说氖只号码,加入预留的密码,对方的手机将在不响铃的情况下自动
油ǎ处入接听状态,而对方却毫无知觉,然后对方周围的一切却尽收你的耳中
此手机用于追踪调查对婚姻不忠和包二奶的可收到奇特效果,还可用于家长监
学生是否迷恋网吧等
电动万能开锁器:电动工具是目前世界上最先进的工具之一,它由直流电机组
成,可充电,充一次电用100个小时左右。它的使用比较广阔,能开电脑打孔锁
⑵车锁、片子锁、挂锁、一字型门锁等。它附件有6个探针,开什么锁换什么样
的针,深入锁空内靠住撬杆,如果第一个弹子长,手向下倾斜;第一个弹子短,
向上移动。扣动开关,每秒种50下的跳动速度,是您能达到瞬间开启。
本工具保修一年。
声明:本网站所有产品均为娱乐所用,不得用于非法用途。否则后果自负。
联系人:陈先生
电话:020-33887956 传 真:020-33887959 服务热线: 13828469020
电子邮箱:webmaster@fd788.com
国际网站:http://www.wst00.com/
……
2.2 MMPI库解析提取邮件数据
相比之下,在提取中文邮件标题方面,MMPI库要方便得多,因为它能自动转码,直接得到解析后的中文标题,而不是需要我们对着一串奇怪的字符伤脑筋。
参考资料:基于Python的邮件快速检测工具库
emp = mmpi()
emp.parse('trec06c/data/001/281')
report = emp.get_report()
print("mmpi解析结果: ", report)
f = open('trec06c/data/001/281', 'r', encoding='gb2312', errors='ignore')
text = ''
for line in f:
line = line.strip().strip('\n')
if len(line) > 1:
print(line)
text = text + line
然而,碰到像上面这封邮件,mmpi库在内容读取方面却“翻车”了……中英文夹杂的时候,中文邮件内容却出现了乱码……看来,要找到两全其美的办法,同时支持解析中文邮件的标题和内容,应对各种复杂场景,还需要考虑更多。
不过,既然可以提取中文邮件标题了,我们不妨试试看,只提取邮件标题,分类效果如何。毕竟,在使用PaddleNLP识别垃圾邮件(三):用ELECTRA做英文邮件标题分类项目中,根据英文邮件标题进行垃圾邮件分类的准确率达到了96%以上,也许中文邮件分类也能尝试一下?
使用mmpi库,要查看邮件标题,其实非常简单,像下面这样:
report['headers'][0]['Subject']
'Re: 教父这个烂片'
那么接下来,就是将中文邮件标题的提取过程和处理过程串接起来,生成可以让PaddleNLP训练的数据集了。
2.3 提取中文邮件标题,划分训练集、验证集、测试集
# 去掉非中文字符
def clean_str(string):
string = re.sub(r"[^\u4e00-\u9fff]", " ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip()
# 从指定路径读取邮件文件内容信息
def get_data_in_a_file(original_path, save_path='all_email.txt'):
emp = mmpi()
emp.parse(original_path)
report = emp.get_report()
# 如果可以解析到邮件头信息
if report.get('headers') is not None:
return clean_str(report['headers'][0]['Subject'])
中文邮件标题提取的时候有时候会提取到''
——注意是''
不是None!
所以要进行过滤,经测试,处理了这个问题之后,提取的数据中就没有脏数据了!可以放心使用
# 读取标签文件信息
f = open('trec06c/full/index', 'r')
for line in f:
str_list = line.split(" ")
# 设置垃圾邮件的标签为0
if str_list[0] == 'spam':
label = '0'
# 设置正常邮件标签为1
elif str_list[0] == 'ham':
label = '1'
text = get_data_in_a_file('trec06c/full/' + str(str_list[1].split("\n")[0]))
# 注意,处理逻辑不是text非空,实际上这样还是会生成脏数据,应该用下面的方法
if text is not '':
with open("all_email.txt","a+") as f:
f.write(text + '\t' + label + '\n')
data_list_path="./"
with open(os.path.join(data_list_path, 'eval_list.txt'), 'w', encoding='utf-8') as f_eval:
f_eval.seek(0)
f_eval.truncate()
with open(os.path.join(data_list_path, 'train_list.txt'), 'w', encoding='utf-8') as f_train:
f_train.seek(0)
f_train.truncate()
with open(os.path.join(data_list_path, 'test_list.txt'), 'w', encoding='utf-8') as f_test:
f_test.seek(0)
f_test.truncate()
with open(os.path.join(data_list_path, 'all_email.txt'), 'r', encoding='utf-8') as f_data:
lines = f_data.readlines()
i = 0
with open(os.path.join(data_list_path, 'eval_list.txt'), 'a', encoding='utf-8') as f_eval,open(os.path.join(data_list_path, 'test_list.txt'), 'a', encoding='utf-8') as f_test,open(os.path.join(data_list_path, 'train_list.txt'), 'a', encoding='utf-8') as f_train:
for line in lines:
words = line.split('\t')[-1].replace('\n', '')
label = line.split('\t')[0]
labs = ""
# 划分验证集
if i % 10 == 1:
labs = label + '\t' + words + '\n'
f_eval.write(labs)
# 划分测试集
elif i % 10 == 2:
labs = label + '\t' + words + '\n'
f_test.write(labs)
# 划分训练集
else:
labs = label + '\t' + words + '\n'
f_train.write(labs)
i += 1
print("数据列表生成完成!")
数据列表生成完成!
2.4 自定义数据集
在之前的几个项目里,有很多读者反馈项目里能正常执行的代码,到本地就会报错,尤其是自定义数据集部分。
其实,这是因为项目环境和本地安装环境的不同。在之前的项目里,paddlepaddle的版本一般是2.0.2,而paddlenlp的版本是2.0rc,正因如此,原来才能这样自定义数据集:
class SelfDefinedDataset(paddle.io.Dataset):
def __init__(self, data):
super(SelfDefinedDataset, self).__init__()
self.data = data
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def get_labels(self):
return ["0", "1"]
def txt_to_list(file_name):
res_list = []
for line in open(file_name):
res_list.append(line.strip().split('\t'))
return res_list
trainlst = txt_to_list('train_list.txt')
devlst = txt_to_list('eval_list.txt')
testlst = txt_to_list('test_list.txt')
train_ds, dev_ds, test_ds = SelfDefinedDataset.get_datasets([trainlst, devlst, testlst])
但如果读者使用的是本地环境,安装的paddlenlp版本可能是2.0.2,就像当前项目的环境,那么自定义数据集时就会出现报错:
---------------------------------------------------------------------------AttributeError Traceback (most recent call last)<ipython-input-33-f4bfc4346c92> in <module>
23 testlst = txt_to_list('test_list.txt')
24
---> 25 train_ds, dev_ds, test_ds = SelfDefinedDataset.get_datasets([trainlst, devlst, testlst])
AttributeError: type object 'SelfDefinedDataset' has no attribute 'get_datasets'
也正是因此,我们需要参考最新paddlenlp文档自定义数据集
2.4.1 从本地文件创建数据集
从本地文件创建数据集时,官方推荐根据本地数据集的格式给出读取function并传入 load_dataset()
中创建数据集。
from paddlenlp.datasets import load_dataset
def read(data_path):
with open(data_path, 'r', encoding='utf-8') as f:
# 跳过列名
next(f)
for line in f:
words, labels = line.strip('\n').split('\t')
words = words.split('\002')
labels = labels.split('\002')
yield {'text': words[0], 'label': labels[0]}
# data_path为read()方法的参数
train_ds = load_dataset(read,data_path='train_list.txt',splits='train',lazy=False)
dev_ds = load_dataset(read,data_path='eval_list.txt',splits='dev',lazy=False)
test_ds = load_dataset(read,data_path='test_list.txt',splits='test',lazy=False)
官方文档说明了这么做的理由:
- 将数据读取代码写成生成器(generator)的形式,这样可以更好的构建 MapDataset 和 IterDataset 两种数据集
- 将单条数据写成字典的格式,这样可以更方便的监测数据流向
- 事实上,MapDataset在绝大多数时候都可以满足要求。一般只有在数据集过于庞大无法一次性加载进内存的时候我们才考虑使用IterDataset。任何人都可以方便的定义属于自己的数据集。
需要注意的是,只有PaddleNLP内置的数据集具有将数据中的label自动转为id的功能(详细条件参见 创建DatasetBuilder)。
自定义数据集需要在自定义的convert to feature方法中添加label转id的功能——当然,在本项目中,已经直接在前面的数据预处理阶段将label转id,就免去了这个步骤。
2.4.2 从 paddle.io.Dataset/IterableDataset
创建数据集
在之前的项目中,我们其实都是用的这种方法自定义数据集。事实上,是用PaddleNLP内置的 MapDataset
和 IterDataset
API。由于API有更新,注意到这里不需要原来的get_datasets
函数了,直接用。
class SelfDefinedDataset(paddle.io.Dataset):
def __init__(self, path):
def load_data_from_source(path):
data = []
for line in open(path):
data.append(line.strip().split('\t'))
return data
self.data = load_data_from_source(path)
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
train_ds = SelfDefinedDataset('train_list.txt')
dev_ds = SelfDefinedDataset('eval_list.txt')
test_ds = SelfDefinedDataset('test_list.txt')
效果如下
训练集数据:[['非财务经理的财务管理 沙盘模拟', '0'], ['低点代开发票', '0'], ['一边上网冲浪 一边赚钱 何乐而不为', '0']]
验证集数据:[['问一部魏宗万的电影名称', '1'], ['低点代开发票', '0'], ['深圳协恒实业有限公司', '0']]
测试集数据:[['公司业务 代开发票', '0'], ['合作', '0'], ['公司业务 优惠发票', '0']]
2.5 训练数据分析
#看看数据长什么样子,分别打印训练集、验证集、测试集的前3条数据。
print("训练集数据:{}\n".format(train_ds[0:3]))
print("验证集数据:{}\n".format(dev_ds[0:3]))
print("测试集数据:{}\n".format(test_ds[0:3]))
print("训练集样本个数:{}".format(len(train_ds)))
print("验证集样本个数:{}".format(len(dev_ds)))
print("测试集样本个数:{}".format(len(test_ds)))
训练集数据:[{'text': '低点代开发票', 'label': '0'}, {'text': '一边上网冲浪 一边赚钱 何乐而不为', 'label': '0'}, {'text': '优惠代开各种发票', 'label': '0'}]
验证集数据:[{'text': '低点代开发票', 'label': '0'}, {'text': '深圳协恒实业有限公司', 'label': '0'}, {'text': '帮帮忙啊', 'label': '1'}]
测试集数据:[{'text': '合作', 'label': '0'}, {'text': '公司业务 优惠发票', 'label': '0'}, {'text': '业务合作', 'label': '0'}]
训练集样本个数:2905
验证集样本个数:363
测试集样本个数:363
# 统计训练集正负样本数量
spam = 0
for data in train_ds:
if data['label'] == '0':
spam += 1
print("训练集垃圾邮件数量:{}".format(spam))
print("训练集正常邮件数量:{}".format(len(train_ds) - spam))
训练集垃圾邮件数量:2157
训练集正常邮件数量:748
情况稍微比较不妙,看来中文邮件分类的数据集中,正负样本比例不能算特别均衡,但是又好像在可以接受的范围之内,我们可以继续往下看。
三、模型训练
3.1 模型选择
在PaddleNLP Transformer API中选择RoBERTa算法,考虑到中文邮件标题分类这个场景,要分类的基本都是非常短的语句,也许不宜使用太深的神经网络,因此rbt3
似乎就是一个很好的选择。
3.2 准备PaddleNLP模型库
一旦弄清楚如何自定义PaddleNLP数据集,就可以根据PaddleNLP模型库提供的预训练模型使用方法,对使用预训练模型Fine-tune完成中文文本分类任务目录下的文件稍作修改,开始迁移学习训练。
根据文档说明,主要步骤如下:
- 加载数据集:PaddleNLP内置了多种数据集,内置的数据集可以一键导入,在本项目中,则需要用2.4.1的处理方法进行替代。
# 注释掉内置数据集的导入
# train_ds, dev_ds = load_dataset("chnsenticorp", splits=["train", "dev"])
# data_path为read()方法的参数,用本地数据集导入进行替换。
train_ds = load_dataset(read,data_path='/home/aistudio/train_list.txt',splits='train',lazy=False)
dev_ds = load_dataset(read,data_path='/home/aistudio/eval_list.txt',splits='dev',lazy=False)
test_ds = load_dataset(read,data_path='/home/aistudio/test_list.txt',splits='test',lazy=False)
- 加载预训练模型:PaddleNLP的预训练模型可以很容易地通过
from_pretrained()
方法加载。 第一个参数是汇总表中对应的Pretrained Weight
,可加载对应的预训练权重。BertForSequenceClassification
初始化__init__
所需的其他参数,如num_classes
等, 也是通过from_pretrained()
传入。Tokenizer
使用同样的from_pretrained
方法加载。
model = ppnlp.transformers.RobertaForSequenceClassification.from_pretrained('rbt3', num_class=2)
tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained('rbt3')
-
通过
Dataset
的map
函数,使用tokenizer
将dataset
从原始文本处理成模型的输入。 -
定义
BatchSampler
和DataLoader
,shuffle数据、组合Batch。 -
定义训练所需的优化器,loss函数等,就可以开始进行模型fine-tune任务——在本项目中,可以结合中文邮件标题文本的特点,重点注意修改
max_seq_length
和batch_size
超参数。 -
【项目新增】引入VisualDL进行训练过程的可视化。
with LogWriter(logdir="./logdir") as writer:
……
#记录训练过程
writer.add_scalar(tag="train/loss", step=global_step, value=loss)
writer.add_scalar(tag="train/acc", step=global_step, value=acc)
……
#记录评估过程
writer.add_scalar(tag="eval/loss", step=epoch, value=eval_loss)
writer.add_scalar(tag="eval/acc", step=epoch, value=eval_acc)
# 拉取模型库
!git clone https://gitee.com/paddlepaddle/PaddleNLP.git
# 将文本分类迁移学习训练目录单独移出来,便于查看和修改
# 本项目修改后的结果,可以直接查看./pretrained_models目录
# !mv PaddleNLP/examples/text_classification/pretrained_models ./pretrained_models
3.3 开始训练
!python pretrained_models/train.py
最终,模型在测试集上准确率达到98.3%以上,分类效果还不错。
test result...
eval loss: 0.06442, accu: 0.98347
训练过程中,可支持配置的参数如下:
save_dir
:可选,保存训练模型的目录;默认保存在当前目录checkpoints文件夹下。max_seq_length
:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。batch_size
:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。learning_rate
:可选,Fine-tune的最大学习率;默认为5e-5。weight_decay
:可选,控制正则项力度的参数,用于防止过拟合,默认为0.00。epochs
: 训练轮次,默认为3。warmup_proption
:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。init_from_ckpt
:可选,模型参数路径,热启动模型训练;默认为None。seed
:可选,随机种子,默认为1000.device
: 选用什么设备进行训练,可选cpu或gpu。如使用gpu训练则参数gpus指定GPU卡号。
3.4 模型保存
训练过程中会自动保存模型在指定的save_dir
中。 如:
checkpoints/
├── model_100
│ ├── model_config.json
│ ├── model_state.pdparams
│ ├── tokenizer_config.json
│ └── vocab.txt
└── ...
我们不妨来看看json文件的具体内容,比如model_config.json
定义了我们使用的rbt3
模型进行序列分类任务的具体信息
{ "init_args": [
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 3,
"type_vocab_size": 2,
"vocab_size": 21128,
"pad_token_id": 0,
"init_class": "RobertaModel"
}
],
"init_class": "RobertaForSequenceClassification"
}
而tokenizer_config.json
主要为了定位vocab.txt
{
"do_lower_case":true
"vocab_file":"/home/aistudio/.paddlenlp/models/rbt3/vocab.txt"
"init_class":"RobertaTokenizer"
}
3.5 模型导出和可视化
使用动态图训练结束之后,可以运行pretrained_models/export_model.py
将动态图参数导出成静态图参数,静态图参数保存在--output_path
指定路径中,在保存的静态图信息中,我们还可以在VisualDL中查看模型的网络结构。
很显然,rbt3
的网络结构要简单多了,毕竟layer只有3层。
对于PaddleNLP模型库提供的默认配置,在pretrained_models/export_model.py
中,最重要的是修改下面两行代码:
# 根据数据集,指定标签的写法
label_map = {0: 'spam', 1: 'ham'}
# 加载训练时选定的预训练模型
model = ppnlp.transformers.RobertaForSequenceClassification.from_pretrained('rbt3', num_class=len(label_map))
# 模型导出
= ppnlp.transformers.RobertaForSequenceClassification.from_pretrained('rbt3', num_class=len(label_map))
# 模型导出
!python pretrained_models/export_model.py --params_path=./checkpoint/model_360/model_state.pdparams --output_path=./checkpoint/static_graph_params
3.6 模型的预测和部署
在预测阶段,本项目希望把数据处理流程在预测和部署时串接起来。因此,无论是动态图预测,还是动态图转静态图部署之后的预测,都有下列事项需要注意:
-
配置预训练模型为
rbt3
,需要更换model 和 tokenizer -
将数据处理流程加入预测阶段代码中,当然,考虑到带预测的邮件文件数量比较多,这里加入了限制条件,选取三封邮件进行抽样,实际使用时可以根据需要,删除限制条件。
修改后的文件,参考pretrained_models/predict.py
和pretrained_models/deploy/python/predict.py
,部分关键代码:
# 引入数据处理工具库
from mmpi import mmpi
import re
......
# 定义处理函数
# 去掉非中文字符
def clean_str(string):
string = re.sub(r"[^\u4e00-\u9fff]", " ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip()
# 从指定路径读取邮件文件内容信息
def get_data_in_a_file(original_path):
emp = mmpi()
emp.parse(original_path)
report = emp.get_report()
# 如果可以解析到邮件头信息
if report.get('headers') is not None:
return clean_str(report['headers'][0]['Subject'])
......
# 将处理结果传入输入预测的数据list中
data = []
for line in open('/home/aistudio/trec06c/full/index', 'r'):
str_list = line.split(" ")
if str_list[1].split("\n")[0] in ['../data/000/011','../data/000/021','../data/000/031']:
text = get_data_in_a_file('trec06c/full/' + str(str_list[1].split("\n")[0]))
data.append(text)
......
3.6.1 动态图预测
可以直接调用predict函数输出预测结果。
!python pretrained_models/predict.py --params_path=./checkpoint/model_360/model_state.pdparams
3.6.2 python部署预测
!python pretrained_models/deploy/python/predict.py --model_file=./checkpoint/static_graph_params.pdmodel --params_file=./checkpoint/static_graph_params.pdiparams
预测结果如下
Data: 低点代开发票 Label: 垃圾邮件
Data: 深圳协恒实业有限公司 Label: 垃圾邮件
Data: 帮帮忙啊 Label: 正常邮件
参考资料
- RoBERTa: A Robustly Optimized BERT Pretraining Approach
- RoBERTa项目中文地址:https://github.com/brightmart/roberta_zh
- RoBERTa模型原理总结
- RoBERTa 详解
小结
- 其实只要弄清了最关键的一步,即如何自定义数据集,使用PaddleNLP库进行文本分类的迁移学习还是非常方便的。特别要注意的,就是不同的自定义数据集方法返回结果的差异:
过去,一般是这样……
训练集数据:[['非财务经理的财务管理 沙盘模拟', '0'], ['低点代开发票', '0'], ['一边上网冲浪 一边赚钱 何乐而不为', '0']]
验证集数据:[['问一部魏宗万的电影名称', '1'], ['低点代开发票', '0'], ['深圳协恒实业有限公司', '0']]
测试集数据:[['公司业务 代开发票', '0'], ['合作', '0'], ['公司业务 优惠发票', '0']]
现在,PaddleNLP模型库推荐的是这样……
训练集数据:[{'text': '低点代开发票', 'label': '0'}, {'text': '一边上网冲浪 一边赚钱 何乐而不为', 'label': '0'}, {'text': '优惠代开各种发票', 'label': '0'}]
验证集数据:[{'text': '低点代开发票', 'label': '0'}, {'text': '深圳协恒实业有限公司', 'label': '0'}, {'text': '帮帮忙啊', 'label': '1'}]
测试集数据:[{'text': '合作', 'label': '0'}, {'text': '公司业务 优惠发票', 'label': '0'}, {'text': '业务合作', 'label': '0'}]
- 中文邮件短标题的分类效果能达到98.3%以上,比英文邮件的92%高出不少,这有可能是数据集的问题,当然,有可能是这个场景下,确实英文垃圾邮件不太明显?(这是非常有可能的,注册过几个外网网站你就知道,即使是那些国外大厂,也超爱发各种广告邮件),但是我们也应当意识到,由于数据集有限,中文邮件短标题面对当前的实际场景,效果可能会比较一般,这个模型只是针对当前数据集的较优方案。开发适合用户需要的垃圾邮件分类器,肯定还是需要整合用户自己的邮件分类数据和标注。
更多推荐
所有评论(0)