前言

本项目基于PaddlePaddle框架复现了《Attention on Attention for Image Captioning》论文,目的是依托飞桨深度学习平台的优势,开拓基于飞桨框架的IC研究道路。

这是一篇IC方向比较经典的论文,指标分数也是相当客观。虽然不是IC最前沿的论文,但其中的模型方法值得去学习借鉴,并可以作为未来IC研究的基线模型学习。

依赖环境:

  • paddlepaddle-gpu2.1.2
  • python3.7

代码在coco2014数据集上训练,其复现精度:

SCST(Self-critical Sequence Training)

Bleu_1Bleu_2Bleu_3Bleu_4METEORROUGE_LCIDErSPICE
0.8100.6580.5110.3910.2860.5891.2830.220

模型背景及其介绍

参考论文:《Attention on Attention for Image Captioning》论文链接

发表于2019 IEEE/CVF International Conference on Computer Vision (ICCV)

这篇IC论文针对解码器几乎不了解相关向量和给定注意力信息之间的关系问题,提出增加额外的注意力,以确定注意力结果和查询结果的相关性。作者提出了一个“Attention on Attention”(AoA)模块。AoA首先使用注意力的结果和当前的上下文生成一个“信息向量”和一个“注意力门”,然后通过对它们进行逐元素乘法来增加额外的注意力,最终获得“关注信息”,即预期的有用知识。作者把AoA模型应用与编码与解码阶段,将其命名为AoA Network(AoANet)。

这篇论文的强大创新之处在于使用Transformer的编码部分和LSTM解码器做了巧妙的结合,并在注意力上加入了门控注意,实现了二次注意的效果。我们可以知道,这篇论文提出的模型在2019年获得了较高的性能分数,具有很好的优越性和普适性,非常适合作为未来IC研究的基线模型研究学习,其中的方法和思路值得大家借鉴。

具体细节我们看一下:

在AoANet的编码器部分,通过模型结构分析,实际上是重构了Transformer的自注意力编码方式,在图像特征提取网络和解码器之间架起桥梁,获取图像模态间和模态内的上下文信息表示,我认为这是模型性能获得很大提高的关键。

然后,作者设计在传统注意力基础上加上了门控注意,表示额外的注意,并把软注意力换成了多头自注意力,通过相同的一组编码参数以同质方式对模态关系建模。下面展示了注意力和注意力上的注意力(AoA):

(a) 注意模块根据Q和K之间的相似度得分生成加权平均V;

(b) AoA生成信息向量I和注意门G,并通过元素乘法添加另一个注意。

在AoANet的解码部分,使用了传统的LSTM循环神经网络作为解码器主体部分, 借于它具有记忆功能和处理序列信息的优越性。然后对传统的注意力部分做了设计,更换使用了效果更好的AoA模块。

参考项目地址链接

复现论文代码github地址链接

数据集

coco2014 image captions 论文,采用“Karpathy” data split 论文

数据集总大小:123287张

训练集:113287张

验证集:5000张

测试集:5000张

标签文件:dataset_coco.json

运行

解压预训练数据到work/data/目录下

预加载数据包括: 通过Faster R-CNN提取的coco2014图像显著区域特征(cocobu_att)、池化特征(cocobu_fc)、边框特征(cocobu_box);
cocotalk.json;cocotalk_label.h5。

上述预训练数据也可以通过命令 !python3 scripts/make_bu_data.py 和 !python3 scripts/prepro_labels.py 获得

显著区域特征(cocobu_att)因数据过大,原数据分成了cocobu_att_train和cocobu_att_val上传

%cd /home/aistudio/work/data/
!unzip -oq /home/aistudio/data/data107198/cocobu_att_train.zip
!unzip -oq /home/aistudio/data/data107198/cocobu_att_val.zip
!unzip -oq /home/aistudio/data/data107198/cocobu_fc.zip
!unzip -oq /home/aistudio/data/data107198/cocobu_box.zip
/home/aistudio/work/data

加载完成后,我们把cocobu_att_train和cocobu_att_val合并成cocobu_att

%cd /home/aistudio/work/data/
!mv cocobu_att_val/* cocobu_att_train/
!mv cocobu_att_train cocobu_att
!find . -type d -empty -delete
/home/aistudio/work/data

解压用于训练测试的文件coco-caption到work/目录下

%cd /home/aistudio/work/
!unzip -oq /home/aistudio/data/data118052/coco-caption.zip
/home/aistudio/work

安装依赖库

%cd /home/aistudio/work/
!pip install -r requirements.txt
/home/aistudio/work
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting lmdb (from -r requirements.txt (line 1))
[?25l  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/4d/cf/3230b1c9b0bec406abb85a9332ba5805bdd03a1d24025c6bbcfb8ed71539/lmdb-1.3.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (298kB)
[K     |████████████████████████████████| 307kB 22.5MB/s eta 0:00:01
[?25hCollecting yacs==0.1.7 (from -r requirements.txt (line 2))
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/81/3b/40e876afde9f5ffa1cfdce10565aba85b0dc2e067ed551dfb566cfee6d4d/yacs-0.1.7-py3-none-any.whl
Collecting scikit-image (from -r requirements.txt (line 3))
[?25l  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3MB)
[K     |████████████████████████████████| 13.3MB 9.6MB/s eta 0:00:01
[?25hRequirement already satisfied: PyYAML in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from yacs==0.1.7->-r requirements.txt (line 2)) (5.1.2)
Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 3)) (2.4)
Requirement already satisfied: imageio>=2.4.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 3)) (2.6.1)
Requirement already satisfied: scipy>=1.4.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 3)) (1.6.3)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 3)) (7.1.2)
Requirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 3)) (21.3)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 3)) (1.20.3)
Collecting tifffile>=2019.7.26 (from scikit-image->-r requirements.txt (line 3))
[?25l  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178kB)
[K     |████████████████████████████████| 184kB 7.7MB/s eta 0:00:01
[?25hCollecting PyWavelets>=1.1.1 (from scikit-image->-r requirements.txt (line 3))
[?25l  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1MB)
[K     |████████████████████████████████| 6.2MB 3.6MB/s eta 0:00:01
[?25hRequirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->-r requirements.txt (line 3)) (4.4.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->-r requirements.txt (line 3)) (3.0.6)
Installing collected packages: lmdb, yacs, tifffile, PyWavelets, scikit-image
Successfully installed PyWavelets-1.2.0 lmdb-1.3.0 scikit-image-0.19.1 tifffile-2021.11.2 yacs-0.1.7

训练

训练的日志和模型会放到work/log/目录下

训练过程过程分为两步:Cross-entropy Training和SCST(Self-critical Sequence Training)

Cross-entropy Training

Bleu_1Bleu_2Bleu_3Bleu_4METEORROUGE_LCIDErSPICE
0.7780.6230.4850.3770.2840.5781.1870.215
# Cross-entropy Training
!python3 train.py --cfg configs/aoa.yml
# SCST(Self-critical Sequence Training)
!python3 train.py --cfg configs/aoa_rl.yml

评估

解压预先训练好的模型日志log到work/目录下

加载work/log目录下保存的训练模型数据进行验证

%cd /home/aistudio/work/
!unzip -oq /home/aistudio/data/data118052/log.zip
!python3 eval.py

总结

我之前在论坛写了我的复现经验及心得。这里我给大家再重新呈现出来,并分享了我的解决办法,能够让大家一起讨论学习。

  1. paddle的索引切片处理问题,这里我写了一个函数用于解决,以后大家遇到类似问题可以直接使用:

re_i = paddle.to_tensor([2,1,0])

out = theta[:, re_i] # 会报错

可以转换为:

out = pd_index_slice(theta, re_i, 1) # 达到效果

def pd_index_slice(x, index, axes):
    y = paddle.to_tensor([0])
    for i, k in enumerate(index):
        xs = x.slice(axes=[axes], starts=[k], ends=[k + 1])
        if i == 0:
            y = xs
        else:
            y = paddle.concat([y, xs], axes)
    return y
  1. Paddle和Pytorch的gather()不一样,我使用F.one_hot()实现了Pytorch的gather()一样的功能。

  2. 处理变长序列问题,没有找到类似的功能函数,也是手写了一个处理变长序列的函数,后续会在留言部分贴出。

如果大家在IC研究道路上或运行本项目过程中遇到什么问题可以随时留言讨论,希望我们可以一起完善项目也能一起进步。


Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐