[论文复现赛第6期] UNITER论文复现冠军方案
在flickr30k数据集上微调预训练UNITER模型,实现SOTA跨模态图文检索性能
基于paddle实现的UNITER模型
基于paddle框架的UNITER: UNiversal Image-TExt Representation Learning实现
注:本项目根目录在/home/aistudio/work/UNITER-Paddle下
一、论文解读
本项目使用paddle框架复现UNITER模型。图像和文本的联合嵌入是大多数视觉和语言任务的基础,在这些任务中,同时处理多模态输入以实现对视觉和文本的共同理解。UNITER是一种通用的图像文本表示形式,它是通过对四个图像文本数据集(COCO,Visual Genome,Conceptual Captions和SBU Captions)进行大规模预训练学习而得,可以为大量具有联合多模式嵌入的多个下游任务提供支持。 UNITER采用了四个预训练任务:掩码语言建模(MLM),掩码区域建模(MRM,具有三个变体),图像文本匹配(ITM)和字区域对齐(WRA)。与先前将联合随机掩码应用于这两种方法的工作不同,UNITER在预训练任务上使用条件掩码(即,掩码语言/区域建模取决于对图像/文本的完全观察)。 广泛的实验表明,UNITER在六个视觉语言任务(超过九个数据集)上实现了最新的技术水平,包括视觉问答,图像文本检索,指代表达理解,视觉常识推理。
论文:
- [1] Y. Chen, L. Li, L. Yu, and et. al, “UNITER: UNiversal Image-TExt Representation Learning”, ECCV, 2020.
参考项目:
- UNITER [官方实现]
二、代码复现
2.1 数据集
本项目所使用的数据集为Flickr30k。该数据集共包含31783张图像,每张图像对应5个标题。训练集、验证集和测试集分别为29783、1000、1000张图像及其对应的标题。本项目使用预提取的bottom-up
特征,可以从这里下载得到。
# 相关数据集已上传至Aistudio
# 详情见: https://aistudio.baidu.com/aistudio/datasetdetail/128538
# paddle格式的预训练权重也已上传至Aistudio
# 详情见: https://aistudio.baidu.com/aistudio/datasetdetail/128538
# 下载或挂载数据集和预训练权重之后
# 需要修改配置文件(configs/retrieval_train.yaml和configs/retrieval_test.yaml的一些参数:
# DATA_DIR (数据集目录), FEAT_FILE (特征文件), PRETRAINED-DIR (预训练权重路径)
# 挂载数据集
!unzip -q /home/aistudio/data/data128538/UNITER-flickr30k-ir.zip -d data/
!unzip -q /home/aistudio/data/data129717/flickr30k_retrieval_large_22Y_02M_27D_15H.zip -d exp/
2.2 环境依赖
-
硬件:CPU、GPU
-
软件:
- Python 3.7
- PaddlePaddle-GPU == 2.2.1
- PaddleNLP==2.2.1
cd /home/aistudio/work/PASSL/
pip install -r requirements.txt
2.3 模型的创建
class UniterModel(UniterPreTrainedModel):
""" Modification for Joint Vision-Language Encoding."""
def __init__(self, config, img_dim):
super().__init__(config)
self.embeddings = UniterTextEmbeddings(config)
self.img_embeddings = UniterImageEmbeddings(config, img_dim)
self.encoder = UniterEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_weights)
def _compute_txt_embeddings(self, input_ids, position_ids,
txt_type_ids=None):
output = self.embeddings(input_ids, position_ids, txt_type_ids)
return output
def _compute_img_embeddings(self, img_feat, img_pos_feat, img_masks=None,
img_type_ids=None):
if img_type_ids is None:
img_type_ids = paddle.ones_like(img_feat[:, :, 0], dtype='int64')
img_type_embeddings = self.embeddings.token_type_embeddings(
img_type_ids)
output = self.img_embeddings(img_feat, img_pos_feat,
img_type_embeddings, img_masks)
return output
def _compute_img_txt_embeddings(self, input_ids, position_ids,
img_feat, img_pos_feat,
gather_index, img_masks=None,
txt_type_ids=None, img_type_ids=None):
txt_emb = self._compute_txt_embeddings(
input_ids, position_ids, txt_type_ids)
img_emb = self._compute_img_embeddings(
img_feat, img_pos_feat, img_masks, img_type_ids)
# align back to most compact input
gather_index = gather_index.unsqueeze(-1).expand(
(-1, -1, self.config.hidden_size))
embedding_output = paddle_gather(paddle.concat([txt_emb, img_emb], axis=1),
dim=1, index=gather_index)
return embedding_output
def forward(self, input_ids, position_ids,
img_feat, img_pos_feat,
attention_mask, gather_index=None, img_masks=None,
output_all_encoded_layers=True,
txt_type_ids=None, img_type_ids=None):
# compute self-attention mask
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.cast(dtype=paddle.float32) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# embedding layer
if input_ids is None:
# image only
embedding_output = self._compute_img_embeddings(
img_feat, img_pos_feat, img_masks, img_type_ids)
elif img_feat is None:
# text only
embedding_output = self._compute_txt_embeddings(
input_ids, position_ids, txt_type_ids)
else:
embedding_output = self._compute_img_txt_embeddings(
input_ids, position_ids,
img_feat, img_pos_feat,
gather_index, img_masks, txt_type_ids, img_type_ids)
encoded_layers = self.encoder(
embedding_output, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return
2.4 训练
export PYTHONPATH=$PWD:$PYTHONPATH
CUDA_VISIBLE_DEVICES='0, 1, 2, 3' python -m paddle.distributed.launch tools/finetune_retrieval.py --cfg_file configs/retrieval_train.yaml
def train_epoch(model, trn_dataloader_img, trn_dataloader_txt, optimizer, scheduler, epoch, logger, args, cfg):# {{{
# Set mode for training
model.train()
# Set epoch for trn_sampler
trn_dataloader_img.batch_sampler.set_epoch(epoch)
trn_dataloader_txt.batch_sampler.set_epoch(epoch)
logger.info('=====> Start epoch {}:'.format(epoch + 1))
print_steps = cfg['MONITOR']['PRINT_FREQ']
grad_accum_steps = cfg['OPTIMIZATION']['GRADIENT_ACCUMULATION_STEPS']
train_loss = 0.
optim_steps = 0
train_iter_img = iter(trn_dataloader_img)
for step, batch_txt in enumerate(trn_dataloader_txt):
# hard text from image
try:
batch_img = next(train_iter_img)
except StopIteration:
train_iter_img = iter(trn_dataloader_img)
batch_img = next(train_iter_img)
# Forward img
loss_img = model(batch_img, sample_from='i', compute_loss=True)
if args.n_gpus > 1:
loss_img = loss_img.mean()
if grad_accum_steps > 1:
loss_img = loss_img / grad_accum_steps
# Backward img
loss_img.backward()
# Forward txt
loss_txt = model(batch_txt, sample_from='t', compute_loss=True)
if args.n_gpus > 1:
loss_txt = loss_txt.mean()
if grad_accum_steps > 1:
loss_txt = loss_txt / grad_accum_steps
# Backward txt
loss_txt.backward()
train_loss += (float(loss_img) + float(loss_txt))
# Update parameters
if (step + 1) % grad_accum_steps == 0:
optimizer.step()
scheduler.step()
optimizer.clear_grad()
# Count optimization steps
optim_steps += 1
# Print log
if optim_steps % print_steps == 0:
logger.info('Epoch [%d], step [%d], training loss: %.5f' % (
epoch + 1, optim_steps, (float(loss_img) + float(loss_txt))))
train_loss = train_loss / step
logger.info('** ** Epoch [%d] done! Training loss: %.5f ** **'
% (epoch + 1, train_loss))# }}}
2.5 验证
# 测试之前,需要在configs/retrieval_test.yaml中指定测试的模型 (即修改EVAL-CHECKPOINT_DIR参数).
python tools/evaluate_retrieval.py --cfg_file configs/retrieval_test.yaml
import os
import sys
sys.path.insert(0, '.')
import argparse
from tqdm import tqdm
# paddle
import paddle
from paddle.io import DataLoader, BatchSampler
# model
from models.uniter_retrieval import UniterForImageTextRetrievalHardNeg
# dataset
from datasets.retrieval_dataset import itm_eval_collate, ItmEvalDataset
# config
from config.default import get_cfg_defaults
# utils
from utils.utils import compute_ranks
from utils.io_utils import TxtTokLmdb, DetectFeatLmdb
def main(args, cfg):
# 1. Create test dataloader
test_img_db = DetectFeatLmdb(img_dir=cfg['DATASET']['IMG_DIR'])
test_txt_dr = os.path.join(cfg['DATASET']['TXT_DIR'], 'itm_flickr30k_{}.db'.format(cfg['DATASET']['TEST']))
test_txt_db = TxtTokLmdb(db_dir=test_txt_dr, max_txt_len=-1)
test_dataset = ItmEvalDataset(test_txt_db, test_img_db, mini_batch_size=cfg['OPTIMIZATION']['DEV_BATCH_SIZE'])
test_sampler = BatchSampler(dataset=test_dataset,
batch_size=1,
shuffle=False,
drop_last=False)
test_dataloader = DataLoader(test_dataset,
batch_sampler=test_sampler,
collate_fn=itm_eval_collate)
# 2. Build model
config = os.path.join(cfg['PRETRAINED']['DIR'], cfg['PRETRAINED']['CONFIG'])
checkpoint = paddle.load(os.path.join(args.checkpoint_dir, 'paddle_model.bin'))['model']
model = UniterForImageTextRetrievalHardNeg.from_pretrained(
config, checkpoint, img_dim=cfg['INPUT']['IMG_DIM'],
margin=cfg['INPUT']['MARGIN'], hard_size=cfg['INPUT']['HARD_NEG_SIZE'])
print('Load state dict from %s.' % args.checkpoint_dir)
model.eval()
# 3. Start to evaluate
score_matrix = paddle.zeros((len(test_dataloader.dataset),
len(test_dataloader.dataset.all_img_ids)))
for i, mini_batches in enumerate(tqdm(test_dataloader)):
j = 0
for batch in mini_batches:
with paddle.no_grad():
scores = model(batch, compute_loss=False)
bs = scores.shape[0]
score_matrix[i, j:j+bs] = scores.squeeze(1)
j += bs
assert j == score_matrix.shape[1]
test_dataset = test_dataloader.dataset
all_txt_ids = [ids for ids in test_dataset.ids]
all_img_ids = test_dataset.all_img_ids
assert score_matrix.shape == [len(all_txt_ids), len(all_img_ids)]
test_results = compute_ranks(score_matrix, all_txt_ids, all_img_ids,
test_dataset.txt2img, test_dataset.img2txts)
print("T2I Retrieval: {:.4f} @ R1, {:.4f} @ R5, {:.4f} @ R10".format(
test_results['txt_r1'], test_results['txt_r5'], test_results['txt_r10']))
print("I2T Retrieval: {:.4f} @ R1, {:.4f} @ R5, {:.4f} @ R10".format(
test_results['img_r1'], test_results['img_r5'], test_results['img_r10']))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg_file', type=str, required=True,
help='Path to the config file for a specific experiment.')
args = parser.parse_args()
# Get the default config & merge from cfg_file
cfg = get_cfg_defaults()
cfg.merge_from_file(args.cfg_file)
# Make sure checkpoint dir exists
args.checkpoint_dir = cfg['EVAL']['CHECKPOINT_DIR']
assert os.path.isdir(args.checkpoint_dir), \
"Please make sure the specified checkpoint dir and eval epoch exist."
# Call main
main(args, cfg)
2.6 复现结果
本项目验证其在图文检索
Image-Text Retrieval
下游任务中的性能,所使用的数据集为Flickr30K,复现精度如下。
指标 | 原论文 | 复现精度 |
---|---|---|
IR-flickr30K-R1 | 73.66 | 74.02 |
三、参考资料
- https://github.com/ChenRocks/UNITER
四、总结
- 这篇文章在TIPC过程中,遇到了一些问题,飞桨的工程师非常热心的帮忙解决了,详情见issue,非常感谢~
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接
更多推荐
所有评论(0)