在数据增强、蒸馏剪枝下ERNIE3.0模型性能提升

以CBLUE数据集中医疗搜索检索词意图分类为例:

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,结果预览:

模型ACCPrecisionRecallF1average_of_acc_and_f1
ERNIE 3.0 Base0.802550.93171470.9082840.9198500.86120
ERNIE 3.0 Base+数据增强0.79795390.9010040.928990.914780.8563
ERNIE 3.0 Base+剪裁保留比0.50.798460.9512570.894970.922250.8603
ERNIE 3.0 Base +剪裁保留比2/30.80920710.94153840.9053250.9230760.86614

1.环境安装

gensim安装最新版本:pip install gensim

tqdm安装:pip install tqdm

LAC安装最新版本:pip install lac


Gensim库介绍

Gensim是在做自然语言处理时较为经常用到的一个工具库,主要用来以无监督的方式从原始的非结构化文本当中来学习到文本隐藏层的主题向量表达。

主要包括TF-IDF,LSA,LDA,word2vec,doc2vec等多种模型。

Tqdm

是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。目的为了程序显示的美观

中文词法分析-LAC

LAC是一个联合的词法分析模型,整体性地完成中文分词、词性标注、专名识别任务。LAC既可以认为是Lexical Analysis of Chinese的首字母缩写,也可以认为是LAC Analyzes Chinese的递归缩写。

LAC基于一个堆叠的双向GRU结构,在长文本上准确复刻了百度AI开放平台上的词法分析算法。效果方面,分词、词性、专名识别的整体准确率95.5%;单独评估专名识别任务,F值87.1%(准确90.3,召回85.4%),总体略优于开放平台版本。在效果优化的基础上,LAC的模型简洁高效,内存开销不到100M,而速度则比百度AI开放平台提高了57%

在这里插入图片描述

LAC链接:https://www.paddlepaddle.org.cn/modelbasedetail/lac

!pip install --upgrade paddlenlp
!pip install gensim
!pip install tqdm
!pip install lac

2.数据增强方案介绍

数据增强工具提供4种增强策略:遮盖、删除、同词性词替换、词向量近义词替换

在这里插入图片描述

!unzip ERNIE-.zip -d ./ERNIE
#添加ERNIE工具包
# %cd ./train
# %cd ..
# !rm -rf .ipynb_checkpoints
如果程序报错:
可以发现提示有一个.ipynb_checkpoints的文件。但当我去对应的文件夹找时根本看不到这个文件,所以猜测是一个隐藏文件。所以通过终端进入对应的目录:输入cd coco进入对应目录,输入ls -a显示所有文件。然后输入rm -rf .ipynb_checkpoints删除该文件。再次输入ls -a查看文件是否被删除。

下载词表,词表有1.7G会花点时间。下面以情感分析数据样例展示demo,看看数据增强的效果。

!wget -q --no-check-certificate http://bj.bcebos.com/wenxin-models/vec2.txt

python data_aug.py “输入文件夹的目录” “输出文件夹的目录”

  • data_aug.py脚本传参说明
shell输入:
    python data_aug.py -h

shell输出:
    usage: data_aug.py [-h] [-n AUG_TIMES] [-c COLUMN_NUMBER] [-u UNK]
                       [-t TRUNCATE] [-r POS_REPLACE] [-w W2V_REPLACE]
                       [-e ERNIE_REPLACE] [--unk_token UNK_TOKEN]
                       input output
    
    main
    
    positional arguments:
      input                                                #原始待增强数据文件所在文件夹,带label的,一个或多个文本列
      output                                               #输出文件路径
    
    optional arguments:
      -h, --help            show this help message and exit
      -n AUG_TIMES, --aug_times AUG_TIMES                  #数据集数目放大n倍,output行数为input的n+1倍      
      -c COLUMN_NUMBER, --column_number COLUMN_NUMBER      #明文文件中所要增强列的列序号,多列用逗号分割,如:1,2
      -u UNK, --unk UNK                                    #unk 增强策略的概率
      -t TRUNCATE, --truncate TRUNCATE                     #truncate 增强策略的概率
      -r POS_REPLACE, --pos_replace POS_REPLACE            #pos_replace 增强策略的概率
      -w W2V_REPLACE, --w2v_replace W2V_REPLACE            #w2v_replace 增强策略的概率
      --unk_token UNK_TOKEN                    

分类问题中:推荐使用前三种即可,w2v词向量近义词替换可以不用,花费时间太长。

!python data_aug.py --unk 0.25 --truncate 0.25 --pos 0.5 --w2v 0 ./train ./output

demo结果展示:

机器 背面 似乎 被 撕 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪	0
机器 背面 似乎 被 撕 了 张 什么 标签 , 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪	0
机器 背面 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签  了 , 该在 , 怪	0
呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。	0
呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我😄妈 爱 看 , 我自己 也 学 着 找 一些 穴位 😄	0
呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还😄 能 看得出来 是 盗😄😄😄。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 ,😄😄😄😄😄😄😄学 着 找 😄😄😄😄😄😄😄	0
😄😄😄😄😄虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。	0
😄😄😄😄😄😄😄 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。	0
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近。。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 机器 还算 干净 , 离 湖南路小吃街 近 。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。	1
地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽
我 看 是 书 的 还 可以 , 但是 我 订 的 书 迟迟 还 到 能 半个月 , 都 没有 收到 打电话 也 没

3.数据蒸馏技术

ERNIE数据蒸馏三步

Step 1. 使用ERNIE模型对输入标注数据对进行fine-tune,得到Teacher Model

Step 2. 使用ERNIE Service对以下无监督数据进行预测:

  • 用户提供的大规模无标注数据,需与标注数据同源
  • 对标注数据进行数据增强,具体增强策略
  • 对无标注数据和数据增强数据进行一定比例混合

Step 3. 使用步骤2的数据训练出Student Model

数据增强

目前采用三种数据增强策略策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:

添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签

同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集钟随机一个同词性的词

N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值

在这里插入图片描述

模型剪裁,基于 PaddleNLP 的 Trainer API 发布提供了模型裁剪 API。裁剪 API 支持用户对 ERNIE 等Transformers 类下游任务微调模型进行裁剪。

具体效果在下一节展现,先安装好paddleslim库

!pip install paddleslim

4.基于ERNIR3.0文本模型微调

加载已有数据集:CBLUE数据集中医疗搜索检索词意图分类(训练)

数据集定义:
以公开数据集CBLUE数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为示例,在训练集上进行模型微调,并在开发集上使用准确率Accuracy评估模型表现。

数据集默认为:默认为"cblue"。

save_dir:保存训练模型的目录;默认保存在当前目录checkpoint文件夹下。

dataset:训练数据集;默认为"cblue"。

dataset_dir:本地数据集路径,数据集路径中应包含train.txt,dev.txt和label.txt文件;默认为None。

task_name:训练数据集;默认为"KUAKE-QIC"。

max_seq_length:ERNIE模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。

model_name:选择预训练模型;默认为"ernie-3.0-base-zh"。

device: 选用什么设备进行训练,可选cpu、gpu、xpu、npu。如使用gpu训练,可使用参数gpus指定GPU卡号。

batch_size:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。

learning_rate:Fine-tune的最大学习率;默认为6e-5。

weight_decay:控制正则项力度的参数,用于防止过拟合,默认为0.01。

early_stop:选择是否使用早停法(EarlyStopping);默认为False。

early_stop_nums:在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为4。
epochs: 训练轮次,默认为100。

warmup:是否使用学习率warmup策略;默认为False。

warmup_proportion:学习率warmup策略的比例数,如果设为0.1,则学习率会在前10%steps数从0慢慢增长到learning_rate, 而后再缓慢衰减;默认为0.1。

logging_steps: 日志打印的间隔steps数,默认5。

init_from_ckpt: 模型初始checkpoint参数地址,默认None。

seed:随机种子,默认为3。

# !pip install --upgrade paddlenlp
# !python train.py --warmup --early_stop --epochs 10 --model_name "ernie-3.0-base-zh" --max_seq_length 128 --batch_size 32 --logging_steps 10 --learning_rate 6e-5 
!python train.py --warmup --early_stop --epochs 5 --model_name ernie-3.0-base-zh --batch_size 16
#修改后的训练文件train_new2.py ,主要使用了paddlenlp.metrics.glue的AccuracyAndF1:准确率及F1-score,可用于GLUE中的MRPC 和QQP任务
#不过吐槽一下:    return (acc,precision,recall,f1,(acc + f1) / 2,) 最后一个指标竟然是加权平均.....
!python train_new2.py --warmup --early_stop --epochs 10 --save_dir "./checkpoint2" --batch_size 16 --model_name ernie-3.0-base-zh

训练结果部分展示:

[2022-08-16 19:58:36,834] [    INFO] - global step 1280, epoch: 3, batch: 412, loss: 0.23292, acc: 0.87106, speed: 16.54 step/s
[2022-08-16 19:58:37,392] [    INFO] - global step 1290, epoch: 3, batch: 422, loss: 0.22339, acc: 0.87130, speed: 17.94 step/s
[2022-08-16 19:58:37,960] [    INFO] - global step 1300, epoch: 3, batch: 432, loss: 0.22791, acc: 0.87182, speed: 17.68 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8025575447570332, 0.9317147192716236, 0.908284023668639, 0.9198501872659175, 0.8612038660114754)

[2022-08-16 20:01:36,060] [ INFO] - Early stop!
[2022-08-16 20:01:36,060] [ INFO] - Save best accuracy text classification model in ./checkpoint2

程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存开发集上最佳模型在指定的 save_dir 中,保存模型文件结构如下所示:

checkpoint/
├── model_config.json
├── model_state.pdparams
├── tokenizer_config.json
└── vocab.txt

NOTE:

如需恢复模型训练,则可以设置 init_from_ckpt , 如init_from_ckpt=checkpoint/model_state.pdparams

如需训练中文文本分类任务,只需更换预训练模型参数 model_name 。中文训练任务推荐使用"ernie-3.0-base-zh",更多可选模型可参考Transformer预训练模型。

4.1 加载自定义数据集(并通过数据增强训练)

从本地文件创建数据集

使用本地数据集来训练我们的文本分类模型,本项目支持使用固定格式本地数据集文件进行训练
如果需要对本地数据集进行数据标注,可以参考文本分类任务doccano数据标注使用指南进行文本分类数据标注。[这个放到下个项目讲解]

本项目将以CBLUE数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为例进行介绍如何加载本地固定格式数据集进行训练:

本地数据集目录结构如下:

data/
├── train.txt # 训练数据集文件
├── dev.txt # 开发数据集文件
├── label.txt # 分类标签文件
└── data.txt # 可选,待预测数据文件
# !wget https://paddlenlp.bj.bcebos.com/datasets/KUAKE_QIC.tar.gz
!tar -zxvf KUAKE_QIC.tar.gz
# !mv KUAKE_QIC data
KUAKE_QIC/
KUAKE_QIC/data.txt
KUAKE_QIC/train.txt
KUAKE_QIC/dev.txt
KUAKE_QIC/label.txt
# %cd ./KUAKE_QIC/train
# !rm -rf .ipynb_checkpoints
# %cd ./..
/home/aistudio
!python data_aug.py --unk 0 --truncate 0 --pos 1 --w2v 0 ./KUAKE_QIC/train ./output1

这里数据增强的时候不推荐使用,随机删除

会报错:ValueError: not enough values to unpack (expected 2, got 1)

解决方案:

(1)文档和代码分割的符号要统一,尽量换成英文符号。

(2)不能有多余的换行符,文档最后要检查一下,是不是有空白的几行,因为没有文字容易被人漏掉。

比较麻烦,因为会造成分割混乱,建议使用其他三种方式增强。

在训练过程中通过指定数据集路径参数dataset_dir进行: 单卡训练

python train.py --warmup --dataset_dir data/KUAKE_QIC

dataset_dir:本地数据集路径,数据集路径中应包含train.txt,dev.txt和label.txt文件;默认为None。

!mv output1/train_aug.txt ./KUAKE_QIC
%cd ./KUAKE_QIC
!mv train_aug.txt train.txt
%cd ..
#移动文件并重命名
/home/aistudio/KUAKE_QIC
/home/aistudio
!python train_new2.py --warmup --early_stop --epochs 10 --save_dir "./checkpoint3" --batch_size 16 --model_name ernie-3.0-base-zh --dataset_dir KUAKE_QIC

部分结果展示

[2022-08-16 23:43:18,093] [    INFO] - global step 2400, epoch: 2, batch: 234, loss: 0.60859, acc: 0.84437, speed: 19.27 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7979539641943734, 0.9010043041606887, 0.9289940828402367, 0.9147851420247632, 0.8563695531095683)
[2022-08-16 23:43:24,522] [    INFO] - Save best F1 text classification model in ./checkpoint3
[2022-08-16 23:43:24,523] [    INFO] - best F1 performence has been updated: 0.91450 --> 0.91479

4.2 数据蒸馏

静态图导出

使用动态图训练结束之后,还可以将动态图参数导出成静态图参数,静态图模型将用于后续的推理部署工作。具体代码见静态图导出脚本,静态图参数保存在output_path指定路径中。

可支持配置的参数:

  • params_path:动态图训练保存的参数路径;默认为"./checkpoint/"。
  • output_path:静态图图保存的参数路径;默认为"./export"。

程序运行时将会自动导出模型到指定的 output_path 中,保存模型文件结构如下所示:

export/
├── float32.pdiparams
├── float32.pdiparams.info
└── float32.pdmodel
!python export_model.py \
    --params_path ./checkpoint3/ \
    --output_path ./export
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/setuptools/depends.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
[32m[2022-08-17 11:33:00,031] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForSequenceClassification'> to load './checkpoint3/'.[0m
[0m
!pip install paddleslim==2.2.2
!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \
    --device "gpu" \
    --output_dir "./prune" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --learning_rate 3e-5 \
    --num_train_epochs 5 \
    --logging_steps 10 \
    --save_steps 50 \
    --seed 3 \
    --dataset_dir "KUAKE_QIC" \
    --max_seq_length 128 \
    --params_dir "./checkpoint3" \
    --width_mult '0.5'

部分结果展示:

[2022-08-17 14:22:30,954] [    INFO] - width_mult: 0.5, eval loss: 0.63535, acc: 0.79847
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7984654731457801, 0.9512578616352201, 0.8949704142011834, 0.9222560975609755, 0.8603607853533778)
[2022-08-17 14:22:35,870] [    INFO] - Save best F1 text classification model in ./prune/0.5
[2022-08-17 14:22:35,870] [    INFO] - best F1 performence has been updated: 0.92226 --> 0.92226
!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \
    --device "gpu" \
    --output_dir "./prune" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --learning_rate 3e-5 \
    --num_train_epochs 5 \
    --logging_steps 10 \
    --save_steps 50 \
    --seed 3 \
    --dataset_dir "KUAKE_QIC" \
    --max_seq_length 128 \
    --params_dir "./checkpoint3" \
    --width_mult '2/3'
2022-08-17 14:53:45,544] [    INFO] - global step 3070, epoch: 2, batch: 904, loss: 0.709566, speed: 9.93 step/s
[2022-08-17 14:53:46,550] [    INFO] - global step 3080, epoch: 2, batch: 914, loss: 0.607238, speed: 9.94 step/s
[2022-08-17 14:53:47,558] [    INFO] - global step 3090, epoch: 2, batch: 924, loss: 0.718484, speed: 9.93 step/s
[2022-08-17 14:53:48,563] [    INFO] - global step 3100, epoch: 2, batch: 934, loss: 0.546288, speed: 9.95 step/s
[2022-08-17 14:53:50,206] [    INFO] - teacher model, eval loss: 0.66438, acc: 0.80358
[2022-08-17 14:53:50,207] [    INFO] - eval done total : 1.6434180736541748 s
[2022-08-17 14:53:53,568] [    INFO] - width_mult: 0.6666666666666666, eval loss: 0.60219, acc: 0.80921
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8092071611253197, 0.9415384615384615, 0.9053254437869822, 0.923076923076923, 0.8661420421011213)
[2022-08-17 14:53:58,489] [    INFO] - Save best F1 text classification model in ./prune/0.6666666666666666
[2022-08-17 14:53:58,489] [    INFO] - best F1 performence has been updated: 0.92308 --> 0.92308

使用多卡训练可以指定多个GPU卡号,例如 --gpus “0,1”。如果设备只有一个GPU卡号默认为0,可使用nvidia-smi命令查看GPU使用情况。

可支持配置的参数:

TrainingArguments

output_dir:必须,保存模型输出和和中间checkpoint的输出目录;默认为 None 。
device: 选用什么设备进行裁剪,选择cpu、gpu。如使用gpu训练,可使用参数–gpus指定GPU卡号。

per_device_train_batch_size:训练集裁剪训练过程批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。

per_device_eval_batch_size:开发集评测过程批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。

learning_rate:训练最大学习率;默认为3e-5。

num_train_epochs: 训练轮次,使用早停法时可以选择100;默认为10。

logging_steps: 训练过程中日志打印的间隔steps数,默认5。

save_steps: 训练过程中保存模型checkpoint的间隔steps数,默认100。

seed:随机种子,默认为3。

DataArguments

dataset_dir:本地数据集路径,需包含train.txt,dev.txt,label.txt;默认为None。

**max_seq_length:**模型使用的最大序列长度,建议与训练过程保持一致, 若出现显存不足,请适当调低这一参数;默认为128。
ModelArguments

**params_dir:**待预测模型参数文件;默认为"./checkpoint/"。

width_mult:裁剪宽度保留的比例,表示对self_attention中的 q、k、v 以及 ffn 权重宽度的保留比例,默认是 ‘2/3’。

以上参数都可通过 python prune.py --dataset_dir xx --params_dir xx 的方式传入)

4.3 模型预测

输入待预测数据和数据标签对照列表,模型预测数据对应的标签

使用默认数据进行预测:

!python predict.py --params_path ./checkpoint3/ 
#也可以选择使用本地数据文件data/data.txt进行预测:
!python predict.py --params_path ./checkpoint3/ --dataset_dir ./KUAKE_QIC --device "cpu"
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/setuptools/depends.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
[32m[2022-08-17 15:21:28,086] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieForSequenceClassification'> to load './checkpoint3/'.[0m
[32m[2022-08-17 15:21:41,231] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load './checkpoint3/'.[0m
[32m[2022-08-17 15:21:41,564] [    INFO][0m - Prediction results save in output.txt.[0m
黑苦荞茶的功效与作用及食用方法 功效作用
交界痣会凸起吗 疾病表述
检查是否能怀孕挂什么科 就医建议
鱼油怎么吃咬破吃还是直接咽下去 其他
幼儿挑食的生理原因是 病因分析
[0m
!python predict.py \
    --device "cpu" \
    --dataset_dir ./KUAKE_QIC \
    --params_path "./prune/0.5" \

可支持配置的参数:

params_path:待预测模型参数文件夹;默认为"./checkpoint/"。

dataset_dir:本地数据集路径,数据集路径中应包含data.txt和label.txt文件;默认为None。

max_seq_length:ERNIE模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为512。

batch_size:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。

device: 选用什么设备进行训练,可选cpu、gpu、xpu、npu;默认为gpu。

5.总结

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,现在进行汇总

模型ACCPrecisionRecallF1average_of_acc_and_f1
ERNIE 3.0 Base0.802550.93171470.9082840.9198500.86120
ERNIE 3.0 Base+数据增强0.79795390.9010040.928990.914780.8563
ERNIE 3.0 Base+剪裁保留比0.50.798460.9512570.894970.922250.8603
ERNIE 3.0 Base +剪裁保留比2/30.80920710.94153840.9053250.9230760.86614

分析可得,

  • 首先数据增强后导致性能部分下降部分和预期的原因:
    随机mask、删除会产生过多噪声样本影响结果,推荐只使用同义词替换,本次样本数据量足够,且ERNIE性能本就优越,数据增强对结果提升在较大样本集可以忽略。

  • 其次,可以看到通过数据蒸馏后,模型性能变化不大,甚至在剪裁1/3之后,性能有小幅度提升

本次主要对分类模型加入数据增强、数据蒸馏,已经对性能指标进行细化,不只是ACC,个人比较关注F1情况,并作为保存模型依据。

展望: 后续将完善动态图和静态图转化部分,让蒸馏下来模型可以继续线上加载使用;其次将会考虑小样本学习在分类模型应用情况;最后将完成模型融合环节提升性能,并做可解释性分析。

本人博客:https://blog.csdn.net/sinat_39620217?type=blog


此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐