本项目主要包括以下内容:

  • 项目简介

  • 模型背景及介绍

  • 复现结果

  • 关于本项目

1、简介

本项目为飞桨启航菁英计划中的《Matching Networks for One Shot Learning》论文复现代码。

依赖环境及安装包:

  • paddlepaddle-gpu2.1.2

  • python3.7

代码在Omniglot数据集上进行训练,k-way=5, n-shot=1, 精度98.1

2、模型背景及介绍

参考论文:《Matching Networks for One Shot Learnings》

源代码(Pytorch实现):few-shot

论文简介:

人类可以可以通过非常少量的样本学习到一个新的概念,但是最好的深度学习模型依然需要成百上千的例子来学习到一个新的概念。因此本文就考虑如何通过一个样本就让深度学习模型学会一个新概念。

传统上训练出一个模型需要使用很多样本进行很多次的参数更新,因此作者认为可以使用一个无参数的模型。参考KNN这种度量式的做法,作者将有参数的模型和无参数的模型进行了结合。

在模型层面上,作者提出了一个Matching Networks, 将注意力机制和记忆机制引入快速学习任务中。在训练流程上,作者训练模型时遵循了一个很简单的规则,即测试和训练条件必须匹配。作者在训练时仅用每个类别中很少的样本进行训练,因为在测试时也使用的是很少的样本。

模型结构如图所示:

gθ和 fθ分别是对训练数据和测试数据的编码函数。Matching Networks可以简洁表示为计算一个无标签样本的标签为y^的概率,这个计算方法跟KNN很像,相当于是加权后的KNN:

P ( y ^ ∣ x ^ , S ) = ∑ i = 1 k a ( x ^ , x i ) y i P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i} P(y^x^,S)=i=1ka(x^,xi)yi

其中xi,yi是输入的支撑集(support set)中的样本S, a类似于注意力机制中的核函数,用来度量匹配度。

a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a\left(\hat{x}, x_{i}\right)=\frac{e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)}}{\sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)}} a(x^,xi)=j=1kec(f(x^),g(xj))ec(f(x^),g(xi))

在这里公式 f 定义了对测试样本的编码方式,对于Figure 1 中的 gθ,公式 g 定义了对训练样本的编码方式,对应于Figure 1 中的 fθ。这个公式先计算了一个余弦距离,然后在做一个softmax归一化。

训练函数g:

g是一个BiLSTM,它的输入是xi和支撑集 S 。

g ( x i , S ) = h i → + h i ← + g ′ ( x i ) h ⃗ i , c ⃗ i = LSTM ⁡ ( g ′ ( x i ) , h ⃗ i − 1 , c ⃗ i − 1 ) h ← i , c ← i = LSTM ⁡ ( g ′ ( x i ) , h ← i + 1 , c ˉ i + 1 ) \begin{gathered} g\left(x_{i}, S\right)=\overrightarrow{h_{i}}+\overleftarrow{h_{i}}+g^{\prime}\left(x_{i}\right) \\ \vec{h}_{i}, \vec{c}_{i}=\operatorname{LSTM}\left(g^{\prime}\left(x_{i}\right), \vec{h}_{i-1}, \vec{c}_{i-1}\right) \\ \stackrel{\leftarrow}{h}_{i}, \overleftarrow{c}_{i}=\operatorname{LSTM}\left(g^{\prime}\left(x_{i}\right), \stackrel{\leftarrow}{h}_{i+1}, \bar{c}_{i+1}\right) \end{gathered} g(xi,S)=hi +hi +g(xi)h i,c i=LSTM(g(xi),h i1,c i1)hi,c i=LSTM(g(xi),hi+1,cˉi+1)

g(xi,S)是一个神经网络,比如VGG或者Inception。

测试函数f:

f是一个迭代了K步的 LSTM,它的输出是LSTM最后输出的隐状态 h。即 f(x,S) =hk,其中 hk 由(3)式决定:

h ^ k , c k = LSTM ⁡ ( f ′ ( x ^ ) , [ h k − 1 , r k − 1 ] , c k − 1 ) h k = h ^ k + f ′ ( x ^ ) r k − 1 = ∑ i = 1 ∣ S ∣ a ( h k − 1 , g ( x i ) ) g ( x i ) a ( h k − 1 , g ( x i ) ) = e h k − 1 T g ( x i ) / ∑ j = 1 ∣ S ∣ e h k − 1 T g ( x j ) \begin{aligned} \hat{h}_{k}, c_{k} &=\operatorname{LSTM}\left(f^{\prime}(\hat{x}),\left[h_{k-1}, r_{k-1}\right], c_{k-1}\right) \\ h_{k} &=\hat{h}_{k}+f^{\prime}(\hat{x}) \\ r_{k-1} &=\sum_{i=1}^{|S|} a\left(h_{k-1}, g\left(x_{i}\right)\right) g\left(x_{i}\right) \\ a\left(h_{k-1}, g\left(x_{i}\right)\right) &=e^{h_{k-1}^{T} g\left(x_{i}\right)} / \sum_{j=1}^{|S|} e^{h_{k-1}^{T} g\left(x_{j}\right)} \end{aligned} h^k,ckhkrk1a(hk1,g(xi))=LSTM(f(x^),[hk1,rk1],ck1)=h^k+f(x^)=i=1Sa(hk1,g(xi))g(xi)=ehk1Tg(xi)/j=1Sehk1Tg(xj)

其中, f′是一个embedding函数,比如一个CNN。

项目结构:

├── pretrained/
│   ├── omniglot_n=1_k=5_q=15_nv=1_kv=5_qv=1_dist=l2_fce=None.pdparams       
├── few_shot/                
│   ├── __init__
│   ├── callbacks
│   ├── core
│   ├── datasets
│   ├── eval
│   ├── maml
│   ├── matching
│   ├── metrics
│   ├── models
│   ├── proto
│   ├── train
│   ├── utils
├── config                 
├── matching_nets 
├── prepare_omniglot        
├── requirements.txt                      
├── README.md

1.prepare the Omniglot dataset from the raw Omniglot dataset

  • Augment classes with rotations in multiples of 90 degrees.

  • Downsize images to 28x28

  • Uses background and evaluation sets present in the raw dataset

Edit the DATA_PATHvariable in config.pyto the location where you store the Omniglot and miniImagenet datasets.

After acquiring the data and running the setup scripts your folder structure should look like

DATA_PATH/
    Omniglot/
        images_background/
        images_evaluation/
    miniImageNet/
        images_background/
        images_evaluation/

Place the extracted files into DATA_PATH/Omniglot_Rawand run prepare_omniglot.py

def handle_alphabet(folder):
    print('{}...'.format(folder.split('/')[-1]))
    for rotate in [0, 90, 180, 270]:
        # Create new folders for each augmented alphabet
        mkdir(f'{folder}.{rotate}')
        for root, character_folders, _ in os.walk(folder):
            for character_folder in character_folders:
                # For each character folder in an alphabet rotate and resize all of the images and save
                # to the new folder
                handle_characters(folder, root + '/' + character_folder, rotate)

for root, alphabets, _ in os.walk(prepared_omniglot_location + 'images_background/'):
    for alphabet in sorted(alphabets):
        handle_alphabet(root + alphabet)

2.training

for epoch in range(1, epochs+1):
    callbacks.on_epoch_begin(epoch)
    epoch_logs = {}
    for batch_index, batch in enumerate(dataloader):
        batch_logs = dict(batch=batch_index, size=(batch_size or 1))
        callbacks.on_batch_begin(batch_index, batch_logs)
        x, y = prepare_batch(batch)
        loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs)
        batch_logs['loss'] = loss.cpu().numpy() #loss.items()
         # Loops through all metrics
        batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)
        callbacks.on_batch_end(batch_index, batch_logs)
        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

3.test

def evaluate(model: Layer, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]],
             loss_fn: Callable = None, prefix: str = 'val_', suffix: str = ''):
    logs = {}
    seen = 0
    totals = {m: 0 for m in metrics}
    if loss_fn is not None:
        totals['loss'] = 0
    model.eval()
    with paddle.no_grad():
        for batch in dataloader:
            x, y = prepare_batch(batch)
            y_pred = model(x)
            seen += x.shape[0]
            if loss_fn is not None:
                totals['loss'] += loss_fn(y_pred, y).item() * x.shape[0]
            for m in metrics:
                if isinstance(m, str):
                    v = NAMED_METRICS[m](y, y_pred)
                else:
                    # Assume metric is a callable function
                    v = m(y, y_pred)
                totals[m] += v * x.shape[0]
    for m in ['loss'] + metrics:
        logs[prefix + m + suffix] = totals[m] / seen
    return logs

4.results

Omniglot
k-way5
n-shot1
Pytorch Published (l2)98.3
This paddle Repo (l2)98.85

4、关于本项目

参考论文:《Matching Networks for One Shot Learnings》

源代码(Pytorch实现):few-shot

我们的GitHub项目:few-shot

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐