转载自AI Studio
项目链接https://aistudio.baidu.com/aistudio/projectdetail/3168859

赛事背景

问答系统中包括三个主要的部分:问题理解,信息检索和答案抽取。而问题理解是问答系统的第一部分也是非常关键的一部分。问题理解有非常广泛的应用,如重复评论识别、相似问题识别等。

重复问题检测是一个常见的文本挖掘任务,在很多实际问答社区都有相应的应用。重复问题检测可以方便进行问题的答案聚合,以及问题答案推荐,自动QA等。由于中文词语的多样性和灵活性,本赛题需要选手构建一个重复问题识别算法。

https://challenge.xfyun.cn/topic/info?type=chinese-question-similarity

赛事任务

本次赛题希望参赛选手对两个问题完成相似度打分。

  • 训练集:约5千条问题对和标签。若两个问题是相同的问题,标签为1;否则为0。
  • 测试集:约5千条问题对,需要选手预测标签。

评审规则

  1. 数据说明
    训练集给定问题对和标签,使用\t进行分隔。测试集给定问题对,使用\t进行分隔。
eg:世界上什么东西最恐怖 世界上最恐怖的东西是什么? 1
解析:“世界上什么东西最恐怖”与”世界上最恐怖的东西是什么“问题相同,故是重复问题,标签为1。
  1. 评估指标
    本次竞赛的评价标准采用准确率指标,最高分为1。计算方法参考:
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
accuracy_score(y_true, y_pred)
!pip install paddle-ernie > log.log
import numpy as np
import paddle as P

# 导入ernie,并进行编码测试
from ernie.tokenizing_ernie import ErnieTokenizer
from ernie.modeling_ernie import ErnieModel

# Try to get pretrained model from server, make sure you have network connection
model = ErnieModel.from_pretrained('ernie-1.0')    
model.eval()
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')

ids, _ = tokenizer.encode('hello world')
ids = P.to_tensor(np.expand_dims(ids, 0))  # insert extra `batch` dimension
pooled, encoded = model(ids)                 # eager execution
print(pooled.numpy())    
downloading https://ernie-github.cdn.bcebos.com/model-ernie1.0.1.tar.gz: 788478KB [00:13, 58506.84KB/s]                            
W1202 20:53:59.563866   128 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1202 20:53:59.568325   128 device_context.cc:422] device: 0, cuDNN Version: 7.6.


[[-1.         -1.          0.9947966  -0.99986964 -0.7872068  -1.
  -0.99919456  0.985997   -0.22648299  0.97202295 -0.9994966  -0.9822341
  -0.68219566 -0.9998575  -0.8304648  -0.98049784 -1.          0.99995095
  -0.55144906  0.48972973 -1.          1.          0.14248629 -0.7196953
  -0.90551496  0.9796572  -0.999682    0.8513557   1.         -0.3938747
  -0.9999991   0.99999195 -0.9992019  -0.09899103  0.9999599   0.749042
  -0.99999213 -0.96575606 -0.9960677  -0.953351    0.9999292  -0.91971534
  -0.99989647  0.6303845  -0.9860959  -0.9999173  -1.          0.9999998
  -0.99999994 -0.8262453  -0.97446996  0.9990984   1.         -0.99999976
  -0.9999996  -0.99950695 -0.02158238 -1.          0.9997592   0.9932415
  -0.9991993  -1.          0.99999994  0.39329588 -0.31263578 -0.99844974
  -0.9835465  -0.99998647 -0.98631394  0.9999959   0.98583347 -0.9996068
  -0.9474884  -0.99999726  1.         -0.9858532   0.68813705  1.
   0.99902004  0.9886897   1.          0.99939865 -0.99691194  0.9998406
   1.         -1.         -0.99998885  0.9622678   1.         -0.9073399
  -0.99999887  0.41061124  0.35902554  0.9701588   1.         -1.
   0.99992687 -0.09641007  0.9997865   1.         -0.53010523  0.9979866
  -0.9999991   0.9999998  -0.56906253 -1.          1.         -1.
  -0.83410954 -0.9972185   0.71370375  0.6173501   0.99986655  1.
   1.         -0.9991407  -0.9982252  -0.29370877  0.94986635  0.974079
  -0.9797859   0.9999882  -0.9997873   0.8926128  -0.999963    0.9875355
   0.99996686 -0.99956316 -0.7013649   0.9091245   0.99992985  0.36666185
  -0.99957246 -0.99999994  0.999909    0.99999446  0.99999976 -0.9953213
   0.99999994  0.972278    1.          0.9999953  -0.7862481   0.99953455
   0.8351896   0.99186313  0.9650968  -0.99940604 -0.9712903  -0.9947523
   0.99233043  0.99999034  0.85335356 -0.9400231  -0.97488457  1.
   1.         -1.          0.93516445  0.9987283   1.          0.99971515
  -0.9965259  -0.5467724   1.          0.9942181   0.6069191  -1.
   0.9966893  -0.94737995 -0.99999994  0.99862695  0.9999889   1.
  -1.         -0.5273155   0.04165626  0.9937646   0.88374007 -1.
  -0.98892546  1.          0.9999941  -0.99999994 -0.73221755  0.857093
  -0.99924487  0.3502608   0.98152554 -0.9999997   0.7171253   0.9718659
  -0.9991775   0.99702847 -0.99995285 -0.7097992  -0.99999785 -0.9686826
  -0.3610336  -0.90550834  0.99939483 -0.4930981   0.99990445 -0.9978074
   0.7300049   0.99393106 -0.9893526  -0.99453     0.9983252  -0.9794313
   0.53081703 -0.9979754  -0.9529507   0.9999984  -0.9999998  -0.7356112
   0.81415546  1.          0.9892449  -0.83468103 -1.          0.999743
  -0.86810464 -0.9997558  -0.20181224 -0.99999547 -0.99999934 -0.2182404
   0.9998732   1.          0.9999923  -0.99840605  0.9407252   0.4017602
  -0.9992026  -0.82429636 -0.9429531   0.9997898   0.76942927  0.9777227
   0.99998903 -1.          0.8595473   0.99960047 -0.9967672  -0.9994331
  -0.9997981   0.99998605  0.99965954 -0.99966854 -0.96993816  0.7935611
  -0.45320866  0.9830856  -0.5497223   0.99300534  0.99988496  1.
  -0.9999992  -1.         -0.99618727  0.23314574 -0.99719536 -0.80661523
   0.32799956  1.          0.99999857 -0.97343165 -1.          0.9978576
  -0.9919795   0.99999815 -0.9999999  -0.1334464   1.         -0.99961054
   0.9971871  -0.9864563   0.16868624 -0.99276286  0.99999714  0.9200265
  -0.9242948  -0.99993247  0.9988716  -0.9725936  -1.         -1.
  -1.          0.9969332   0.9999597   0.99504334  0.88538665  0.9835156
   0.99972886  0.99856824 -0.9649251  -0.9971054  -1.         -1.
  -0.8998264  -0.93112624 -0.9740408   0.2537215  -1.         -0.91917837
  -0.999935    0.9993159   1.         -1.          0.02620327 -0.9980975
   0.99772954  1.         -1.          0.99999076  0.9924431   0.9997116
  -1.          0.99882114 -0.50127447  1.         -0.98255926 -0.8043856
   0.9998081  -0.62940514 -0.99988675 -0.9706199   0.9978046  -0.6076621
  -1.          1.         -0.9982328   0.9995839  -0.8645034  -0.762148
  -0.99988246  0.9517748  -0.82121325  0.2945853  -0.99822944  1.
   1.          0.9689294   0.99276227 -0.9999997  -0.99927294  1.
  -0.99918365  0.8140408   0.69133973  0.93894196 -0.8062333   0.9999998
   0.9431836  -0.9999073   0.9899104  -0.96207523 -0.20486923 -0.9984225
  -0.39250267 -0.9938646   0.1934291   0.99784297  0.9914833   1.
   1.         -0.92175925 -1.          0.777627    0.9206414  -1.
  -0.9989347  -0.46215263  0.99968785  0.80873036 -0.20581989  0.8834879
   0.9997589   0.99866796 -0.9922761  -0.99996084  0.99986947 -0.82793474
  -0.99841046 -0.9990964   0.98211896  1.         -0.99999994  0.99616045
   0.99999976 -0.99999243 -0.9999997  -0.99999714  1.          0.9939314
   0.99152046  0.99771667 -0.99999744  1.          0.98841125 -1.
  -0.9999999   0.992869    0.7338416  -1.         -0.99971646 -0.9977229
   1.          0.9964063  -0.99997866  0.9579224   0.99862623 -0.95916873
  -0.93258107  1.          0.9999815  -0.99965525 -0.05712385  0.7929279
   1.         -1.         -0.63294727  0.99999905  0.9999928  -0.9138498
  -0.58142805 -0.9999974  -0.50225466 -0.9884215   0.62460124  0.9952184
  -0.9662707   0.99879575  0.9193481   0.8657087  -0.7023982   0.99981546
   1.          0.93622863  0.8811557  -0.80531764 -0.52614504  0.9999964
   0.7156265   0.65080976 -1.          0.9999999   0.97201777  0.08372057
  -0.95372367 -0.98426515  0.8021691  -0.99989295  0.1394261   0.9476906
   0.33132365  0.9190514  -0.6456231  -0.99998987  1.          1.
   0.9996623   1.          0.980936    0.9585556  -0.999997   -0.9999968
  -0.7713572  -0.32009816 -1.         -0.8165567  -0.9010418  -0.91434306
   0.9901418  -0.9999998   0.99999607 -0.9999999  -0.998353   -0.7945882
  -1.          0.99761385 -0.9991723   0.99999976  0.9976408  -0.9043439
  -0.71953666  0.39754778  0.96926165  0.99999946  0.78094065  0.99996835
   0.20420592  0.94715124  1.         -0.9319126  -0.99999994  0.9995099
   0.04715774 -0.9994818  -0.4299773   1.          0.993541   -0.92046857
  -1.         -0.9998621   0.9452656  -1.         -0.93060195 -0.9613411
   0.92313194 -1.          0.998021   -1.         -0.9650017  -0.9998429
   0.9979359   0.9999817  -1.          1.          0.99999255 -0.63863575
   0.9985998  -0.99998945  0.9899823   0.99999994 -0.999888   -0.99999803
  -0.8187791   0.9694749  -0.99999994 -0.9999899  -0.99989164 -0.9997336
   0.7184304  -0.9999976  -0.99996877 -0.24213222 -0.9999994   0.82644814
   1.          0.9999992  -0.99702954  0.9999466   0.44653273  0.9999996
   0.9047354   0.9982966   0.5623058   0.99999994 -1.         -0.9999001
   0.98437494 -1.         -0.99825895  1.         -0.99998593 -0.9999754
   0.9998784   1.         -0.9999596   0.94500023 -0.23792706 -0.87703276
  -1.         -0.98388237 -0.9954633  -1.          0.95263404 -0.9994247
  -0.9998663  -0.9999824   0.99923533  0.99523705  0.79224235  0.89377266
  -0.9999904  -0.99992746  0.9632811   0.9943542   0.9941962  -0.9999978
   1.         -0.9974183   0.99996316 -0.29181403  0.99776316  0.9827472
   0.9976778   1.          0.90089697  0.2652334  -0.9936546   1.
  -0.91406417 -1.          0.98997    -0.9999998  -1.          0.9895604
   0.98788255  0.9973655  -0.9962906  -0.9982991   0.17212316 -0.99960595
   0.05793795 -1.          0.99999964  0.99120486 -1.         -0.9885665
  -0.9902784  -0.75728595  0.99752456 -0.17518732 -0.99129784 -0.99999326
   0.99997216 -0.99999934 -0.82791257  0.98660254  0.99603254 -0.67563635
  -0.73366654  1.          1.         -0.9998628   0.8316432  -0.36662954
   0.6111789   0.9988032   0.9810423  -0.9588539   0.99946827  0.9999987
  -0.99960554  0.96805805  0.10139702 -1.         -0.99993324  0.5664537
   0.99999994 -0.99796015 -0.9999974   0.9998277  -0.99928087 -0.9969642
   0.9999832  -0.9999965   1.          1.         -0.94715935 -0.99986416
   0.99730986  1.          0.955229    0.9998091   0.99668235  0.99999964
   1.         -0.9913356  -0.99994016  1.          0.92216367  0.9958406
   0.9846747  -0.9994917   0.94751567  0.9999906   0.9875233   0.9999849
  -0.9970721  -1.          0.06024149  0.91710305 -0.999989    1.
   0.9999998  -0.70932066  0.9987246   0.9995062   0.32136855  0.9316998
   1.         -0.9996941   0.9998097  -0.99604577  0.91744435  0.9999964
   0.999698   -0.9854849  -0.99755925  0.88405    -0.92540944 -0.99934196
  -0.81233275  0.9076926  -0.99119806  0.50136054  0.99026984 -0.9805377
  -0.9999451   0.98480725  0.38554776 -0.5592141   0.9901851   0.8039652
   1.         -1.         -0.9997191   0.80112654  0.6773012   0.995455
   0.99996954  0.9993313   0.00101695 -0.9999991   0.99610806 -0.9970495
   0.99996704 -0.99999976 -1.          0.99973965  0.98511857  0.9999948
   1.          0.99614286  0.99999726  0.9982908  -0.7917839   0.99584585
  -0.9914651  -1.          0.9999999   1.          0.999995    0.14218074
   0.99049306 -0.9999999  -0.99067837  0.99999744 -0.99533266  0.98921704
   0.93848974  0.8418761   1.          0.999998    0.980067    0.99886674
   0.9999988   0.99946433  0.98491013  0.9996923  -0.7944232  -0.99994105
   0.99827063  1.         -0.0576716   0.9999987   0.81761754  0.7983499
  -0.14292398  1.         -0.9975951  -0.9999982  -0.9997338  -0.99937415]]
import sys
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
import paddle as P

from ernie.tokenizing_ernie import ErnieTokenizer
from ernie.modeling_ernie import ErnieModelForSequenceClassification
# 读取数据集
train_df = pd.read_csv('train.csv', sep='\t', names=['question1', 'question2', 'label'])
train_df = train_df.sample(frac=1.0)
train_df.head()
question1question2label
2570关于西游记的问题关于《西游记》的一些问题1
4635星天牛吃什么呀?天牛是什么0
1458亲稍等我帮您查询下亲请您稍等一下这边帮您核实一下可以吗1
4750备份怎么下载云备份怎么下载1
2402最近有什么战争?最近有什么娱乐新闻0
# 对句子对进行编码
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
tokenizer.encode('泰囧完整版下载', 'エウテルペ完整版下载')
(array([    1,  1287,  4663,   328,   407,   511,    86,   763,     2,
        17963, 17963, 17963, 17963, 17963,   328,   407,   511,    86,
          763,     2]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))
# 模型超参数
BATCH=16
MAX_SEQLEN=72
LR=5e-5
EPOCH=1

# 文本分类模型
ernie = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=2)
optimizer = P.optimizer.Adam(LR,parameters=ernie.parameters())
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ernie/modeling_ernie.py:296: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
  log.warn('param:%s not set in pretrained model, skip' % k)
[WARNING] 2021-12-02 21:05:35,518 [modeling_ernie.py:  296]:    param:classifier.weight not set in pretrained model, skip
[WARNING] 2021-12-02 21:05:35,519 [modeling_ernie.py:  296]:    param:classifier.bias not set in pretrained model, skip
# 对训练集所有句子进行编码
def make_data(df):
    data = []
    for i, row in enumerate(df.iterrows()):
        text_id, _ = tokenizer.encode(row[1].question1, row[1].question2) 
        text_id = text_id[:MAX_SEQLEN]
        text_id = np.pad(text_id, [0, MAX_SEQLEN-len(text_id)], mode='constant')
        data.append((text_id, row[1].label))
    return data

train_data = make_data(train_df.iloc[:-1000])
val_data = make_data(train_df.iloc[-1000:])
# 得到batch数据
def get_batch_data(data, i):
    d = data[i*BATCH: (i + 1) * BATCH]
    feature, label = zip(*d)
    feature = np.stack(feature)  # 将BATCH行样本整合在一个numpy.array中
    label = np.stack(list(label))
    feature = P.to_tensor(feature) # 使用to_variable将numpy.array转换为paddle tensor
    label = P.to_tensor(label)
    return feature, label
# 模型训练与验证
for i in range(EPOCH):
    np.random.shuffle(train_data) # 每个epoch都shuffle数据以获得最佳训练效果;
    ernie.train()
    for j in range(len(train_data) // BATCH):
        feature, label = get_batch_data(train_data, j)
        loss, _ = ernie(feature, labels=label) 
        loss.backward()
        optimizer.minimize(loss)
        ernie.clear_gradients()
        if j % 50 == 0:
            print('Train %d: loss %.5f' % (j, loss.numpy()))
        
        # 验证
        if j % 100 == 0:
            all_pred, all_label = [], []
            with P.no_grad():
                ernie.eval()
                for j in range(len(val_data) // BATCH):
                    feature, label = get_batch_data(val_data, j)
                    loss, logits = ernie(feature, labels=label)

                    all_pred.extend(logits.argmax(-1).numpy())
                    all_label.extend(label.numpy())
                ernie.train()
            acc = (np.array(all_label) == np.array(all_pred)).astype(np.float32).mean()
            print('Val acc %.5f' % acc)
Train 0: loss 0.70721
Val acc 0.54335
Train 50: loss 0.27199
Train 100: loss 0.35329
Val acc 0.88004
Train 150: loss 0.16315
Train 200: loss 0.08876
Val acc 0.89113
test_df = pd.read_csv('test.csv', sep='\t', names=['question1', 'question2'])
test_df['label'] = 0

test_data = make_data(test_df.iloc[:])
# 模型预测
all_pred, all_label = [], []
with P.no_grad():
    ernie.eval()
    for j in range(len(test_data) // BATCH):
        feature, label = get_batch_data(test_data, j)
        loss, logits = ernie(feature, labels=label)

        all_pred.extend(logits.argmax(-1).numpy())
        all_label.extend(label.numpy())
pd.DataFrame({
    'label': all_pred,
}).to_csv('submit.csv', index=None)

总结和展望

  1. 本项目使用ERNIE来解决NSP任务,在训练过程中使用简单的匹配训练。
  2. 可以考虑pretrain ERNIE模型,来增加模型精度。
  3. 也可以考虑使用sentence-bert的思路来尝试。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐