文本智能校对Baseline
包含ERNIE-CSC纠错字模型、T5生成模型、nezha分类模型等
·
文本智能校对大赛初赛Baseline
一、数据集介绍
# 遍历数据文件夹
!ls -l /home/aistudio/all_data/preliminary_a_data/
total 368228
-rw-r--r-- 1 aistudio aistudio 214928 Jul 15 15:14 preliminary_a_test_source.json
-rw-r--r-- 1 aistudio aistudio 400324 Jul 15 15:14 preliminary_extend_train.json
-rw-r--r-- 1 aistudio aistudio 376015154 Jul 15 15:14 preliminary_train.json
-rw-r--r-- 1 aistudio aistudio 423209 Jul 15 15:14 preliminary_val.json
-rw-r--r-- 1 aistudio aistudio 1381 Jul 15 15:14 README.md
- preliminary_train:伪数据约100w, 均为负样本
- preliminary_extend_train: 真实场景训练数据约1000条, 均为负样本
- preliminary_val:真实场景下验证集约1000条(包括约500条正样本和500条负样本)
- preliminary_a_test_source: 真实场景下测试集约1000条(包括约500条正样本和500条负样本)
二、思路
我们可以把数据分为正确数据(即不需要纠错),错别字,语义错误,错别字+语义错误4类。
简单来看可以把纠错后与纠错前等长的视为错别字,不等长的视为语义错误;而两者混合错误不好判断,暂不考虑。
但是按上面做法可以发现:
- 字数相同的也会出现语义错误,这种是颠倒语序造成的语义错误
- 语义错误很多是字词重复
三、做法
-
分类数据
-
对识别为错字的先进行错字纠错;无变化再进行语义错误纠错
-
语义错误纠错包含两部分:
3.1 算法识别重复内容去除
3.2 语义纠错模型纠错
本文作者训练的模型已上传到数据集中,文本智能校对预训练模型
三、训练分类模型
-
在nezha模型、ernie模型或其他模型的基础上,首先使用preliminary_train作为训练集,preliminary_extend_train作为验证集训练模型,该阶段模型保存在*_ckpt文件夹中
-
再在训练后模型基础上,使用preliminary_extend_train作为训练集,preliminary_val作为验证集训练模型,该阶段模型保存在*_ft_ckpt文件夹中
-
最后在上一步模型基础上,使用preliminary_val作为训练集训练模型,该阶段模型保存在*_el_ckpt文件夹中
!pip install --upgrade paddlenlp
!python train_classification.py --model_name_or_path nezha_el_ckpt --learning_rate 2e-5\
--train_data_path all_data/preliminary_a_data/preliminary_val.json --train_data_is_ground_eval True\
--eval_data_path all_data/preliminary_a_data/preliminary_val.json --eval_data_is_ground_eval True\
--max_seq_length 128 --batch_size 32\
--model_save_path nezha_el_ckpt --epoch 30 --print_step 10 --eval_step 20
预测测试集的type
import json
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from eval import predict
model_name = "nezha_el_ckpt"
num_classes = 3
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=num_classes)
tokenizer = AutoTokenizer.from_pretrained(model_name)
[2022-07-19 00:50:06,729] [ INFO] - We are using <class 'paddlenlp.transformers.nezha.modeling.NeZhaForSequenceClassification'> to load 'nezha_el_ckpt'.
[2022-07-19 00:50:11,175] [ INFO] - We are using <class 'paddlenlp.transformers.nezha.tokenizer.NeZhaTokenizer'> to load 'nezha_el_ckpt'.
resule_save = []
with open('all_data/preliminary_a_data/preliminary_a_test_source.json') as f:
test_raw_data = json.load(f)
test_data = [{'text': data['source']} for data in test_raw_data]
label_map = {0: 'Positive', 1: 'Misspelled words', 2: 'Semantic Error'}
results = predict(model, test_data, tokenizer, label_map, batch_size=32)
for idx, text in enumerate(test_data):
resule_save.append({'source': text['text'], 'type': results[idx], 'id': test_raw_data[idx]['id']})
print('Data: {} \t Lable: {}'.format(text, results[idx]))
print(len(resule_save))
with open('preliminary_a_test_source_with_type.json', 'w') as f:
json.dump(resule_save, f)
1019
四、分离两种错误数据
import json
misspell_words_data = ''
semantic_error_data = ''
with open('all_data/preliminary_a_data/preliminary_train.json') as f:
all_data = json.load(f)
for data in all_data:
if len(data['source']) == len(data['target']):
misspell_words_data += data['source'] + '\t' + data['target'] + '\n'
else:
semantic_error_data += data['source'] + '\t' + data['target'] + '\n'
with open('all_data/preliminary_a_misspelled_words_data/preliminary_train.txt', 'w') as f:
f.write(misspell_words_data)
with open('all_data/preliminary_a_semantic_error_data/preliminary_train.txt', 'w') as f:
f.write(semantic_error_data)
misspell_words_data = ''
semantic_error_data = ''
with open('all_data/preliminary_a_data/preliminary_extend_train.json') as f:
all_data = json.load(f)
for data in all_data:
if len(data['source']) == len(data['target']):
misspell_words_data += data['source'] + '\t' + data['target'] + '\n'
else:
semantic_error_data += data['source'] + '\t' + data['target'] + '\n'
with open('all_data/preliminary_a_misspelled_words_data/preliminary_extend_train.txt', 'w') as f:
f.write(misspell_words_data)
with open('all_data/preliminary_a_semantic_error_data/preliminary_extend_train.txt', 'w') as f:
f.write(semantic_error_data)
misspell_words_data = ''
semantic_error_data = ''
with open('all_data/preliminary_a_data/preliminary_val.json') as f:
all_data = json.load(f)
for data in all_data:
if data['type'] == 'negative':
if len(data['source']) == len(data['target']):
misspell_words_data += data['source'] + '\t' + data['target'] + '\n'
else:
semantic_error_data += data['source'] + '\t' + data['target'] + '\n'
with open('all_data/preliminary_a_misspelled_words_data/preliminary_val.txt', 'w') as f:
f.write(misspell_words_data)
with open('all_data/preliminary_a_semantic_error_data/preliminary_val.txt', 'w') as f:
f.write(semantic_error_data)
五、ERNIE-CSC模型 纠正错字
训练
%cd ernie_csc
[Errno 2] No such file or directory: 'ernie_csc'
/home/aistudio/ernie_csc
!pip install -r requirements.txt
! python download.py --data_dir ./extra_train_ds/ --url https://github.com/wdimmy/Automatic-Corpus-Generation/raw/master/corpus/train.sgml
100%|█████████████████████████████████████| 22934/22934 [17:17<00:00, 22.10it/s]
! python change_sgml_to_txt.py -i extra_train_ds/train.sgml -o extra_train_ds/train.txt
!python train.py --batch_size 40 --logging_steps 100 --epochs 5 --learning_rate 5e-5 --max_seq_length 128\
--model_name_or_path ernie-3.0-xbase-zh\
--full_model_path ../ernie_csc_pre_ckpt/best_model.pdparams\
--output_dir ../ernie_csc_ft_ckpt/ --extra_train_ds_dir ./extra_train_ds
导出
!python export_model.py --model_name_or_path ernie-3.0-xbase-zh --params_path ../ernie_csc_pre_ckpt/best_model.pdparams --output_path ../ernie_csc_infer_model/static_graph_params
[32m[2022-07-17 16:34:19,931] [ INFO][0m - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/ernie_3.0_xbase_zh.pdparams[0m
W0717 16:34:19.933907 77787 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0717 16:34:19.937021 77787 gpu_context.cc:306] device: 0, cuDNN Version: 8.2.
[32m[2022-07-17 16:34:26,181] [ INFO][0m - Weights from pretrained model not used in ErnieModel: ['cls.predictions.layer_norm.bias', 'cls.predictions.transform.weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight'][0m
[0m
预测
from paddlenlp.transformers import ErnieTokenizer
from predict import Predictor
from paddlenlp.data import Vocab
tokenizer = ErnieTokenizer.from_pretrained('ernie-3.0-xbase-zh')
pinyin_vocab = Vocab.load_vocabulary('./pinyin_vocab.txt',
unk_token='[UNK]',
pad_token='[PAD]')
predictor = Predictor('../ernie_csc_infer_model/static_graph_params.pdmodel',
'../ernie_csc_infer_model/static_graph_params.pdiparams',
'gpu', 128, tokenizer, pinyin_vocab)
samples = [
'遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
'人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。',
]
results = predictor.predict(samples, batch_size=2)
for source, target in zip(samples, results):
print("Source:", source)
print("Target:", target)
[2022-07-19 00:04:04,168] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/ernie_3.0_xbase_zh_vocab.txt
[2022-07-19 00:04:04,195] [ INFO] - tokenizer config file saved in /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/tokenizer_config.json
[2022-07-19 00:04:04,197] [ INFO] - Special tokens file saved in /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/special_tokens_map.json
Source: 遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。
Target: 遇到逆境时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。
Source: 人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。
Target: 人生就是如此,经过磨练才能让自己更加茁壮,才能使自己更加乐观。
[1m[35m--- Running analysis [ir_graph_build_pass][0m
[1m[35m--- Running analysis [ir_graph_clean_pass][0m
[1m[35m--- Running analysis [ir_analysis_pass][0m
[32m--- Running IR pass [is_test_pass][0m
[32m--- Running IR pass [simplify_with_basic_ops_pass][0m
[32m--- Running IR pass [conv_bn_fuse_pass][0m
[32m--- Running IR pass [conv_eltwiseadd_bn_fuse_pass][0m
[32m--- Running IR pass [embedding_eltwise_layernorm_fuse_pass][0m
I0719 00:04:05.387586 39570 fuse_pass_base.cc:57] --- detected 1 subgraphs
[32m--- Running IR pass [multihead_matmul_fuse_pass_v2][0m
[32m--- Running IR pass [gpu_cpu_squeeze2_matmul_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_reshape2_matmul_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_flatten2_matmul_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_map_matmul_v2_to_mul_pass][0m
I0719 00:04:05.718148 39570 fuse_pass_base.cc:57] --- detected 242 subgraphs
[32m--- Running IR pass [gpu_cpu_map_matmul_v2_to_matmul_pass][0m
I0719 00:04:05.746395 39570 fuse_pass_base.cc:57] --- detected 80 subgraphs
[32m--- Running IR pass [gpu_cpu_map_matmul_to_mul_pass][0m
[32m--- Running IR pass [fc_fuse_pass][0m
I0719 00:04:06.649116 39570 fuse_pass_base.cc:57] --- detected 242 subgraphs
[32m--- Running IR pass [fc_elementwise_layernorm_fuse_pass][0m
I0719 00:04:06.909458 39570 fuse_pass_base.cc:57] --- detected 80 subgraphs
[32m--- Running IR pass [conv_elementwise_add_act_fuse_pass][0m
[32m--- Running IR pass [conv_elementwise_add2_act_fuse_pass][0m
[32m--- Running IR pass [conv_elementwise_add_fuse_pass][0m
[32m--- Running IR pass [transpose_flatten_concat_fuse_pass][0m
[32m--- Running IR pass [runtime_context_cache_pass][0m
[1m[35m--- Running analysis [ir_params_sync_among_devices_pass][0m
I0719 00:04:07.052198 39570 ir_params_sync_among_devices_pass.cc:100] Sync params from CPU to GPU
[1m[35m--- Running analysis [adjust_cudnn_workspace_size_pass][0m
[1m[35m--- Running analysis [inference_op_replace_pass][0m
[1m[35m--- Running analysis [ir_graph_to_program_pass][0m
I0719 00:04:08.071785 39570 analysis_predictor.cc:1007] ======= optimize end =======
I0719 00:04:08.097136 39570 naive_executor.cc:102] --- skip [feed], feed -> pinyin_ids
I0719 00:04:08.097163 39570 naive_executor.cc:102] --- skip [feed], feed -> input_ids
I0719 00:04:08.106650 39570 naive_executor.cc:102] --- skip [softmax_21.tmp_0], fetch -> fetch
I0719 00:04:08.106679 39570 naive_executor.cc:102] --- skip [linear_364.tmp_1], fetch -> fetch
六、T5模型 纠正语法
训练
%cd t5
/home/aistudio/t5
!python train.py --batch_size 32 --logging_steps 100 --epochs 5 --learning_rate 5e-5 --max_seq_length 128\
--model_name_or_path Langboat/mengzi-t5-base\
--output_dir ../t5_ft_ckpt/ --train_ds_dir ../all_data/preliminary_a_semantic_error_data --eval_ds_dir ../all_data/preliminary_a_semantic_error_data/eval
预测
import paddle
from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer
import json
tokenizer = T5Tokenizer.from_pretrained('Langboat/mengzi-t5-base')
model = T5ForConditionalGeneration.from_pretrained('../t5_ft_ckpt/model_38000_best')
[2022-07-18 20:03:57,432] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base/spiece.model
[2022-07-18 20:03:57,434] [ INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/added_tokens.json and saved to /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base
[2022-07-18 20:03:57,436] [ INFO] - Downloading added_tokens.json from https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/added_tokens.json
[2022-07-18 20:03:57,471] [ INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/special_tokens_map.json and saved to /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base
[2022-07-18 20:03:57,474] [ INFO] - Downloading special_tokens_map.json from https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/special_tokens_map.json
[2022-07-18 20:03:57,517] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base/tokenizer_config.json
W0718 20:03:57.622402 39570 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0718 20:03:57.625485 39570 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.
res_t5_pre = []
with open('../all_data/preliminary_a_data/preliminary_a_test_source.json') as f:
all_data = json.load(f)
for data in all_data:
text = data['source']
inputs = tokenizer(text)
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
output = model.generate(**inputs)
gen_text = tokenizer.decode(list(output[0].numpy()[0]), skip_special_tokens=True)
gen_text = gen_text.replace(',', ',')
gen_text.replace(gen_text[-3:], text[text.rfind(gen_text[-3:]):])
gen_text = gen_text if text.rfind(gen_text[-3:]) == -1 else gen_text + text[text.rfind(gen_text[-3:])+3:]
res_t5_pre.append({'inference': gen_text, 'id': data['id']})
print(len(res_t5_pre))
1019
with open('../preliminary_a_test_inference.json', 'w') as f:
json.dump(res_t5_pre, f, ensure_ascii=False)
七、去重
不考虑效率就用了比较慢的循环方法
def remove_duplication(text: str):
length = len(text)
for i in range(length):
cp_range = min(i+1, length-i-1)
j = 0
flag = False
for j in range(cp_range):
if text[i-j:i+1] == text[i+1:i+1+j+1]:
flag = True
break
if flag:
return text.replace(text[i-j:i+1], '', 1)
return text
remove_duplication('我爱爱你')
'我爱你'
八、综合
首先使路径在 ‘/home/aistudio’,然后运行分类模型,对测试数据进行分类
import json
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from eval import predict
model_name = "nezha_el_ckpt"
num_classes = 3
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=num_classes)
tokenizer = AutoTokenizer.from_pretrained(model_name)
resule_save = []
with open('all_data/preliminary_a_data/preliminary_a_test_source.json') as f:
test_raw_data = json.load(f)
test_data = [{'text': data['source']} for data in test_raw_data]
label_map = {0: 'Positive', 1: 'Misspelled words', 2: 'Semantic Error'}
results = predict(model, test_data, tokenizer, label_map, batch_size=32)
for idx, text in enumerate(test_data):
resule_save.append({'source': text['text'], 'type': results[idx], 'id': test_raw_data[idx]['id']})
print('Data: {} \t Lable: {}'.format(text, results[idx]))
print(len(resule_save))
with open('all_data/preliminary_a_test_source_with_type.json', 'w') as f:
json.dump(resule_save, f)
然后定义ERNIE-CSC错别字纠正函数
%cd ernie_csc/
from paddlenlp.transformers import ErnieTokenizer
from predict import Predictor
from paddlenlp.data import Vocab
%cd ../
ernie_tokenizer = ErnieTokenizer.from_pretrained('ernie-3.0-xbase-zh')
pinyin_vocab = Vocab.load_vocabulary('ernie_csc/pinyin_vocab.txt',
unk_token='[UNK]',
pad_token='[PAD]')
ernie_predictor = Predictor('ernie_csc_infer_model/static_graph_params.pdmodel',
'ernie_csc_infer_model/static_graph_params.pdiparams',
'gpu', 128, ernie_tokenizer, pinyin_vocab)
def ernie_predict(samples: list, batch_size=2) -> list:
results = ernie_predictor.predict(samples, batch_size=batch_size)
return results
# 使用示例
# samples = [
# '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
# '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。',
# ]
# results = ernie_predict(samples)
# for source, target in zip(samples, results):
# print("Source:", source)
# print("Target:", target)
/home/aistudio/ernie_csc
/home/aistudio
[2022-07-19 12:25:40,615] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/ernie_3.0_xbase_zh_vocab.txt
[2022-07-19 12:25:40,644] [ INFO] - tokenizer config file saved in /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/tokenizer_config.json
[2022-07-19 12:25:40,647] [ INFO] - Special tokens file saved in /home/aistudio/.paddlenlp/models/ernie-3.0-xbase-zh/special_tokens_map.json
E0719 12:25:40.651176 6786 analysis_config.cc:95] Please compile with gpu to EnableGpu()
[1m[35m--- Running analysis [ir_graph_build_pass][0m
[1m[35m--- Running analysis [ir_graph_clean_pass][0m
[1m[35m--- Running analysis [ir_analysis_pass][0m
[32m--- Running IR pass [simplify_with_basic_ops_pass][0m
[32m--- Running IR pass [layer_norm_fuse_pass][0m
[37m--- Fused 0 subgraphs into layer_norm op.[0m
[32m--- Running IR pass [attention_lstm_fuse_pass][0m
[32m--- Running IR pass [seqconv_eltadd_relu_fuse_pass][0m
[32m--- Running IR pass [seqpool_cvm_concat_fuse_pass][0m
[32m--- Running IR pass [mul_lstm_fuse_pass][0m
[32m--- Running IR pass [fc_gru_fuse_pass][0m
[37m--- fused 0 pairs of fc gru patterns[0m
[32m--- Running IR pass [mul_gru_fuse_pass][0m
[32m--- Running IR pass [seq_concat_fc_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_squeeze2_matmul_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_reshape2_matmul_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_flatten2_matmul_fuse_pass][0m
[32m--- Running IR pass [matmul_v2_scale_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_map_matmul_v2_to_mul_pass][0m
I0719 12:25:51.128208 6786 fuse_pass_base.cc:57] --- detected 242 subgraphs
[32m--- Running IR pass [gpu_cpu_map_matmul_v2_to_matmul_pass][0m
I0719 12:25:51.167219 6786 fuse_pass_base.cc:57] --- detected 80 subgraphs
[32m--- Running IR pass [matmul_scale_fuse_pass][0m
[32m--- Running IR pass [gpu_cpu_map_matmul_to_mul_pass][0m
[32m--- Running IR pass [fc_fuse_pass][0m
I0719 12:25:54.406841 6786 fuse_pass_base.cc:57] --- detected 242 subgraphs
[32m--- Running IR pass [repeated_fc_relu_fuse_pass][0m
[32m--- Running IR pass [squared_mat_sub_fuse_pass][0m
[32m--- Running IR pass [conv_bn_fuse_pass][0m
[32m--- Running IR pass [conv_eltwiseadd_bn_fuse_pass][0m
[32m--- Running IR pass [conv_transpose_bn_fuse_pass][0m
[32m--- Running IR pass [conv_transpose_eltwiseadd_bn_fuse_pass][0m
[32m--- Running IR pass [is_test_pass][0m
[32m--- Running IR pass [runtime_context_cache_pass][0m
[1m[35m--- Running analysis [ir_params_sync_among_devices_pass][0m
[1m[35m--- Running analysis [adjust_cudnn_workspace_size_pass][0m
[1m[35m--- Running analysis [inference_op_replace_pass][0m
[1m[35m--- Running analysis [ir_graph_to_program_pass][0m
I0719 12:25:54.972319 6786 analysis_predictor.cc:1007] ======= optimize end =======
I0719 12:25:55.012966 6786 naive_executor.cc:102] --- skip [feed], feed -> pinyin_ids
I0719 12:25:55.013001 6786 naive_executor.cc:102] --- skip [feed], feed -> input_ids
I0719 12:25:55.027304 6786 naive_executor.cc:102] --- skip [softmax_21.tmp_0], fetch -> fetch
I0719 12:25:55.027344 6786 naive_executor.cc:102] --- skip [linear_364.tmp_1], fetch -> fetch
再定义T5纠正函数
import paddle
from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer
T5_tokenizer = T5Tokenizer.from_pretrained('Langboat/mengzi-t5-base')
T5_model = T5ForConditionalGeneration.from_pretrained('./t5_ft_ckpt/model_38000_best')
def T5_predict(samples: list) -> list:
res_t5_pre = []
for text in samples:
inputs = T5_tokenizer(text)
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
output = T5_model.generate(**inputs)
gen_text = T5_tokenizer.decode(list(output[0].numpy()[0]), skip_special_tokens=True)
gen_text = gen_text.replace(',', ',')
gen_text = gen_text.replace('?', '?')
gen_text = gen_text.replace(';', ';')
gen_text = gen_text.replace('!', '!')
# 这里是补充生成文本不完整,可能会确实后半部分文本
gen_text.replace(gen_text[-3:], text[text.rfind(gen_text[-3:]):])
gen_text = gen_text if text.rfind(gen_text[-3:]) == -1 else gen_text + text[text.rfind(gen_text[-3:])+3:]
res_t5_pre.append(gen_text)
return res_t5_pre
# 使用示例
# samples = [
# '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
# '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。',
# ]
# results = T5_predict(samples)
# for source, target in zip(samples, results):
# print("Source:", source)
# print("Target:", target)
[2022-07-19 12:26:04,885] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base/spiece.model
[2022-07-19 12:26:04,888] [ INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/added_tokens.json and saved to /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base
[2022-07-19 12:26:04,890] [ INFO] - Downloading added_tokens.json from https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/added_tokens.json
[2022-07-19 12:26:04,954] [ INFO] - Downloading https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/special_tokens_map.json and saved to /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base
[2022-07-19 12:26:04,957] [ INFO] - Downloading special_tokens_map.json from https://bj.bcebos.com/paddlenlp/models/community/Langboat/mengzi-t5-base/special_tokens_map.json
[2022-07-19 12:26:05,019] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/Langboat/mengzi-t5-base/tokenizer_config.json
定义去重函数
def remove_duplication(text: str):
length = len(text)
for i in range(length):
cp_range = min(i+1, length-i-1)
j = 0
flag = False
for j in range(cp_range):
if text[i-j:i+1] == text[i+1:i+1+j+1]:
flag = True
break
if flag:
return text.replace(text[i-j:i+1], '', 1)
return text
预测结果
import json
res = []
with open('all_data/preliminary_a_test_source_with_type.json') as f:
raw_data = json.load(f)
# 这里为了简单,就没有把同一种模型的调用合并成一个list
for data in raw_data:
if data['type'] == 'Positive':
res.append({'inference': data['source'], 'id': data['id']})
elif data['type'] == 'Misspelled words':
csc_res = ernie_predict([data['source']])[0]
if csc_res == data['source']:
remove_res = remove_duplication(csc_res)
if remove_res == csc_res:
t5_res = T5_predict([csc_res])[0]
res.append({'inference': t5_res, 'id': data['id']})
else:
res.append({'inference': remove_res, 'id': data['id']})
else:
res.append({'inference': csc_res, 'id': data['id']})
else:
remove_res = remove_duplication(data['source'])
if remove_res == csc_res:
t5_res = T5_predict([csc_res])[0]
res.append({'inference': t5_res, 'id': data['id']})
else:
res.append({'inference': remove_res, 'id': data['id']})
print(len(res))
with open('preliminary_a_test_inference5.json', 'w', encoding='utf-8') as f:
json.dump(res, f, ensure_ascii=False)
1019
此文仅为搬运,原作链接:https://aistudio.baidu.com/aistudio/projectdetail/4340298?shared=1
更多推荐
已为社区贡献1438条内容
所有评论(0)