螺旋桨RNA结构预测竞赛:Unpaired Probability Prediction

AI_Drug队伍 B榜第5名方案

1. 赛题说明与理解

“RNA碱基不成对概率”衡量了RNA序列在各个点位是否能形成稳定的碱基对(base pair),是RNA结构的重要属性,并可被应用在mRNA疫苗序列设计、药物研发等领域。例如mRNA疫苗序列通常不稳定,而RNA碱基不成对概率较高的点位正是易被降解的位置;又如RNA 碱基不成对概率较高的点位通常更容易与其他RNA序列相互作用,形成RNA-RNA binding等,这一特性也被广泛应用于疾病诊断(如快速肺结核检测)和RNA药物研发。

paddle平台提供了训练数据集和测试数据集(A榜和B榜测试数据) ,具体介绍如下:
训练数据集: 包括RNA序列 + LinearFold预测结构(模型输入)和 RNA碱基不成对概率标签(模型输出)
A榜测试数据集: 包括RNA序列(模型输入);
B榜测试数据集: 包括无任何公开结构信息的RNA序列。

(以上引用链接 https://aistudio.baidu.com/aistudio/competition/detail/61)

总的来说,就是给定一个RNA一级结构序列和二级结构序列,预测序列中每一个位置的碱基不成对的概率。
当时第一想法就是类似NLP任务中的序列标注任务或者命名实体识别相关任务(预测每一个位置的标签),因此后续比赛过程中尝试的模型主要是NLP和CNN相关的模型及其组合。

2. 数据集的理解和分析

数据集包括4750条训练集数据、250条验证集数据、444条A榜测试数据、112条B榜测试数据。
RNA的长度差异性巨大,并且对训练集和验证集的标签进行了统计,发现不成对的概率值集中在区间[0-1]的两端,在中间的值很少。
并且研究了二级结构序列和标签的关系,发现位置为’.‘的不成对概率平均值为0.805,’(’ 的不成对概率平均值为0.140, ')‘的不成对概率平均值为0.142。左右括号的不成对概率的均值和标准差都很接近,’.'的不成对概率的均值和标准差都较大。

下图为标签的分布,后续在比赛过程中我们尝试过将模型预测的值二值极端化(更接近0或1,在A榜能提高0.02,但是B榜的效果不好,所以在B榜上没有采用这种方法)

3. 特征工程

  1. 比赛过程中我们主要对提供的训练数据,RNA序列和二级结构序列进行ngram编码并使用word embedding的思路,在训练过程中一并学习embedding。尝试ngram大小为1,3,5,7的方案,经过实验比较,ngram为1和3的实验效果最好。数据集的处理和ngram词表生成的代码文件为:dataset.py,原始数据以及处理好的数据存放在mydata文件夹中,ngram词表存放在vocab文件夹中。输入模型时,对RNA一级结构的特征向量和二级结构的特征向量进行简单相加。
  2. 数据增强:我们还尝试根据RNA序列使用linear fold的linear_rna.linear_fold_v函数生成3个相应的二级结构序列,处理方法同上。后面实验发现该方法对部分模型在训练集和验证集上的性能有所提高,但是在A榜和B榜数据集上的性能有所下降,说明该数据增强的方法让我们的模型过拟合了,因此最终的模型没有使用这个方法。数据处理以及特征生成的文件为:dataset_aug.py

4. 网络架构设计

  1. 由于rna序列较长,为了让模型能够关注更长范围的序列相关信息,我们采用transformer作为模型的主结构,来对rna序列的每个token进行编码,经过测试我们发现
    transformer的层数在4层以下模型才能收敛较好。
  2. 如果仅仅使用transformer来对token进行编码能取得一定的模型效果,但考虑到rna序列中每个碱基是否能配对在一定程度上与相邻的碱基存在关系,
    我们尝试在transformer顶层叠加双向的gru/lstm结构(因为这两种结构能重点关注序列短距离关系),经过测试,lstm与gru确实能够显著降低模型在测试集上的rmsd。
  3. 为了加强模型对提取特征的利用能力,我们使用两层MLP而不是单线性层来作为输出层。
  4. 综上,我们的主模型结构如下: input ==> ngram_embedding ==> transformer(1~4 layer) ==> bidirection_lstm/bidirection_gru ==>
    MLP,采用adam优化器进行优化。
  5. 模型训练具体使用参数的设置,详细见python main.py --help,单模型训练实例详见run.sh脚本。

模型整体框架以及流程示意图:

5. 安装运行必要的requirements

!pip install -r requirements.txt

6 .训练

训练的脚本:使用多套不同参数的transformer模型训练并保存到inference_checkpoint/transformers_gru文件夹中 (最后用于模型融合)

!mkdir log
# 多模型训练

!bash run_transformer.sh

7. 推理

运行python test.py --model_dir inference_checkpoint/transformers_gru --ensemble

对多个不同结构的transformer模型在不同的checkpoint下进行模型集.

最终保存下来的结果位置为: result/ensemble_transformers_gru/ensemble/predict.files.zip

!python test.py --model_dir inference_checkpoint/transformers_gru --ensemble

注:仅复现榜单第五名结果的代码以及相关checkpoints,请移步Github

8. 推理在比赛中使用的几个模型及其在A、B榜上的结果

模型Val上rmseA榜上rmse
TextCNN (3gram)0.2570.228
TextCNN_bigru (3gram)0.24380.226
TextCNN_bigru (3gram + 数据增强)0.2380.256
TextCNN_(3gram + 3层CNN卷积)0.2190.217
Transformer (1gram)0.2530.229
Transformer (3gram)0.2480.226
Transformer_bigru (3gram)0.2340.216
集成模型Val上rmseB榜上rmse
TextCNN_(3gram + 3层CNN卷积)0.2190.265
Transformer_bigru (3gram)0.2340.263

9. 总结以及优化思路

  1. TextCNN_bigru较transformer_bigru最大的优势就是训练速度快,模型稳定,结果方差小。但由于没有像transformer那样充分考虑上下文的信息,因此效果没有transfomer好
  2. 除使用Transformer、TextCNN以外,我们还尝试把RNA处理为图数据(能够充分考虑RNA的二级结构信息),并使用图网络模型来对rna进行建模,由于时间有限,该方案并没有作过多调优,效果略差于transformer,代码详见dataset.py中相关部分与model.GraphTransformer,可以通过运行!bash run_graph.sh脚本来实现图网络的训练。
  3. 大规模的预训练或许能够提升词向量的表达能力,可以使用RNA central收集的人类RNA序列数据进行预训练。
  4. 可以使用数据增强的技术,增加训练样本,提高模型的性能。
  5. 增强特征可以通过对RNA的一级结构以及二级结构进行分析,用常见的RNA序列分析软件提取更有效的特征输入模型。
  6. 可以尝试对TextCNN和transformers模块进行融合。
Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐