基于paddle实现的UNITER模型

基于paddle框架的UNITER: UNiversal Image-TExt Representation Learning实现

注:本项目根目录在/home/aistudio/work/UNITER-Paddle下

一、论文解读

本项目使用paddle框架复现UNITER模型。图像和文本的联合嵌入是大多数视觉和语言任务的基础,在这些任务中,同时处理多模态输入以实现对视觉和文本的共同理解。UNITER是一种通用的图像文本表示形式,它是通过对四个图像文本数据集(COCO,Visual Genome,Conceptual Captions和SBU Captions)进行大规模预训练学习而得,可以为大量具有联合多模式嵌入的多个下游任务提供支持。 UNITER采用了四个预训练任务:掩码语言建模(MLM),掩码区域建模(MRM,具有三个变体),图像文本匹配(ITM)和字区域对齐(WRA)。与先前将联合随机掩码应用于这两种方法的工作不同,UNITER在预训练任务上使用条件掩码(即,掩码语言/区域建模取决于对图像/文本的完全观察)。 广泛的实验表明,UNITER在六个视觉语言任务(超过九个数据集)上实现了最新的技术水平,包括视觉问答,图像文本检索,指代表达理解,视觉常识推理。

在这里插入图片描述

论文:

  • [1] Y. Chen, L. Li, L. Yu, and et. al, “UNITER: UNiversal Image-TExt Representation Learning”, ECCV, 2020.

参考项目:

二、代码复现

2.1 数据集

本项目所使用的数据集为Flickr30k。该数据集共包含31783张图像,每张图像对应5个标题。训练集、验证集和测试集分别为29783、1000、1000张图像及其对应的标题。本项目使用预提取的bottom-up特征,可以从这里下载得到。

# 相关数据集已上传至Aistudio
# 详情见: https://aistudio.baidu.com/aistudio/datasetdetail/128538

# paddle格式的预训练权重也已上传至Aistudio
# 详情见: https://aistudio.baidu.com/aistudio/datasetdetail/128538

# 下载或挂载数据集和预训练权重之后
# 需要修改配置文件(configs/retrieval_train.yaml和configs/retrieval_test.yaml的一些参数:
# DATA_DIR (数据集目录), FEAT_FILE (特征文件), PRETRAINED-DIR (预训练权重路径)

# 挂载数据集
!unzip -q /home/aistudio/data/data128538/UNITER-flickr30k-ir.zip -d data/
!unzip -q /home/aistudio/data/data129717/flickr30k_retrieval_large_22Y_02M_27D_15H.zip -d exp/

2.2 环境依赖

  • 硬件:CPU、GPU

  • 软件:

    • Python 3.7
    • PaddlePaddle-GPU == 2.2.1
    • PaddleNLP==2.2.1
cd /home/aistudio/work/PASSL/
pip install -r requirements.txt

2.3 模型的创建

class UniterModel(UniterPreTrainedModel):
    """ Modification for Joint Vision-Language Encoding."""
    def __init__(self, config, img_dim):
        super().__init__(config)
        self.embeddings = UniterTextEmbeddings(config)
        self.img_embeddings = UniterImageEmbeddings(config, img_dim)
        self.encoder = UniterEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_weights)

    def _compute_txt_embeddings(self, input_ids, position_ids,
                                txt_type_ids=None):
        output = self.embeddings(input_ids, position_ids, txt_type_ids)
        return output

    def _compute_img_embeddings(self, img_feat, img_pos_feat, img_masks=None,
                                img_type_ids=None):
        if img_type_ids is None:
            img_type_ids = paddle.ones_like(img_feat[:, :, 0], dtype='int64')
        img_type_embeddings = self.embeddings.token_type_embeddings(
            img_type_ids)
        output = self.img_embeddings(img_feat, img_pos_feat,
                                     img_type_embeddings, img_masks)
        return output

    def _compute_img_txt_embeddings(self, input_ids, position_ids,
                                    img_feat, img_pos_feat,
                                    gather_index, img_masks=None,
                                    txt_type_ids=None, img_type_ids=None):
        txt_emb = self._compute_txt_embeddings(
            input_ids, position_ids, txt_type_ids)
        img_emb = self._compute_img_embeddings(
            img_feat, img_pos_feat, img_masks, img_type_ids)
        # align back to most compact input
        gather_index = gather_index.unsqueeze(-1).expand(
            (-1, -1, self.config.hidden_size))
        embedding_output = paddle_gather(paddle.concat([txt_emb, img_emb], axis=1),
                                         dim=1, index=gather_index)
        return embedding_output

    def forward(self, input_ids, position_ids,
                img_feat, img_pos_feat,
                attention_mask, gather_index=None, img_masks=None,
                output_all_encoded_layers=True,
                txt_type_ids=None, img_type_ids=None):
        # compute self-attention mask
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.cast(dtype=paddle.float32)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # embedding layer
        if input_ids is None:
            # image only
            embedding_output = self._compute_img_embeddings(
                img_feat, img_pos_feat, img_masks, img_type_ids)
        elif img_feat is None:
            # text only
            embedding_output = self._compute_txt_embeddings(
                input_ids, position_ids, txt_type_ids)
        else:
            embedding_output = self._compute_img_txt_embeddings(
                input_ids, position_ids,
                img_feat, img_pos_feat,
                gather_index, img_masks, txt_type_ids, img_type_ids)

        encoded_layers = self.encoder(
            embedding_output, extended_attention_mask,
            output_all_encoded_layers=output_all_encoded_layers)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return 

2.4 训练

export PYTHONPATH=$PWD:$PYTHONPATH
CUDA_VISIBLE_DEVICES='0, 1, 2, 3' python -m paddle.distributed.launch tools/finetune_retrieval.py --cfg_file configs/retrieval_train.yaml
def train_epoch(model, trn_dataloader_img, trn_dataloader_txt, optimizer, scheduler, epoch, logger, args, cfg):# {{{
    # Set mode for training
    model.train()
    # Set epoch for trn_sampler
    trn_dataloader_img.batch_sampler.set_epoch(epoch)
    trn_dataloader_txt.batch_sampler.set_epoch(epoch)

    logger.info('=====> Start epoch {}:'.format(epoch + 1))

    print_steps = cfg['MONITOR']['PRINT_FREQ']
    grad_accum_steps = cfg['OPTIMIZATION']['GRADIENT_ACCUMULATION_STEPS']

    train_loss = 0.
    optim_steps = 0
    train_iter_img = iter(trn_dataloader_img)
    for step, batch_txt in enumerate(trn_dataloader_txt):
        # hard text from image
        try:
            batch_img = next(train_iter_img)
        except StopIteration:
            train_iter_img = iter(trn_dataloader_img)
            batch_img = next(train_iter_img)

        # Forward img
        loss_img = model(batch_img, sample_from='i', compute_loss=True)
        if args.n_gpus > 1:
            loss_img = loss_img.mean()
        if grad_accum_steps > 1:
            loss_img = loss_img / grad_accum_steps
        # Backward img
        loss_img.backward()

        # Forward txt
        loss_txt = model(batch_txt, sample_from='t', compute_loss=True)
        if args.n_gpus > 1:
            loss_txt = loss_txt.mean()
        if grad_accum_steps > 1:
            loss_txt = loss_txt / grad_accum_steps
        # Backward txt
        loss_txt.backward()

        train_loss += (float(loss_img) + float(loss_txt))

        # Update parameters
        if (step + 1) % grad_accum_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.clear_grad()

            # Count optimization steps
            optim_steps += 1

            # Print log
            if optim_steps % print_steps == 0:
                logger.info('Epoch [%d], step [%d], training loss: %.5f' % (
                            epoch + 1, optim_steps, (float(loss_img) + float(loss_txt))))

    train_loss = train_loss / step
    logger.info('** ** Epoch [%d] done! Training loss: %.5f ** **'
                 % (epoch + 1, train_loss))# }}}

2.5 验证

# 测试之前,需要在configs/retrieval_test.yaml中指定测试的模型 (即修改EVAL-CHECKPOINT_DIR参数).
python tools/evaluate_retrieval.py --cfg_file configs/retrieval_test.yaml
import os
import sys
sys.path.insert(0, '.')
import argparse
from tqdm import tqdm

# paddle
import paddle
from paddle.io import DataLoader, BatchSampler

# model
from models.uniter_retrieval import UniterForImageTextRetrievalHardNeg
# dataset
from datasets.retrieval_dataset import itm_eval_collate, ItmEvalDataset
# config
from config.default import get_cfg_defaults
# utils
from utils.utils import compute_ranks
from utils.io_utils import TxtTokLmdb, DetectFeatLmdb


def main(args, cfg):
    # 1. Create test dataloader
    test_img_db = DetectFeatLmdb(img_dir=cfg['DATASET']['IMG_DIR'])
    test_txt_dr = os.path.join(cfg['DATASET']['TXT_DIR'], 'itm_flickr30k_{}.db'.format(cfg['DATASET']['TEST']))
    test_txt_db = TxtTokLmdb(db_dir=test_txt_dr, max_txt_len=-1)
    test_dataset = ItmEvalDataset(test_txt_db, test_img_db, mini_batch_size=cfg['OPTIMIZATION']['DEV_BATCH_SIZE'])
    test_sampler = BatchSampler(dataset=test_dataset,
                                batch_size=1,
                                shuffle=False,
                                drop_last=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_sampler=test_sampler,
                                 collate_fn=itm_eval_collate)

    # 2. Build model
    config = os.path.join(cfg['PRETRAINED']['DIR'], cfg['PRETRAINED']['CONFIG'])
    checkpoint = paddle.load(os.path.join(args.checkpoint_dir, 'paddle_model.bin'))['model']
    model = UniterForImageTextRetrievalHardNeg.from_pretrained(
        config, checkpoint, img_dim=cfg['INPUT']['IMG_DIM'],
        margin=cfg['INPUT']['MARGIN'], hard_size=cfg['INPUT']['HARD_NEG_SIZE'])
    print('Load state dict from %s.' % args.checkpoint_dir)
    model.eval()

    # 3. Start to evaluate
    score_matrix = paddle.zeros((len(test_dataloader.dataset),
                                len(test_dataloader.dataset.all_img_ids)))

    for i, mini_batches in enumerate(tqdm(test_dataloader)):
        j = 0
        for batch in mini_batches:
            with paddle.no_grad():
                scores = model(batch, compute_loss=False)
            bs = scores.shape[0]
            score_matrix[i, j:j+bs] = scores.squeeze(1)
            j += bs
        assert j == score_matrix.shape[1]

    test_dataset = test_dataloader.dataset
    all_txt_ids = [ids for ids in test_dataset.ids]
    all_img_ids = test_dataset.all_img_ids
    assert score_matrix.shape == [len(all_txt_ids), len(all_img_ids)]

    test_results = compute_ranks(score_matrix, all_txt_ids, all_img_ids,
                                 test_dataset.txt2img, test_dataset.img2txts)

    print("T2I Retrieval: {:.4f} @ R1, {:.4f} @ R5, {:.4f} @ R10".format(
        test_results['txt_r1'], test_results['txt_r5'], test_results['txt_r10']))
    print("I2T Retrieval: {:.4f} @ R1, {:.4f} @ R5, {:.4f} @ R10".format(
        test_results['img_r1'], test_results['img_r5'], test_results['img_r10']))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg_file', type=str, required=True,
        help='Path to the config file for a specific experiment.')
    args = parser.parse_args()

    # Get the default config & merge from cfg_file
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg_file)

    # Make sure checkpoint dir exists
    args.checkpoint_dir = cfg['EVAL']['CHECKPOINT_DIR']
    assert os.path.isdir(args.checkpoint_dir), \
        "Please make sure the specified checkpoint dir and eval epoch exist."

    # Call main
    main(args, cfg)

2.6 复现结果

本项目验证其在图文检索Image-Text Retrieval下游任务中的性能,所使用的数据集为Flickr30K,复现精度如下。

指标论文复现精度
IR-flickr30K-R173.6674.02

三、参考资料

  • https://github.com/ChenRocks/UNITER

四、总结

  • 这篇文章在TIPC过程中,遇到了一些问题,飞桨的工程师非常热心的帮忙解决了,详情见issue,非常感谢~

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐