转自AI Studio,原文链接:​​​​​​论文复现:WS-DAN细粒度分类问题经典之作 - 飞桨AI Studio

论文复现: See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification

一、简介

本篇论文标题名为See Better Before Looking Closer,这是一篇细粒度分类问题的经典论文,所谓细粒度,就是在一个大类下面对小类进行细分,如对鸟、狗的品种与车、飞机的型号进行分类。对于细粒度分类问题,一般的网络只能较为普通的中等性能,如(VGG、ResNet、Inception),而论文《 See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification》提出一种基于弱监督的数据增强网络,即基于注意力图引导的数据增强策略,也就是不仅仅将原图送入网络训练,将增强后的图片也送入到网络训练,最后loss取平均,该部分思想的示意图如图一所示。上半部分为训练阶段的增强策略,分别为Attention Cropping(基于注意力裁剪)与Attention Dropping(基于注意力丢弃);下半部分为测试验证阶段的增强策略,为基于注意力的裁剪,而后resize到原图尺寸大小送入网络进行预测。

图一:基于注意力机制的数据增强

值得注意的是,一般情况我们训练模型时使用的数据增强策略为随机丢弃(遮挡)、随即裁剪等,但是这种随机的方式目的性不强,且容易引入噪声,非常容易裁剪到背景(没有起到增强作用)、或者几乎把主体部分全部裁剪掉了(对于模型收敛有不利的影响),而本文作者提出基于注意力图生成候选区域进行有指向性的裁剪、丢弃,可谓绝佳一笔,随机数据增强方法与基于注意力引导的数据增强方法对比图如图二所示。通俗来说,基于注意图的数据增强策略的思想如下:由于注意力图会注意到图片中主题的一些细节部位,如鸟的嘴部,而通过注意力增强,鸟的嘴部会被裁剪掉,这是便会引导模型更加注重鸟的腹部、羽毛颜色等等其他信息,以此完成数据增强,这也是这篇文章的精髓所在,精度自然显著提升,后文对比试验中将会给出。

图二:随机裁剪与基于注意力的方式

本文的模型结构骨干采用InceptionV3网络,利用其中的mix6e层作为特征图,进一步生成注意力图以进行数据增强,同时注意力图与特征图进行双线性池化(BAP)操作生成最终的特征矩阵,而后flatten送入全连接层进行分类,训练阶段的网络结构如图三所示,由于采取了基于注意力引导的数据增强,使得网络更加健壮,即呼应论文标题See Better,看得更好。

图三:训练阶段网络结构示意图

测试test阶段网络整体结构大体相似,只是较训练阶段少了一个随机丢弃的数据增强操作,很明显测试阶段我们希望输入网络的图片有更加多的信息,因此也就不需要丢弃,所以这一步送入网络的图片为原图和利用特征图进行目标定位后裁剪并Resize的图像,最终预测结果概率二者取平均,这一步也叫做精修(Refinement)环节,也呼应了论文标题中的See Closer,看的更近。

图四:测试阶段网络结构示意图

以上就是本篇论文的核心思想,本项目为基于PaddlePaddle2.2.2的复现

论文: See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification

参考repo: GitHub - wvinzh/WS_DAN_PyTorch: PyTorch Implementation Of WS-DAN(See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification)

二、复现精度与数据集介绍

论文中采用的数据集均为细粒度分类问题的典型代表,包括鸟、飞机、汽车、狗,每一类数据集下为各自大类下的不同小类图片,本项目复现了前三个数据集,并达到了原论文的精度,具体复现精度如下表所示,数据集可通过下方对应链接下载(运行本项目自带,无需单独下载):

DatasetObjectCategoryTrainingTestingACC(复现)ACC(原论文)
CUB-200-2011Bird2005994579489.4089.4
fgvc-aircraftAircraft1006667333394.0393.0
Stanford-CarsCar1968144804194.8894.5
Stanford-DogsDogs120120008580(未要求)92.2

数据集文件夹下的结构如下,解压后在/home/aistudio/work/data文件夹下:

Fine-grained
├── CUB_200_2011
    ├── images
    ├── images.txt
    ├── image_class_labels.txt
    ├── train_test_split.txt
├── Car
    ├── cars_test
    ├── cars_train
    ├── cars_test_annos_withlabels.mat
    ├── devkit
        ├── cars_train_annos.mat
├── fgvc-aircraft-2013b
    ├── data
        ├── variants.txt
        ├── images_variant_trainval.txt
        ├── images_variant_test.txt

三、代码结构

为便于清晰展示代码运行流程以及结构,将主要代码都放在了JupyterNotebook中,详情可见第五部分,其余/home/aistudio/work下的代码结构如下所示

/home/aistudio/work
├── datasets   # 各种数据集定义读取文件夹
    ├── __init__.py  # 读取数据集函数
    ├── aircraft_dataset.py  # 飞机类数据集定义
    ├── bird_dataset.py      # 鸟类数据集定义
    ├── car_dataset.py       # 车类数据集定义
├── models  # 模型相关文件
    ├── bap.py        # BAP模型
    ├── inception.py  # Inceptionv3模型
    ├── wsdan.py      # WS-DAN模型
    ├── InceptionV3_pretrained.pdparams  # Inceptionv3模型权重
├── FGVC  # 模型参数保存与训练日志
    ├── aircraft/ckpt  # 飞机类模型参数以及训练日志
        ├── *.pdparams # 模型网络权重
        ├── *.log      # 训练日志
    ├── brid/ckpt   # 鸟类模型参数以及训练日志
        ├── *.pdparams # 模型网络权重
        ├── *.log      # 训练日志 
    ├── car/ckpt    # 车类模型参数以及训练日志
        ├── *.pdparams # 模型网络权重
        ├── *.log      # 训练日志 
├── imgs         # Markdown 图片资源
├── config.py    # 超参数设置(您可修改)
├── train.py     # 模型训练
└── utils.py     # 工具链 

四、环境依赖

  • 硬件:

    • 4核cpu x86
    • NVIDIA Tesla V100 GPU * 1
  • 框架:

    • PaddlePaddle 2.2.2 + Python3.7
  • 其他依赖项:

    • numpy==1.19.2
    • tqdm==4.59.0
    • Pillow==8.3.1

五、快速开始

  • 首先,您需要运行Firstly下的内容将数据集解压
  • 本项目中包含已经训练好的模型,您可以直接运行模型测试部分进行测试(注意:想更换测试的数据集,更改下方Step0中的config.target_dataset变量即可)
  • 若您想重新训练,您可运行Step1-Step4部分进行模型训练(注意:想更换训练的数据集,仅更改work/config.py文件下的target_dataset即可)

数据集解压: 请先运行本命令将挂载数据集解压到data文件夹

In [ ]

!cd data/ && unzip -oq /home/aistudio/data/data138113/Fine-grained.zip

模型测试:更改您想测试的数据集名称即可(注意:每次更改要重启notebook,否则日志会输出到同一个地方)

In [ ]

import sys
sys.path.append('/home/aistudio/work')
import os
import logging
import config
import paddle
from datasets import getDataset
from paddle.io import DataLoader
from models.wsdan import WSDAN
from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, batch_augment

# 修改您想要测试的数据集
config.target_dataset = 'bird'  # it can be 'car', 'bird', 'aircraft'

# logging config
logging.basicConfig(
    filename=os.path.join('/home/aistudio/work/FGVC/' + config.target_dataset + '/ckpt/', 'test.log'),
    filemode='w',
    format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
    level=logging.INFO)
logging.info('Current Testing Model: {}'.format(config.target_dataset))

# read the dataset
train_dataset, val_dataset = getDataset(config.target_dataset, config.input_size)
train_loader, val_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers), DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers)

# output the dataset info
logging.info('Dataset Name:{dataset_name}, Val:[{val_num}]'.format(dataset_name=config.target_dataset, train_num=len(train_dataset), val_num=len(val_dataset)))
logging.info('Batch Size:[{0}], Train Batches:[{1}], Val Batches:[{2}]'.format(config.batch_size, len(train_loader), len(val_loader)))

# loss and metric
loss_container = AverageMeter(name='loss')
raw_metric = TopKAccuracyMetric(topk=(1, 5))
crop_metric = TopKAccuracyMetric(topk=(1, 5))
drop_metric = TopKAccuracyMetric(topk=(1, 5))
num_classes = train_dataset.num_classes

# network
net = WSDAN(num_classes=num_classes, num_attentions=config.num_attentions, net_name=config.net_name, pretrained=False)
feature_center = paddle.zeros(shape=[num_classes, config.num_attentions * net.num_features])
if config.target_dataset == 'bird':
    net_state_dict = paddle.load("work/FGVC/bird/ckpt/bird_model.pdparams")
if config.target_dataset == 'aircraft':
    net_state_dict = paddle.load("work/FGVC/aircraft/ckpt/aircraft_model.pdparams")
if config.target_dataset == 'car':
    net_state_dict = paddle.load("work/FGVC/car/ckpt/car_model.pdparams")
net.set_state_dict(net_state_dict)
net.eval()

# loss function
cross_entropy_loss = paddle.nn.CrossEntropyLoss()
center_loss = CenterLoss()

logs = {}
for i, (X, y) in enumerate(val_loader):
    # Raw Image
    y_pred_raw, _, attention_map = net(X)

    # Object Localization and Refinement
    crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
    y_pred_crop, _, _ = net(crop_images)

    # Final prediction
    y_pred = (y_pred_raw + y_pred_crop) / 2.

    # loss
    batch_loss = cross_entropy_loss(y_pred, y)
    epoch_loss = loss_container(batch_loss.item())

    # metrics: top-1,5 error
    epoch_acc = raw_metric(y_pred, y)

logs['val_{}'.format(loss_container.name)] = epoch_loss
logs['val_{}'.format(raw_metric.name)] = epoch_acc
batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1])
logging.info(batch_info)
print(batch_info)
W0524 21:08:29.075071  6888 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0524 21:08:29.079933  6888 device_context.cc:465] device: 0, cuDNN Version: 7.6.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Val Loss 0.4921, Val Acc (88.83, 97.57)

模型训练Step1: 导入需要的包以及日志相关设置

In [ ]

import sys
sys.path.append('/home/aistudio/work')
import os
import time
import logging
from tqdm import tqdm
import config
import paddle
from datasets import getDataset
from paddle.io import DataLoader
from models.wsdan import WSDAN
from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, batch_augment
import paddle.nn.functional as F
import datetime

# 若日志保存路径不存在,则新建该文件夹
if not os.path.exists(config.save_dir):
    os.makedirs(config.save_dir)

# 日志格式以及名称配置
current = datetime.datetime.now()
log_name = str(config.target_dataset) + '-0' + str(current.month) + '-' + str(current.day) + '-' + str(current.hour) + '_' + str(current.minute) +  ".log"
logging.basicConfig(filename=os.path.join(config.save_dir, log_name), filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO)

# 当前训练模型信息
logging.info('Current Trainning Model: {}'.format(config.target_dataset))

模型训练Step2: 数据集读取

In [ ]

# 数据集读取
train_dataset, val_dataset = getDataset(config.target_dataset, config.input_size)
train_loader, val_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers), DataLoader(val_dataset, batch_size=config.batch_size * 4, shuffle=False, num_workers=config.workers)
num_classes = train_dataset.num_classes

# 打印当前数据集信息到日志
logging.info('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.format(dataset_name=config.target_dataset, train_num=len(train_dataset), val_num=len(val_dataset)))
logging.info('Batch Size:[{0}], Train Batches:[{1}], Val Batches:[{2}]'.format(config.batch_size, len(train_loader), len(val_loader)))

模型训练Step3: 模型与损失函数定义

In [ ]

# loss and metric
loss_container = AverageMeter(name='loss')
raw_metric = TopKAccuracyMetric(topk=(1, 5))
crop_metric = TopKAccuracyMetric(topk=(1, 5))
drop_metric = TopKAccuracyMetric(topk=(1, 5))

logs = {}

if config.ckpt:
    pretrained = False
else:
    pretrained = True
net = WSDAN(num_classes=num_classes, num_attentions=config.num_attentions, net_name=config.net_name, pretrained=pretrained)
feature_center = paddle.zeros(shape=[num_classes, config.num_attentions * net.num_features])

# 优化器定义
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=config.learning_rate, step_size=2, gamma=0.9)
optimizer = paddle.optimizer.Momentum(learning_rate=scheduler, momentum=0.9, weight_decay=1e-5, parameters=net.parameters())

# 加载训练好的模型以及优化器参数
if config.ckpt:
    net_state_dict = paddle.load(config.save_dir + config.target_dataset + "_model.pdparams")
    optim_state_dict = paddle.load(config.save_dir + config.target_dataset + "_model.pdopt")
    net.set_state_dict(net_state_dict)
    optimizer.set_state_dict(optim_state_dict)

# 损失函数
cross_entropy_loss = paddle.nn.CrossEntropyLoss()
center_loss = CenterLoss()
W0524 21:14:33.915575  7480 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0524 21:14:33.920658  7480 device_context.cc:465] device: 0, cuDNN Version: 7.6.

模型训练Step4: 开始训练

In [ ]

if config.ckpt:
    start_epoch = config.model_num
else:
    start_epoch = 0

max_val_acc = 0  # 最好的精度

# 训练config.epochs次
for epoch in range(start_epoch, start_epoch + config.epochs):
    logs['epoch'] = epoch + 1
    logs['lr'] = optimizer.get_lr()
    logging.info('Epoch {:03d}, lr= {:g}'.format(epoch + 1, optimizer.get_lr()))
    print("Start epoch %d ==========,lr=%f" % (epoch + 1, optimizer.get_lr()))
    pbar = tqdm(total=len(train_loader), unit=' batches')
    pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs))

    # 指标初始化
    loss_container.reset()
    raw_metric.reset()
    crop_metric.reset()
    drop_metric.reset()

    # 开始训练
    start_time = time.time()
    net.train()
    scheduler.step()
    for i, (X, y) in enumerate(train_loader):
        optimizer.clear_grad()
        y_pred_raw, feature_matrix, attention_map = net(X)

        # Update Feature Center
        feature_center_batch = F.normalize(feature_center[y], axis=-1)
        feature_center[y] += config.beta * (feature_matrix.detach() - feature_center_batch)

        # Attention Cropping
        with paddle.no_grad():
            crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1)

        # crop images forward
        y_pred_crop, _, _ = net(crop_images)

        # Attention Dropping
        with paddle.no_grad():
            drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))

        # drop images forward
        y_pred_drop, _, _ = net(drop_images)

        # loss
        batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \
                     cross_entropy_loss(y_pred_crop, y) / 3. + \
                     cross_entropy_loss(y_pred_drop, y) / 3. + \
                     center_loss(feature_matrix, feature_center_batch)

        # backward
        batch_loss.backward()
        optimizer.step()

        with paddle.no_grad():
            epoch_loss = loss_container(batch_loss.item())
            epoch_raw_acc = raw_metric(y_pred_raw, y)
            epoch_crop_acc = crop_metric(y_pred_crop, y)
            epoch_drop_acc = drop_metric(y_pred_drop, y)

        batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format(
            epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1],
            epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1])
        pbar.update()
        pbar.set_postfix_str(batch_info)

    # end of this epoch
    logs['train_{}'.format(loss_container.name)] = epoch_loss
    logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc
    logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc
    logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc
    logs['train_info'] = batch_info
    end_time = time.time()

    # write log for this epoch
    logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))

    # 开始验证,每训练完一轮验证一次精度
    net.eval()
    loss_container.reset()
    raw_metric.reset()
    start_time = time.time()
    net.eval()
    with paddle.no_grad():
        for i, (X, y) in enumerate(val_loader):

            # Raw Image
            y_pred_raw, _, attention_map = net(X)

            # Object Localization and Refinement
            crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
            y_pred_crop, _, _ = net(crop_images)

            # Final prediction
            y_pred = (y_pred_raw + y_pred_crop) / 2.

            # loss
            batch_loss = cross_entropy_loss(y_pred, y)
            epoch_loss = loss_container(batch_loss.item())

            # metrics: top-1,5 error
            epoch_acc = raw_metric(y_pred, y)

    logs['val_{}'.format(loss_container.name)] = epoch_loss
    logs['val_{}'.format(raw_metric.name)] = epoch_acc
    end_time = time.time()
    batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1])
    pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info))

    # write log for this epoch
    logging.info('Valid: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
    logging.info('')
    net.train()
    pbar.close()

    # 模型保存,保存精度最高的模型
    if(epoch_acc[0] > max_val_acc):
        max_val_acc = epoch_acc[0]
        paddle.save(net.state_dict(), config.save_dir + config.target_dataset + "_model.pdparams")
        paddle.save(optimizer.state_dict(), config.save_dir + config.target_dataset + "_model.pdopt")

六、技术路线

以下几点是我在论文复现过程中的一些步骤(仅供参考)

  • 读懂原论文(粗读,不必深究细节)
  • 阅读原论文作者提供的代码结构
  • 精读论文,清楚每个模块的机制
  • 开始框架转写,最好按部就班进行,每一步进行完进行一个对其的验证,总体分为如下四个步骤:

  • 若精度不达标?模型结构、参数初始化、超参数、数据预处理复盘,必要时再看一下论文细节。

七、对比试验

本文的最大精髓之处就在于数据增强,可以看到在鸟类数据集上,采用InceptionV3的baseline,没有数据增强时,准确率为86.4%,在加入随机裁剪、丢弃后略有提高,仅有0.2-0.3%的提升,而加入基于注意力机制的随即裁剪、丢弃,准确率能够提高1%以上,若将注意力裁剪与注意力丢弃进行融合使用,准确率更是达到了88.4%,较baseline提升2%,可见每一个模块都有显著的成效。

Data AugmentationAccuracy(%)
Baseline86.4
Random Cropping86.8
Attention Cropping87.8
Random Dropping86.7
Attention Dropping87.4
Attention Cropping + Attention Dropping88.4

八、复现心得

本次论文复现赛完成了一篇细粒度分类问题的论文复现,该文章原理较为容易理解,其中的弱监督思想也极为巧妙,具有启发性,在复现的过程中,主要工作为原代码Pytorch框架到PaddlePaddle框架的转换,在复现过程中我们可以对照着Pytorch的API文档与PaddlePaddle的API文档进行差异对比,也可以查阅飞桨官方提供的Pytorch1.8与Paddle2.0版本API映射表以便进行快速查阅。

此外,由于论文复现赛中需要模型动转静,而某些API不支持静态图,因此我们需要利用其他API进行替换,如paddle.einsumAPI不支持静态图,则可以用paddle.matmul,paddle.transposeAPI进行替换改写。

最后,对于本论文的数据增强思想,对于一些分类问题非常适用,具有一定启发意义。

九、模型信息

在上述训练完成后,模型和相关日志保存在work/FGVC/对应数据集类别名称/ckpt文件夹下,可以进行查看

信息说明
发布者Victory8858
时间2022.05
框架版本Paddle 2.2.2
应用场景细粒度分类问题
支持硬件GPU、CPU
Aistudio地址论文复现:WS-DAN细粒度分类问题经典之作 - 飞桨AI Studio

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐