零.写在最前

该项目源于比赛《航旅纵横-领域知识问答测评
中子任务二《航旅纵横杯-任务2:段落级答案抽取
使用RocketQA实现段落级答案抽取

问答系统(Question Answering System, QA)是信息检索系统的一种高级形式,它能用准确、简洁的自然语言回答用户用自然语言提出的问题。问答系统在搜索引擎、智能客服和智能助手等应用场景中都发挥着重要作用。检索式问答系统是问答系统的重要类别,它能从大量文本中检索出问题的答案。下图是检索式问答系统的功能示意。

在这里插入图片描述

主流的检索式问答系统通常包含检索(retriever)、排序(reranker)和答案抽取(reader)等多个串行模块。随着大规模预训练模型的发展,研究人员开始探索基于深度语义表示的问答模型。得益于预训练模型生成的高质量语义表示空间和端到端的训练方法,语义模型能够提供更优质的结果,并且简化了传统问答系统的级联架构和特征工程方案,使系统的复杂性大大降低。

在此背景下,百度NLP提出了一系列基于语义的问答模型–RocketQA。RocketQA系列工作不仅被自然语言处理领域的多个国际顶级会议录取,也在百度的搜索业务中发挥了重要作用。为了使更多研发人员能够方便地获取最先进的问答语义检索与排序技术,百度NLP联合飞桨共同发布了基于RocketQA的开发工具。

目前该项目使用RocketQA实现前2个步骤,即检索和排序

快速命令行模型训练+生成结果:


cd work  # 切换到工作目录

python dataset.py # 生成符合格式的训练数据

python train.py # 训练模型

python predict .py # 预测生成提交文件


一.安装RocketQA

  1. RocketQA源码:https://github.com/PaddlePaddle/RocketQA 查看具体示例和源码细节
!pip install rocketqa

二.数据预处理

使用RocketQA需要将数据生成为对应的格式

1.数据集准备

在本项目中 在work/data 目录已经准备好了解压完成的数据集以及任务1的结果文件
在这里插入图片描述

2.生成对应格式的训练集

cd work  # 切换到工作目录


python dataset.py  # 生成训练集 train.tsv

生成的train.tsv 在 /work/mine 目录下。

相关路径配置请看/work/config.py 文件

# -*- coding: utf-8 -*-
class Config(object):
    def __init__(self):
        self.dataPath = './data/'
        self.minePath = './mine/'
        self.modelPath = './model/'
        self.logPath = './logs'
        self.use_gpu = 1
import paddle
import numpy as np
import pandas as pd
from config import Config
import json
import random
import csv

class Dataset(object):
    relations = {}
    keys = []
    total = 0
    success = 0
    def __init__(self):
        self.cf = Config()
        self.dataPath = self.cf.dataPath
        f = open(self.cf.minePath+'/train.tsv', 'wt')
        self.tsv = csv.writer(f, delimiter='\t')

    def _detail2dk(self, detail, islist = False):
            if islist == False:
                detail = detail.strip().replace(" ","").replace("'",'"')
                detail = json.loads(detail)
            dk = [str(i) for i in detail]
            dk = ','.join(dk)
            return dk

    def _contents(self):
        contents = pd.read_excel(self.dataPath+'/section.xlsx')
        for i in range(len(contents)):
            dk = self._detail2dk(contents.loc[i]['detail'])
            key = contents.loc[i]['content-key']+'-'+dk
            self.keys.append(key)
            self.relations[key] = contents.loc[i]['text']

    def _train(self):
        f = open(self.dataPath+'/train.txt')
        while True:
            line = f.readline()
            if not line:
                break
            data = line.strip().split('\t')
            data = json.loads(data[0])
            text = data['question']
            if text is None or text!=text:
                continue
            anwser = []
            for item in data['answer']:
                dk = self._detail2dk(item['detail'], True)
                key = item['content-key']+'-'+dk
                anwser.append(key)
            if len(anwser) < 1:
                continue
            
            anwser = list(set(anwser))
            self._select(text, anwser)

    def _select(self, text, keys):
        diffs = list(set(self.keys).difference(set(keys)))
        nums = 2000
        #nums = len(keys)*2000
        nums = nums if len(diffs) > nums else len(diffs)
        print(len(diffs), nums, len(keys))
        self.total = self.total + nums + len(keys)
        others = random.sample(diffs, nums)

        for k in keys:
            if k in self.relations:
                content = self.relations[k]
                self.tsv.writerow([text, '', content, 1]) 
                self.success = self.success + 1

        for k in others:
            if k in self.relations:
                content = self.relations[k]
                self.tsv.writerow([text, '' ,content, 0]) 
                self.success = self.success + 1

    
    def run(self):
        self._contents()
        self._train()
        print(self.total)
        print(self.success)

dataset = Dataset()
dataset.run()




三. 训练

模型有很多种,这里选择 zh_dureader_ce 进行训练

在这里插入图片描述

import rocketqa

# init cross encoder, and set device and batch_size
cross_encoder = rocketqa.load_model(model="zh_dureader_ce", use_cuda=True, device_id=0, batch_size=25)

# finetune cross encoder based on "zh_dureader_ce_v2"
cross_encoder.train('./mine/train.tsv', 2, 'models', save_steps=10000, learning_rate=1e-5, log_folder='logs')


模型存放在models 下,这里设置每隔10000个steps 保存一次。

batch_size 是根据显卡配置选定的,根据个人显卡配置情况进行修改。本项目示例使用的是 AI studio 中的 V100 16G

四.预测并生成成比赛提交文件

生成的提交文件保存在 work/data/task_1.txt

加载自定义训练好的模型

cross_encoder = rocketqa.load_model(model="./ce_models/config.json", use_cuda=True, device_id=0)

根据个人需要,可以修改模型的config.json 文件的配置

model_type:"cross_encoder"
max_seq_len:384
model_conf_path:"zh_config.json"
model_vocab_path:"zh_vocab.txt"
model_checkpoint_path:"step_50000"
for_cn:true
share_parameter:0

model_checkpoint_path 字段填写自己训练的保存的step 参数


import rocketqa
from config import Config
import pandas as pd
import json

def detail2dk(detail, islist = False):
    if islist == False:
        detail = detail.strip().replace(" ","").replace("'",'"')
        detail = json.loads(detail)
    dk = [str(i) for i in detail]
    dk = ','.join(dk)
    return dk



cf = Config()
contents = pd.read_excel(cf.dataPath+'/section.xlsx')
taskf_2 = open('mine/task_2.txt', 'w')
#cross_encoder = rocketqa.load_model(model="zh_dureader_ce",use_cuda=True, device_id=0)
cross_encoder = rocketqa.load_model(model="./mine/config.json", use_cuda=True, device_id=0)

# 获取task_1 问题和答案结果,生成list
f_task1 = open(cf.dataPath+'/task_1.txt')
while True:
    line = f_task1.readline()
    if not line:
        break
    data = line.strip().split('\t')
    data = json.loads(data[0])
    relations = {}
    keys = []
    text = data['question']
    for key_ in data['answer']:
        contents_ = contents.loc[contents['content-key'] == key_].iloc[:,:]
        for i, line in contents_.iterrows():
            dk = detail2dk(line['detail'])
            key = line['content-key']+'-'+dk
            keys.append(key)
            relations[key] = line['text']

    results_2 = []
    query_list = []
    para_list = []
    for k,v in relations.items():
        query_list.append(text)
        para_list.append(v)

    dot_products = cross_encoder.matching(query=query_list, para=para_list)
    ratios = list(dot_products)
    ranks = sorted(ratios, reverse=True)
    rankTop = ranks[0:5]

    for ik in range(len(ratios)):
        for r in rankTop:
            if ratios[ik] >= r:
                bks = keys[ik].split('-')
                dks = bks[1].split(',')
                dks_2 = []
                for dks_item in dks:
                    if dks_item.isdigit():
                        dks_2.append(int(dks_item))
                    else:
                        dks_2.append(dks_item)
                dks_ = {"content-key":bks[0],"detail":dks_2}
                results_2.append(dks_)
                break;

    re_2 = {"question":text,'answer':results_2}
    print(re_2)
    taskf_2.write(json.dumps(re_2,ensure_ascii=False))
    taskf_2.write('\n')



{'question': '机场几点开始值机', 'answer': [{'content-key': '8f4f560a6087c88394afc18db6b04448', 'detail': ['h1_0', 0, 'h2_6', 'text']}, {'content-key': '385217662e0fb9b27f7a9ba40fa14ae2', 'detail': ['h1_0', 0, 'h2_2', 'text']}, {'content-key': '385217662e0fb9b27f7a9ba40fa14ae2', 'detail': ['h1_0', 0, 'h2_11', 'title']}, {'content-key': '385217662e0fb9b27f7a9ba40fa14ae2', 'detail': ['h1_0', 0, 'h2_11', 'text']}, {'content-key': '0d903b873c849cf00d654414b7db5a9e', 'detail': ['h1_0', 0, 'h2_0', 'text']}]}
{'question': '取机票的时候可以选位置吗', 'answer': [{'content-key': '8f4f560a6087c88394afc18db6b04448', 'detail': ['h1_0', 0, 'h2_7', 'text']}, {'content-key': '8b1efca8af2eff91b012e3346128da24', 'detail': ['h1_0', 0, 'h2_9', 'text']}, {'content-key': '8b1efca8af2eff91b012e3346128da24', 'detail': ['h1_0', 0, 'h2_14', 'text']}, {'content-key': '8b1efca8af2eff91b012e3346128da24', 'detail': ['h1_0', 0, 'h2_25', 'text']}, {'content-key': '8b1efca8af2eff91b012e3346128da24', 'detail': ['h1_0', 0, 'h2_29', 'text']}, {'content-key': '8b1efca8af2eff91b012e3346128da24', 'detail': ['h1_0', 0, 'h2_30', 'text']}, {'content-key': '8b1efca8af2eff91b012e3346128da24', 'detail': ['h1_0', 0, 'h2_38', 'text']}]}
{'question': '航空行李须知', 'answer': [{'content-key': '4c20b81a35d0e2668ff5425c386b04be', 'detail': ['h2_1', 0, 'h3_4', 0, 'h4_8', 0, 'h5_8', 0, 'texts', 8, 'text']}, {'content-key': '014376012b92ea61fbf175090fa6a2d9', 'detail': ['h3_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'h4_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'h4_1', 0, 'texts', 0, 'table']}]}
{'question': '航空托运宠物', 'answer': [{'content-key': '2c2ccfd5d0c6c22156bd31a0a519a529', 'detail': ['h1_0', 0, 'texts', 10, 'text']}, {'content-key': '2c2ccfd5d0c6c22156bd31a0a519a529', 'detail': ['h1_0', 0, 'texts', 21, 'text']}, {'content-key': '2c2ccfd5d0c6c22156bd31a0a519a529', 'detail': ['h1_0', 0, 'texts', 28, 'text']}, {'content-key': '2c2ccfd5d0c6c22156bd31a0a519a529', 'detail': ['h1_0', 0, 'texts', 36, 'text']}, {'content-key': '2c2ccfd5d0c6c22156bd31a0a519a529', 'detail': ['h1_0', 0, 'texts', 48, 'table']}]}
{'question': '面部受伤的旅客还可以乘机吗?', 'answer': [{'content-key': '4e639c45641fc6c204bfc3fa0631bac1', 'detail': ['h1_0', 0, 'title']}, {'content-key': '4e639c45641fc6c204bfc3fa0631bac1', 'detail': ['h1_0', 0, 'texts', 3, 'text']}, {'content-key': '360bd038cf4dc139a858f8118ff1e6cf', 'detail': ['h1_0', 0, 'h2_3', 'text']}, {'content-key': 'e68660af5d233eaa37ce2f2241f6b042', 'detail': ['h1_0', 0, 'h2_2', 'text']}, {'content-key': '9af70e36b48c6dcb513f55172170e385', 'detail': ['h1_0', 0, 'h2_1', 'text']}]}
{'question': '航站楼之间怎么走', 'answer': [{'content-key': '360bd038cf4dc139a858f8118ff1e6cf', 'detail': ['h1_0', 0, 'h2_28', 'text']}, {'content-key': 'e68660af5d233eaa37ce2f2241f6b042', 'detail': ['h1_0', 0, 'h2_30', 'text']}, {'content-key': 'c687e21856d0fb37c63724598155f6ee', 'detail': ['h1_0', 0, 'h2_0', 'title']}, {'content-key': 'c687e21856d0fb37c63724598155f6ee', 'detail': ['h1_0', 0, 'h2_5', 'title']}, {'content-key': 'c687e21856d0fb37c63724598155f6ee', 'detail': ['h1_0', 0, 'h2_25', 'title']}]}
{'question': '飞机提前多久安检', 'answer': [{'content-key': '360bd038cf4dc139a858f8118ff1e6cf', 'detail': ['h1_0', 0, 'h2_5', 'text']}, {'content-key': '360bd038cf4dc139a858f8118ff1e6cf', 'detail': ['h1_0', 0, 'h2_9', 'text']}, {'content-key': 'e68660af5d233eaa37ce2f2241f6b042', 'detail': ['h1_0', 0, 'h2_4', 'text']}, {'content-key': '52dc8edfc394977c9d9c0c1952d6850a', 'detail': ['h1_0', 0, 'h2_8', 'title']}, {'content-key': '52dc8edfc394977c9d9c0c1952d6850a', 'detail': ['h1_0', 0, 'h2_8', 'text']}]}
{'question': '飞机最晚什么时候值机', 'answer': [{'content-key': '8f4f560a6087c88394afc18db6b04448', 'detail': ['h1_0', 0, 'h2_6', 'text']}, {'content-key': '385217662e0fb9b27f7a9ba40fa14ae2', 'detail': ['h1_0', 0, 'h2_2', 'text']}, {'content-key': '385217662e0fb9b27f7a9ba40fa14ae2', 'detail': ['h1_0', 0, 'h2_11', 'title']}, {'content-key': '385217662e0fb9b27f7a9ba40fa14ae2', 'detail': ['h1_0', 0, 'h2_11', 'text']}, {'content-key': '0d903b873c849cf00d654414b7db5a9e', 'detail': ['h1_0', 0, 'h2_0', 'text']}]}
{'question': '显示无餐食', 'answer': [{'content-key': 'a70cb3e5e5e02ce3f5eca9af1a5d19e6', 'detail': ['h1_0', 0, 'texts', 10, 'text']}, {'content-key': '072d85b2c0f07c7b22073fdca0542de5', 'detail': ['h1_0', 0, 'h2_18', 'title']}, {'content-key': '072d85b2c0f07c7b22073fdca0542de5', 'detail': ['h1_0', 0, 'h2_19', 'title']}, {'content-key': '072d85b2c0f07c7b22073fdca0542de5', 'detail': ['h1_0', 0, 'h2_20', 'title']}, {'content-key': '072d85b2c0f07c7b22073fdca0542de5', 'detail': ['h1_0', 0, 'h2_20', 'text']}]}
{'question': '我的行李额是0', 'answer': [{'content-key': '014376012b92ea61fbf175090fa6a2d9', 'detail': ['h3_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'texts', 0, 'text']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'h4_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'h4_1', 0, 'texts', 0, 'table']}]}
{'question': '我要去英国上学,我从昆明出发到香港转机,我托运的行李可以带多少', 'answer': [{'content-key': '014376012b92ea61fbf175090fa6a2d9', 'detail': ['h3_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'texts', 0, 'text']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'h4_0', 0, 'texts', 0, 'table']}, {'content-key': '9b7ace40c5a6f354f18239b7ec1c5a3e', 'detail': ['h3_1', 0, 'h4_1', 0, 'texts', 0, 'table']}]}
{'question': '老人身份证丢了用户口本可以坐飞机吗', 'answer': [{'content-key': '221e26f97b81cfa347fe3275356f150a', 'detail': ['h1_0', 0, 'texts', 2, 'text']}, {'content-key': '221e26f97b81cfa347fe3275356f150a', 'detail': ['h1_0', 0, 'texts', 3, 'text']}, {'content-key': '221e26f97b81cfa347fe3275356f150a', 'detail': ['h1_0', 0, 'texts', 4, 'text']}, {'content-key': '221e26f97b81cfa347fe3275356f150a', 'detail': ['h1_0', 0, 'texts', 5, 'text']}, {'content-key': '221e26f97b81cfa347fe3275356f150a', 'detail': ['h1_0', 0, 'texts', 12, 'text']}]}

五.总结

本项目通过数据预处理,应用RocketQA预训练模型,训练,推理等流程即可轻松生成预测文件,提交比赛结果。

优化方向:

1.增大训练集范围,通过调节 dataset.py中的 nums值

2.更换RocketQA的其他训练模型

3.微调batch_size, 学习率,max_seq_len 等

希望对大家有所帮助


转载自:https://aistudio.baidu.com/aistudio/projectdetail/4388971

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐