基于paddle实现的AoANet模型

基于paddle框架的Attention on Attention for Image Captioning实现

注:本项目根目录在/home/aistudio/work/AoANet-Paddle下

Attention on Attention for Image Captioning是发表在2019年ICCV上的一篇关于Image Captioning的文章。现有的Image Captioning模型大多采用Encoder-Decoder以及Attention的架构。在这种架构中,首先使用CNN作为编码器对图像进行编码,得到图像的特征图表示,然后使用RNN进行解码,生成一个个单词。通常,在解码过程中会使用注意力机制,使得解码器在生成不同单词时能够关注图像的不同区域。作者认为,在现有的方法中,解码器无法得知所关注的图像区域是否正确,以及在生成一些单词时,是否需要关注图像区域。针对这个问题,作者提出了Attention on Attention(AoA)模块,如下所示:

AoA采用了Transformer中的多头自注意力机制,以更好地建模图像中不同对象之间的关系。同时,AoA还是引入了门控机制帮助过滤掉无关的注意力结果,仅保留有用的结果。作者将AoA模块同时应用到了编码器和解码器当中,分别使得编码器更好的建模图像中不同对象之间的关系,解码器得到更准确的注意力区域,从而提升了模型性能。

一、简介

本项目基于paddle复现Attention on Attention for Image Captioning中所提出的Attention on Attention模型。该模型在传统的self-attention注意力机制的基础上,添加了gate机制以过滤和query不相关的attention信息。同时,作者还引入multi-head attention用于建模不同目标之间的关系。

论文:

  • [1] L. Huang, W. Wang, J. Chen, X. Wei, “Attention on Attention for Image Captioning”, ICCV, 2019.

参考项目:

二、复现精度

所有指标均为模型在COCO2014的测试集评估而得

指标BlEU-1BlEU-2BlEU-3BlEU-4METEORROUGE-LCIDEr-DSPICE
论文0.8050.6520.5100.3910.2900.5891.2890.227
复现精度0.8020.6480.5040.3850.2860.5851.2710.222

三、数据集

本项目所使用的数据集为COCO2014。该数据集共包含123287张图像,每张图像对应5个标题。训练集、验证集和测试集分别为113287、5000、5000张图像及其对应的标题。本项目使用预提取的bottom-up特征,可以从这里下载得到(我们提供了脚本下载该数据集的标题以及图像特征,见download_dataset.sh)。

四、环境依赖

  • 硬件:CPU、GPU ( > 11G )

  • 软件:

    • Python 3.8
    • Java 1.8.0
    • PaddlePaddle == 2.1.0

五、快速开始

step1: 安装环境及依赖

%cd /home/aistudio/work/AoANet-Paddle/
!pip install -r requirements.txt

Step2: 下载数据

# 下载数据集及特征
bash ./download_dataset.sh
# 下载与计算评价指标相关的文件
bash ./coco-caption/get_google_word2vec_model.sh
bash ./coco-caption/get_stanford_models.sh

注:如果您想省去下载步骤(根据您的网络性能,可能需要花费数个小时),可直接挂载我上传到AI Studio中的数据。数据集地址如下:

使用如下命令进行替换项目中的相关文件夹:

# 解压数据集及特征
unzip -q /home/aistudio/data/data106442/data.zip -d /home/aistudio/work/AoANet-Paddle

# 解压评价指标相关文件
rm -rf /home/aistudio/work/AoANet-Paddle/coco-caption
unzip -q /home/aistudio/data/data110358/coco-caption.zip -d /home/aistudio/work/AoANet-Paddle

Step3: 数据预处理

注:如果您选择的是挂载我上传至AI Studio处理后的数据,那您无需执行此个步骤。

!python prepro.py

step4: 训练

训练过程过程分为两步(详情见论文3.3节):

  • Training with Cross Entropy (XE) Loss
!bash ./train_xe.sh
  • CIDEr-D Score Optimization
!bash ./train_rl.sh

step5: 测试

  • 测试train_xe阶段的模型
!python eval.py --model log/log_aoa/model.pdparams --infos_path log/log_aoa/infos_aoa.pkl --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test
  • 测试train_rl阶段的模型
%cd /home/aistudio/work/AoANet-Paddle/
/
!python eval.py --model log/log_aoa_rl/model.pdparams --infos_path log/log_aoa_rl/infos_aoa.pkl --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test
/home/aistudio/work/AoANet-Paddle
assigned 113287 images to split train
assigned 5000 images to split val
assigned 5000 images to split test
W1224 15:20:28.318421  8227 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1224 15:20:28.321627  8227 device_context.cc:465] device: 0, cuDNN Version: 7.6.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/scipy/sparse/sparsetools.py:21: DeprecationWarning: `scipy.sparse.sparsetools` is deprecated!
scipy.sparse.sparsetools is a private module for scipy.sparse, and should not be used.
  _deprecated()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.
  warnings.warn(msg)
loading annotations into memory...
0:00:00.358340
creating index...
index created!
using 5000/5000 predictions
Loading and preparing results...     
DONE (t=0.02s)
creating index...
index created!
tokenization...
PTBTokenizer tokenized 307085 tokens at 1376004.94 tokens per second.
PTBTokenizer tokenized 52085 tokens at 370379.28 tokens per second.
setting up scorers...
computing Bleu score...
{'testlen': 47086, 'reflen': 47050, 'guess': [47086, 42086, 37086, 32086], 'correct': [37752, 22032, 11286, 5495]}
ratio: 1.0007651434643783
Bleu_1: 0.802
Bleu_2: 0.648
Bleu_3: 0.504
Bleu_4: 0.385
computing METEOR score...
METEOR: 0.286
computing Rouge score...
ROUGE_L: 0.585
computing CIDEr score...
CIDEr: 1.271
computing SPICE score...
Parsing reference captions
Initiating Stanford parsing pipeline
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator tokenize
[main] INFO edu.stanford.nlp.pipeline.TokenizerAnnotator - TokenizerAnnotator: No tokenizer type provided. Defaulting to PTBTokenizer.
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator ssplit
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator parse
[main] INFO edu.stanford.nlp.parser.common.ParserGrammar - Loading parser from serialized file edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz ... 
done [0.4 sec].
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator lemma
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator ner
Loading classifier from edu/stanford/nlp/models/ner/english.all.3class.distsim.crf.ser.gz ... done [2.4 sec].
Loading classifier from edu/stanford/nlp/models/ner/english.muc.7class.distsim.crf.ser.gz ... done [0.5 sec].
Loading classifier from edu/stanford/nlp/models/ner/english.conll.4class.distsim.crf.ser.gz ... done [0.7 sec].
Threads( StanfordCoreNLP ) [02:03.746 minutes]
Threads( StanfordCoreNLP ) [01:55.967 minutes]
Threads( StanfordCoreNLP ) [56.490 seconds]
Parsing test captions
Threads( StanfordCoreNLP ) [37.668 seconds]
SPICE evaluation took: 5.818 min
SPICE: 0.222
{'Bleu_1': 0.8017669795692818, 'Bleu_2': 0.6478615698898944, 'Bleu_3': 0.5036144276330157, 'Bleu_4': 0.3845799617610523, 'METEOR': 0.2863164333219832, 'ROUGE_L': 0.5850462605214983, 'CIDEr': 1.2711546250804326, 'SPICE': 0.22230362654373134, 'bad_count_rate': 0.0008}

使用预训练模型进行预测

模型下载: 谷歌云盘

将下载的模型权重以及训练信息放到log目录下, 运行step5的指令进行测试。

六、复现心得

这是第二次参加飞浆举办的论文复现赛。相比于上一期,复现这一篇论文收获了如下心得:

  • 由于框架的设计差异,paddlepytorch必然会存在API差异,也一定会存在pytorch中有些API在paddle中没有对应的实现。我在复现这篇文章时就遇到了这个问题,这是就需要去思考对应的API所实现的功能,然后利用paddle中的函数自己去实现这个功能。但在实验的过程,一方面要保证功能一致性(前向和反向对齐),另一方面还要保证所实现功能的代码的速度。

  • 飞桨是一个非常好的团队!大家有问题一定要提issue,或者进入专属群询问飞桨工程师们。这要比自己瞎捣鼓效率要高的多!

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐