基于检索的多层次文本分类

1、项目说明

以前的分类任务中,标签信息作为无实际意义,独立存在的one-hot编码形式存在,这种做法会潜在的丢失标签的语义信息,本方案把文本分类任务中的标签信息转换成含有语义信息的语义向量,将文本分类任务转换成向量检索和匹配的任务。这样做的好处是对于一些类别标签不是很固定的场景,或者需要经常有一些新增类别的需求的情况非常合适。另外,对于一些新的相关的分类任务,这种方法也不需要模型重新学习或者设计一种新的模型结构来适应新的任务。总的来说,这种基于检索的文本分类方法能够有很好的拓展性,能够利用标签里面包含的语义信息,不需要重新进行学习。这种方法可以应用到相似标签推荐,文本标签标注,金融风险事件分类,政务信访分类等领域。

本方案是基于语义索引模型的分类,语义索引模型的目标是:给定输入文本,模型可以从海量候选召回库中快速、准确地召回一批语义相关文本。基于语义索引的分类方法有两种,第一种方法是直接把标签变成召回库,即把输入文本和标签的文本进行匹配,第二种是利用召回的文本带有类别标签,把召回文本的类别标签作为给定输入文本的类别。本方案使用双塔模型,训练阶段引入In-batch Negatives 策略,使用hnswlib建立索引库,并把标签作为召回库,进行召回测试。最后利用召回的结果使用 Accuracy 指标来评估语义索引模型的分类的效果。下面用一张图来展示与传统的微调方案的区别,在预测阶段,微调的方式则是用分类器分类得到的结果,而基于检索的方式是通过比较文本和标签的相似度得到的分类结果。

在这里插入图片描述

本项目源代码全部开源在 PaddleNLP。

详细内容可参考文本分类系统方案:https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/text_classification


1.1 应用特色

  • 低门槛

    • 手把手搭建层次化文本分类
  • 效果好

    • 业界领先的文本分类模型: ERNIE 3.0
    • 业界领先的检索预训练模型: RocketQA Dual Encoder
    • 针对数据量少,类别体系复杂,类别数目多,并且经常需要变动的场景领先解决方案
  • 性能快

    • 基于 Paddle Inference 快速抽取向量
    • 基于 Milvus 快速查询和高性能建库
    • 基于 Paddle Serving 高性能部署

1.2 文本分类流程设计

对于相对均衡的样本,选择分类的方式,极端不均衡和类别变化频繁的样本,使用向量索引的方式分类。基于检索的搜索方案的流程图如图所示,对于向量索引的模型,我们选用基于ERNIE3.0训练的RocketQA模型,然后把文本和标签构成句子对训练语义向量模型,训练结束后得到语义向量抽取模型,然后把标签放入模型抽取标签向量(Label Embedding),最后放入索引向量库中。紧接着对于需要分类的文本,利用语义向量抽取模型抽取句子语义向量(Sentence Embedding),最后利用ANN引擎从标签索引库中找最相似的标签向量,然后得到对应的标签,并把该标签当作最终的分类文本。

2、安装说明

AI Studio平台默认安装了Paddle和PaddleNLP,并定期更新版本。 如需手动更新,可参考如下说明:

# 首次更新完以后,重启后方能生效
!pip install --upgrade paddlenlp
# 安装相关的依赖包
!pip install -r requirements.txt
# 加载系统的API
import abc
import sys
from functools import partial
import argparse
import os
import random
import time
import numpy as np
# 加载飞桨的API
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import inference
# 加载PaddleNLP的API
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset, MapDataset
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.utils.downloader import get_path_from_url
import paddle_serving_client.io as serving_io
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer

3、数据准备

下载数据集:

if(not os.path.exists('baike_qa_category.zip')):
    get_path_from_url('https://paddlenlp.bj.bcebos.com/applications/baike_qa_category.zip',root_dir='.')

!unzip -o baike_qa_category.zip
Archive:  baike_qa_category.zip
  inflating: data/label.txt          
  inflating: data/.DS_Store          
  inflating: data/dev.txt            
  inflating: data/train.txt          

3.1 加载数据

from data import read_text_pair, convert_example, create_dataloader, gen_id2corpus, gen_text_file, convert_corpus_example
from paddlenlp.transformers import AutoModel, AutoTokenizer
train_set_file = 'data/train.txt'
train_ds = load_dataset(read_text_pair,
                            data_path=train_set_file,
                            lazy=False)
max_seq_length = 384
pretrained_model = AutoModel.from_pretrained(
        'rocketqa-zh-dureader-query-encoder')
tokenizer = AutoTokenizer.from_pretrained(
        'rocketqa-zh-dureader-query-encoder')

trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
            ),  # query_input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # query_segment
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
            ),  # title_input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # tilte_segment
    ): [data for data in fn(samples)]
# batch_size设置大一点,效果会更好一点
batch_size = 24
train_data_loader = create_dataloader(train_ds,
                                          mode='train',
                                          batch_size=batch_size,
                                          batchify_fn=batchify_fn,
                                          trans_fn=trans_func)
recall_result_dir = "recall_result_dir"
recall_result_file = "recall_result.txt"
evaluate_result = "evaluate_result.txt"
similar_text_pair_file = "data/dev.txt"
corpus_file = 'data/label.txt' 
batchify_fn_dev = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
            ),  # text_input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # text_segment
    ): [data for data in fn(samples)]

eval_func = partial(convert_example,
                            tokenizer=tokenizer,
                            max_seq_length=max_seq_length)
id2corpus = gen_id2corpus(corpus_file)
# conver_example function's input must be dict
corpus_list = [{idx: text} for idx, text in id2corpus.items()]
corpus_ds = MapDataset(corpus_list)
corpus_data_loader = create_dataloader(corpus_ds,
                                               mode='predict',
                                               batch_size=batch_size,
                                               batchify_fn=batchify_fn_dev,
                                               trans_fn=eval_func)
# convert_corpus_example
query_func = partial(convert_example,
                             tokenizer=tokenizer,
                             max_seq_length=max_seq_length)
text_list, _ = gen_text_file(similar_text_pair_file)
query_ds = MapDataset(text_list)
query_data_loader = create_dataloader(query_ds,
                                              mode='predict',
                                              batch_size=batch_size,
                                              batchify_fn=batchify_fn_dev,
                                              trans_fn=query_func)
if not os.path.exists(recall_result_dir):
    os.mkdir(recall_result_dir)
recall_result_file = os.path.join(recall_result_dir,
                                          recall_result_file)
# 打印标签数据,标签文本数据被映射成了ID的形式
for batch in corpus_data_loader:
    print(batch)
    break
[Tensor(shape=[24, 19], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,
       [[1   , 82  , 227 , 17963, 287 , 70  , 30  , 343 , 2275, 2   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 128 , 941 , 17963, 305 , 742 , 30  , 163 , 716 , 94  , 159 , 2   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 691 , 736 , 30  , 407 , 193 , 188 , 390 , 30  , 1553, 64  , 407 ,
         193 , 2   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 137 , 405 , 18  , 489 , 30  ,
         139 , 405 , 2   , 0   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 389 , 26  , 80  , 227 , 30  ,
         32  , 159 , 138 , 318 , 401 , 525 , 2   ],
        [1   , 691 , 736 , 30  , 103 , 147 , 30  , 1208, 813 , 103 , 147 , 2   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 21  , 205 , 30  , 188 , 494 , 17963, 2608, 3255, 30  , 1134, 954 ,
         17963, 661 , 737 , 2   , 0   , 0   , 0   ],
        [1   , 691 , 736 , 30  , 407 , 193 , 188 , 390 , 30  , 274 , 2180, 407 ,
         193 , 2   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 38  , 35  , 18  , 147 , 30  ,
         111 , 38  , 18  , 2   , 0   , 0   , 0   ],
        [1   , 278 , 26  , 17963, 38  , 625 , 30  , 161 , 645 , 2   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 38  , 35  , 18  , 147 , 30  ,
         21  , 122 , 18  , 2   , 0   , 0   , 0   ],
        [1   , 1030, 320 , 30  , 29  , 320 , 30  , 371 , 979 , 2   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 82  , 227 , 17963, 287 , 70  , 30  , 2001, 497 , 2   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 691 , 736 , 30  , 407 , 193 , 188 , 390 , 30  , 821 , 325 , 188 ,
         390 , 2   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 343 , 621 , 30  , 128 , 367 , 343 , 621 , 2   , 0   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 245 , 225 , 212 , 399 , 2   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 691 , 736 , 30  , 326 , 292 , 111 , 38  , 147 , 2   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 343 , 621 , 30  , 1015, 19  , 343 , 621 , 30  , 241 , 734 , 203 ,
         280 , 2   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 8   , 68  , 18  , 147 , 30  ,
         405 , 545 , 18  , 2   , 0   , 0   , 0   ],
        [1   , 80  , 227 , 17963, 147 , 18  , 30  , 389 , 26  , 80  , 227 , 30  ,
         53  , 112 , 141 , 401 , 525 , 2   , 0   ],
        [1   , 691 , 736 , 30  , 137 , 147 , 30  , 2249, 1601, 137 , 147 , 2   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 691 , 736 , 30  , 760 , 1411, 147 , 30  , 760 , 1411, 396 , 2   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 343 , 621 , 30  , 1032, 98  , 343 , 621 , 2   , 0   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ],
        [1   , 1599, 354 , 30  , 604 , 799 , 17963, 287 , 660 , 2   , 0   , 0   ,
         0   , 0   , 0   , 0   , 0   , 0   , 0   ]]), Tensor(shape=[24, 19], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,
       [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])]
# 分类文本数据,文本数据被映射成了ID的形式
for batch in query_data_loader:
    print(batch)
    break
[Tensor(shape=[24, 384], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,
       [[1   , 647 , 358 , ..., 0   , 0   , 0   ],
        [1   , 75  , 883 , ..., 0   , 0   , 0   ],
        [1   , 2308, 889 , ..., 0   , 0   , 0   ],
        ...,
        [1   , 1228, 1928, ..., 0   , 0   , 0   ],
        [1   , 252 , 560 , ..., 0   , 0   , 0   ],
        [1   , 76  , 1234, ..., 0   , 0   , 0   ]]), Tensor(shape=[24, 384], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,
       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]])]

4、 模型选择

模型Accuracy策略简要说明
ernie-3.0-medium-zh50.580ernie-3.0-medium-zh多分类,5个epoch,对于新增类别需要重新训练
In-batch Negatives + RocketQA49.755Inbatch-negative有监督训练,标签当作召回集,对新增类别不需要重新训练
In-batch Negatives + RocketQA + 投票51.756Inbatch-negative有监督训练,训练集当作召回集,对新增类别,需要至少一条的数据放入召回库中

对于类别体系不固定,且未来有新增类别可能性的情况,我们推荐使用RocketQA的语义检索的方案。

【注意】对于新增的类别,不需要重新训练检索模型,只需要把新增类别的标签抽取成向量,放入语义索引库中,就能够继续使用了。

margin = 0.2
scale = 20
# 设置为0,默认不对语义向量进行降维度
output_emb_size = 0
from model import SemanticIndexBatchNeg
model = SemanticIndexBatchNeg(pretrained_model,
                                  margin=margin,
                                  scale=scale,
                                  output_emb_size=output_emb_size)

model = paddle.DataParallel(model)

5、模型训练

5.1 参数配置

# 正常需要设置50~100 epoch左右,为了演示方便,这里设置成了10
epochs = 5
learning_rate = 5E-5 
warmup_proportion = 0
save_dir ='checkpoints_recall'
weight_decay = 0.0
log_steps= 100
recall_num = 20
num_training_steps = len(train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,
                                         warmup_proportion)
decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params,
        grad_clip=nn.ClipGradByNorm(clip_norm=1.0))

5.2 模型评估

from data import build_index
# 索引的大小
hnsw_max_elements=1000000
# 控制时间和精度的平衡参数
hnsw_ef=100
hnsw_m=100
output_emb_size = 0
recall_num = 10

def recall(rs, N=10):
    recall_flags = [np.sum(r[0:N]) for r in rs]
    return np.mean(recall_flags)

@paddle.no_grad()
def evaluate(model, corpus_data_loader, query_data_loader, recall_result_file,
             text_list, id2corpus):
    # Load pretrained semantic model
    inner_model = model._layers
    final_index = build_index(corpus_data_loader, inner_model,output_emb_size=output_emb_size,hnsw_max_elements=hnsw_max_elements,
                hnsw_ef=hnsw_ef,
                hnsw_m=hnsw_m)
    query_embedding = inner_model.get_semantic_embedding(query_data_loader)
    with open(recall_result_file, 'w', encoding='utf-8') as f:
        for batch_index, batch_query_embedding in enumerate(query_embedding):
            recalled_idx, cosine_sims = final_index.knn_query(
                batch_query_embedding.numpy(), recall_num)
            batch_size = len(cosine_sims)
            for row_index in range(batch_size):
                text_index = batch_size * batch_index + row_index
                for idx, doc_idx in enumerate(recalled_idx[row_index]):
                    f.write("{}\t{}\t{}\n".format(
                        text_list[text_index]["text"], id2corpus[doc_idx],
                        1.0 - cosine_sims[row_index][idx]))
    text2similar = {}
    with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
        for line in f:
            text_arr = line.rstrip().rsplit("\t")
            text, similar_text = text_arr[0], text_arr[1].replace('##', ',')
            text2similar[text] = similar_text
    rs = []
    with open(recall_result_file, 'r', encoding='utf-8') as f:
        relevance_labels = []
        for index, line in enumerate(f):
            if index % recall_num == 0 and index != 0:
                rs.append(relevance_labels)
                relevance_labels = []
            text_arr = line.rstrip().rsplit("\t")
            text, similar_text, cosine_sim = text_arr
            if text2similar[text] == similar_text:
                relevance_labels.append(1)
            else:
                relevance_labels.append(0)

    recall_N = []
    recall_num_list = [1, 5, 10, 20]
    for topN in recall_num_list:
        R = round(100 * recall(rs, N=topN), 3)
        recall_N.append(str(R))
    evaluate_result_file = os.path.join(recall_result_dir,
                                        evaluate_result)
    result = open(evaluate_result_file, 'a')
    res = []
    timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())
    res.append(timestamp)
    for key, val in zip(recall_num_list, recall_N):
        print('recall@{}={}'.format(key, val))
        res.append(str(val))
    result.write('\t'.join(res) + '\n')
    return float(recall_N[0])

5.3 启动训练

save_root_dir='checkpoints'
global_step = 0
best_recall = 0.0
tic_train = time.time()
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch
        loss = model(query_input_ids=query_input_ids,
                         title_input_ids=title_input_ids,
                         query_token_type_ids=query_token_type_ids,
                         title_token_type_ids=title_token_type_ids)
        global_step += 1
        if global_step % log_steps == 0:
            print(
                    "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"
                    % (global_step, epoch, step, loss, 10 /
                       (time.time() - tic_train)))
            tic_train = time.time()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()

    print("evaluating")
    recall_5 = evaluate(model, corpus_data_loader, query_data_loader,
                                recall_result_file, text_list, id2corpus)
    if recall_5 > best_recall:
        best_recall = recall_5
        save_dir = os.path.join(save_root_dir, "model_best")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_param_path = os.path.join(save_dir, 'model_state.pdparams')
        paddle.save(model.state_dict(), save_param_path)
        tokenizer.save_pretrained(save_dir)
        with open(os.path.join(save_dir, "train_result.txt"),
                          'a',
                          encoding='utf-8') as fp:
            fp.write('epoch=%d, global_step: %d, recall: %s\n' %
                             (epoch, global_step, recall_5))

语义索引模型的收敛速度比较慢,建议设置较大的epoch,性能会提升比较大,但训练时间会比较长一点。

6、模型预测

from data import build_index
# 构建索引
final_index = build_index(corpus_data_loader, model._layers,output_emb_size=output_emb_size,hnsw_max_elements=hnsw_max_elements,
                hnsw_ef=hnsw_ef,
                hnsw_m=hnsw_m)
[2022-09-20 17:21:27,837] [    INFO] - start build index..........
[2022-09-20 17:21:28,296] [    INFO] - Total index number:316
query_embedding = model._layers.get_semantic_embedding(query_data_loader)
list_data=[]
for batch_index, batch_query_embedding in enumerate(query_embedding):
    recalled_idx, cosine_sims = final_index.knn_query(
                batch_query_embedding.numpy(), recall_num)
    batch_size = len(cosine_sims)
    for row_index in range(batch_size):
        text_index = batch_size * batch_index + row_index
        for idx, doc_idx in enumerate(recalled_idx[row_index]):
            list_data.append([text_list[text_index]["text"], id2corpus[doc_idx],
                        1.0 - cosine_sims[row_index][idx]])
# 打印若干预测的值
print(list_data[:20])
[['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '教育/科学,理工学科,心理学', 0.6262896060943604], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼,恋爱', 0.5953264832496643], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼,情感情绪', 0.5355594754219055], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '健康,精神心理科', 0.5196096301078796], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼', 0.3823578953742981], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '娱乐', 0.328216552734375], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '健康,精神心理科,心理科', 0.26789504289627075], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼,交友技巧', 0.2628575563430786], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '电脑/网络,互联网', 0.226911723613739], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '娱乐,电视', 0.21047455072402954], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,诛仙', 0.7549923658370972], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,金山游戏,封神榜', 0.543825089931488], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,神鬼传奇', 0.530568540096283], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,久游游戏,仙剑ol', 0.516210675239563], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,网络游戏,天龙八部', 0.4695323705673218], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,网络游戏,破天一剑', 0.46719789505004883], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,武林外传', 0.42875951528549194], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,单机游戏,仙剑奇侠传', 0.4053085446357727], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,口袋西游', 0.30933690071105957], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,网页游戏', 0.3086514472961426]]

7、模型部署

模型部署分为3步,第一步是把模型转换成静态图,这样能够提升速度;第二步是搭建一个分类索引引擎,构建标签向量库,可以使用Milvus,FAISS,Elastic Search等;第三步是PaddleServing部署,把模型部署成Serving的形式,可拓展性更好。

7.1 动转静导出

output_path='output'
# 切换成eval模式,关闭dropout
model.eval()
# Convert to static graph with specific input description
model = paddle.jit.to_static(
        model._layers,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None, None], dtype="int64"),  # input_ids
            paddle.static.InputSpec(
                shape=[None, None], dtype="int64")  # segment_ids
        ])
# Save in static graph model.
save_path = os.path.join(output_path, "inference")
paddle.jit.save(model, save_path)

7.2 分类引擎

模型准备结束以后,开始搭建 Milvus 的向量检索引擎,用于文本语义向量的快速检索,本项目使用Milvus开源工具进行向量检索,Milvus 的搭建教程请参考官方教程 Milvus官方安装教程本案例使用的是 Milvus 的1.1.1 CPU版本,建议使用官方的 Docker 安装方式,简单快捷。

如果想使用最新版本的Milvus,可以参考Neural Search的实现,最新的Milvus更易用一点。

7.3 Paddle Serving 部署

dirname="output"
# 模型的路径
model_filename="inference.pdmodel"
# 参数的路径
params_filename="inference.pdiparams" 
# server的保存地址
server_path="serving_server"
# client的保存地址
client_path="serving_client"
# 指定输出的别名
feed_alias_names=None
# 制定输入的别名
fetch_alias_names="output_embedding"
# 设置为True会显示日志
show_proto=False
serving_io.inference_model_to_serving(
        dirname=dirname,
        serving_server=server_path,
        serving_client=client_path,
        model_filename=model_filename,
        params_filename=params_filename,
        show_proto=show_proto,
        feed_alias_names=feed_alias_names,
        fetch_alias_names=fetch_alias_names)
(dict_keys(['query_input_ids', 'title_input_ids']),
 dict_keys(['mean_0.tmp_0']))

搭建结束以后,就可以启动server部署服务,使用client端访问server端就行了。具体细节参考代码:
PaddleNLP文本分类应用

# 删除产生的模型文件
!rm -rf output/
# rm -rf checkpoints/
!rm -rf checkpoints_recall
!rm -rf serving_server/
!rm -rf serving_client/

8、模型优化

对于基于索引的方法除了用更多的数据集,调参数以外,还可以通过丰富标签的语义信息得到进一步的优化,由于标签大多都是关键词,这对于句子级别语义建模的方法不是很好,可以考虑把这些关键字,增加一些描述信息,这样就提供了更多的信息,对文本和标签的向量相似度计算很有帮助。


此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐