★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

一、论文概述

小样本学习是一个具有挑战性的问题,因为每个新任务只提供了非常少的训练样本。解决这个挑战的有效研究方法之一,是通过学习由查询(query)样本和少量支持(support)样本之间的相似性度量驱,从而获得切实可行的深度特征表示。从统计上讲,这相当于衡量图像间特征的依赖性。先前的方法要么只使用边缘分布而不考虑联合分布,导致表示能力有限,要么通过利用联合分布虽然能得到更好的结果,但也引入了很高的计算成本。在本文中,作者提出了一种用于小样本分类的深度布朗距离协方差(DeepBDC)方法。DeepBDC的核心思想是通过衡量嵌入特征的联合特征函数与边缘乘积之间的差异来学习图像表示。由于BDC度量是解耦的,作者将其制定为一个高度模块化的算法。此外,作者在两个不同的小样本分类框架中实例化了DeepBDC。在六个标准的小样本图像基准测试上进行了实验,涵盖了一般对象识别、细粒度分类和跨域分类。实验评估结果显示,DeepBDC在性能上明显优于其他方法,并取得了新的最先进结果。

作者基于传统小样本学习中的元学习框架,和迁移学习框架,分别构建了下图所示的两种学习结构。本项目也是基于作者在github上开源的代码进行改写,采用飞桨深度学习框架重新实现的。训练收敛后,测试指标能够与论文中的指标平齐。

原论文地址:https://arxiv.org/pdf/2204.04567.pdf

作者开源代码地址:https://github.com/Fei-Long121/DeepBDC

二、复现精度

基于paddlepaddle深度学习框架,对文献算法进行复现后,本项目在mini-ImageNet上达到的测试精度,如下表所示。

task 5-Way 1-Shot 5-Way 5-Shot
原始论文精度 STL DeepBDC 67.83±0.43 85.45±0.29
本项目精度 STL DeepBDC 66.68±0.44 84.35±0.29

模型分别基于元学习方式(Meta)和迁移学习方式(STL)分别进行了代码实现。

1、Meta DeepBDC

Meta DeepBDC模型训练包括了两个过程,首先是模型预训练,按照典型分类网络的训练过程,将整个训练集送入backbone进行训练;然后是微调过程,按照episode training的训练范式,配置为5-Way 1-Shot和5-Shot方式进行微调训练。这两个训练过程的训练超参数设置如下:

(1)预训练过程

超参数名 设置值
lr 5e-2
t_lr 1e-3
gamma 0.1
epoch 170
milestones 100 150
batch_size 512
reduce_dim 640
dropout_rate 0.8

(2)5-way 1-shot微调训练过程

超参数名 设置值
lr 1e-4
gamma 0.1
epoch 100
train_n_episode 1000
milestones 40 80
train_n_way 5
n_shot 1
reduce_dim 640

(3)5-way 5-shot微调训练过程

超参数名 设置值
lr 1e-4
gamma 0.1
epoch 100
train_n_episode 600
milestones 40 80
train_n_way 5
n_shot 5
reduce_dim 640

2、STL DeepBDC

STL DeepBDC模型训练包括了两个过程,首先是模型预训练,按照典型分类网络的训练过程,将整个训练集送入backbone进行训练;然后是3次自蒸馏训练过程。这两个训练过程的训练超参数设置如下:

(1)预训练过程

超参数名 设置值
lr 5e-2
t_lr 1e-3
gamma 0.1
epoch 170
milestones 100 150
batch_size 512
reduce_dim 128
dropout_rate 0.5

(2)自蒸馏过程

超参数名 设置值
lr 5e-2
t_lr 1e-3
gamma 0.1
epoch 170
milestones 100 150
batch_size 512
reduce_dim 128
dropout_rate 0.5

三、数据集

miniImageNet数据集节选自ImageNet数据集。
DeepMind团队首次将miniImageNet数据集用于小样本学习研究,从此miniImageNet成为了元学习和小样本领域的基准数据集。
关于该数据集的介绍可以参考https://blog.csdn.net/wangkaidehao/article/details/105531837

miniImageNet是由Oriol Vinyals等在Matching Networks
中首次提出的,该文献是小样本分类任务的开山制作,也是本次复现论文关于该数据集的参考文献。在Matching Networks中,
作者提出对ImageNet中的类别和样本进行抽取(参见其Appendix B),形成了一个数据子集,将其命名为miniImageNet。
划分方法,作者仅给出了一个文本文件进行说明。
Vinyals在文中指明了miniImageNet图片尺寸为84x84。因此,后续小样本领域的研究者,均是基于原始图像,在代码中进行预处理,
将图像缩放到84x84的规格。

至于如何缩放到84x84,本领域研究者各有各的方法,通常与研究者的个人理解相关,但一般对实验结果影响不大。本次文献论文原文,未能给出
miniImageNet的具体实现方法,本项目即参考领域内较为通用的预处理方法进行处理。

  • 数据集大小:
    • miniImageNet包含100类共60000张彩色图片,其中每类有600个样本。
      mini-imagenet一共有2.86GB
  • 数据格式:
|- miniImagenet
|  |- images/
|  |  |- n0153282900000005.jpg 
|  |  |- n0153282900000006.jpg
|  |  |- …
|  |- train.csv
|  |- test.csv
|  |- val.csv

数据集链接:miniImagenet

四、环境依赖

  • 硬件:
    • x86 cpu
    • NVIDIA GPU
  • 框架:
    • PaddlePaddle = 2.4
  • 其他依赖项:
    • numpy==1.19.3
    • tqdm==4.59.0
    • Pillow==8.3.1
    • sklearn1.0.1

五、快速开始

1、解压数据集和源代码:

!unzip -n -d ./data/ ./data/data105646/mini-imagenet-sxc.zip

%cd /home/aistudio/
!unzip -n -d ./data/ ./data/data105646/mini-imagenet-sxc.zip
%cd /home/aistudio
!unzip -o DeepBDC-baseline-paddle.zip
# 生成json文件
!cp write_miniImagenet_filelist.py /home/aistudio/data/mini-imagenet-sxc/
%cd /home/aistudio/data/mini-imagenet-sxc/
!python write_miniImagenet_filelist.py

3、执行以下命令启动元学习训练和评估:

cd /home/aistudio/DeepBDC-baseline-paddle
bash ./scripts/mini_imagenet/run_meta_deepbdc/run.sh

模型开始训练,运行完毕后,训练log和模型参数保存在./checkpoints/mini_imagenet/ResNet12_meta_deepbdc_5way_1shot_metatrain/和./checkpoints/mini_imagenet/ResNet12_meta_deepbdc_5way_5shot_metatrain/目录下,分别是:

best_model.pdparams  # 最优模型参数文件
output.log  # 训练LOG信息

训练完成后,可将上述文件手动保存到其他目录下,避免被后续训练操作覆盖。评估过程将在训练完成后自动进行,评估结果将打印在输出栏。

注意:如有报错,可能是由于run.sh文本文件在windows下创建,换行符为\r\t,导致linux下不能正确识别。可自行进入.sh文件,复制出其中的命令运行即可。
%cd /home/aistudio/DeepBDC-baseline-paddle
!bash ./scripts/mini_imagenet/run_meta_deepbdc/run.sh

4、执行以下命令启动迁移学习训练和评估:

cd /home/aistudio/DeepBDC-baseline-paddle
bash ./scripts/mini_imagenet/run_stl_deepbdc/run.sh

模型开始训练,运行完毕后,预训练和自蒸馏的log和模型参数保存在以下路径

./checkpoints/mini_imagenet/ResNet12_stl_deepbdc_pretrain
./checkpoints/mini_imagenet/ResNet12_stl_deepbdc_distill_born1
./checkpoints/mini_imagenet/ResNet12_stl_deepbdc_distill_born1
./checkpoints/mini_imagenet/ResNet12_stl_deepbdc_distill_born1

模型和LOG信息为这些路径下的相应文件:

last_model.pdparams  # 最优模型参数文件
output.log  # 训练LOG信息

训练完成后,可将上述文件手动保存到其他目录下,避免被后续训练操作覆盖。评估过程将在训练完成后自动进行,评估结果将打印在输出栏。

注意:如有报错,可能是由于run.sh文本文件在windows下创建,换行符为\r\t,导致linux下不能正确识别。可自行进入.sh文件,复制出其中的命令运行即可。
!cd /home/aistudio/DeepBDC-baseline-paddle
!bash ./scripts/mini_imagenet/run_stl_deepbdc/run.sh

六、代码结构与详细说明

6.1 代码结构

├── data                               # 数据处理相关
│   ├── datamgr.py                       # data manager模块
│   ├── dataset.py                       # data set模块
├── methods                             # 模型相关
│   ├── bdc_module.py                     # BDC核心模块
│   ├── meta_deepbdc.py                    # 元训练算法
│   ├── stl_deepbdc.py                    # 简单迁移学习训练算法
│   ├── template.py                      # 训练模板
├── network                             # backbone
│   ├── conv.py                         # Conv-4和Conv-6代码实现
│   ├── resnet.py                        # ResNet-12代码实现
├── scripts                             # 运行工程脚本
│   ├── mini_imagenet                     
│   │   ├── run_meta_deepbdc                    
│   │   │   ├── run.sh                 # 运行微调训练
│   │   ├── run_stl_deepbdc                    
│   │   │   ├── run.sh                 # 运行微调训练
├── distillation.py                        # 自蒸馏训练代码
├── meta_train.py                         # 微调训练代码
├── pretrain.py                          # 预训练代码
├── test.py                             # 测试代码
├── utils.py                            # 公共调用函数

6.2 参数说明

可以在 pretrain.pydistillation.pymeta_train.py 中设置训练与评估相关参数,具体如下:

参数 默认值 说明
----batch_size 256 batch size
–lr 0.05 初始学习率
–wd 5e-4 weight decay超参
–gamma 0.1 lr_scheduler衰减系数
–milestones 80, 120 达到相应epoch后,lr_scheduler开始衰减
–epoch 150 遍历数据集的迭代轮数
–gpu True 是否使用GPU进行训练
–dataset mini_imagenet 指定训练数据集
–data_path ‘’ 指定数据集的路径
–model ResNet-12 指定采用的backbone
–val meta 指定验证方式
–train_n_way 20 小样本训练类别数
–val_n_episode 600 验证时测试多少个episode
–val_n_way 5 小样本验证类别数
–n_shot 1 给定支持样本的个数
–n_query 15 指定查询样本的个数
–num_classes 64 指定base set类别总数
–save_freq 50 指定每隔多少个epoch保存一次模型参数
–seed 0 指定随机数种子
–resume ‘’ 指定恢复训练时加载的中间参数文件路径
–reduce_dim ‘’ 指定BDC模块特征降维的维度
–dropout_rate ‘’ 指定BDC模块的dropout概率

6.3 训练流程

可参考快速开始章节中的描述

训练输出

执行训练开始后,将得到类似如下的输出。每一轮epoch训练将会打印当前training loss、training acc、val loss、val acc以及训练kl散度。

Epoch 0 | Batch 0/150 | Loss 4.158544
best model! save...
val loss is 0.00, val acc is 37.46
model best acc is 37.46, best acc epoch is 0
This epoch use 7.61 minutes
train loss is 3.72, train acc is 10.84
Epoch 1 | Batch 0/150 | Loss 3.052964
val loss is 0.00, val acc is 37.46
model best acc is 37.46, best acc epoch is 0
This epoch use 3.73 minutes
train loss is 2.96, train acc is 25.28
Epoch 2 | Batch 0/150 | Loss 2.588413
val loss is 0.00, val acc is 37.46
model best acc is 37.46, best acc epoch is 0
This epoch use 3.71 minutes
train loss is 2.59, train acc is 33.27
...

6.4 测试流程

可参考快速开始章节中的描述,也可直接执行以下命令:

echo "============= meta-test 1-shot ============="
python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method stl_deepbdc --image_size 84 --gpu --n_shot 1 --model_path ./checkpoints/mini_imagenet/ResNet12_stl_deepbdc_distill_born3/last_model.pdparams --test_task_nums 5 --penalty_C 0.1 --reduce_dim 128 --test_n_episode 2000

echo "============= meta-test 5-shot ============="
python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method stl_deepbdc --image_size 84 --gpu --n_shot 5 --model_path ./checkpoints/mini_imagenet/ResNet12_stl_deepbdc_distill_born3/last_model.pdparams --test_task_nums 5 --penalty_C 2 --reduce_dim 128 --test_n_episode 2000

七、模型信息

训练完成后,模型和相关LOG保存在./results/5w1s和./results/5w5s目录下。

训练和测试日志保存在results目录下。

信息 说明
发布者 hrdwsong
时间 2023.06
框架版本 Paddle 2.4
应用场景 小样本学习
支持硬件 GPU、CPU
Aistudio地址 https://aistudio.baidu.com/aistudio/projectdetail/6402612?sUid=527829&shared=1&ts=1687140394955

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐