基于PaddlePaddle复现Few-shot learning
本项目在参与【飞桨启航菁英计划】过程中完成,由百度官方提供算力支持,基于PaddlePaddle复现论文《Matching Networks for One Shot Learning》
本项目主要包括以下内容:
-
项目简介
-
模型背景及介绍
-
复现结果
-
关于本项目
1、简介
本项目为飞桨启航菁英计划中的《Matching Networks for One Shot Learning》论文复现代码。
依赖环境及安装包:
-
paddlepaddle-gpu2.1.2
-
python3.7
代码在Omniglot数据集上进行训练,k-way=5, n-shot=1, 精度98.1
2、模型背景及介绍
参考论文:《Matching Networks for One Shot Learnings》
源代码(Pytorch实现):few-shot
论文简介:
人类可以可以通过非常少量的样本学习到一个新的概念,但是最好的深度学习模型依然需要成百上千的例子来学习到一个新的概念。因此本文就考虑如何通过一个样本就让深度学习模型学会一个新概念。
传统上训练出一个模型需要使用很多样本进行很多次的参数更新,因此作者认为可以使用一个无参数的模型。参考KNN这种度量式的做法,作者将有参数的模型和无参数的模型进行了结合。
在模型层面上,作者提出了一个Matching Networks, 将注意力机制和记忆机制引入快速学习任务中。在训练流程上,作者训练模型时遵循了一个很简单的规则,即测试和训练条件必须匹配。作者在训练时仅用每个类别中很少的样本进行训练,因为在测试时也使用的是很少的样本。
模型结构如图所示:
gθ和 fθ分别是对训练数据和测试数据的编码函数。Matching Networks可以简洁表示为计算一个无标签样本的标签为y^的概率,这个计算方法跟KNN很像,相当于是加权后的KNN:
P ( y ^ ∣ x ^ , S ) = ∑ i = 1 k a ( x ^ , x i ) y i P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i} P(y^∣x^,S)=i=1∑ka(x^,xi)yi
其中xi,yi是输入的支撑集(support set)中的样本S, a类似于注意力机制中的核函数,用来度量匹配度。
a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a\left(\hat{x}, x_{i}\right)=\frac{e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)}}{\sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)}} a(x^,xi)=∑j=1kec(f(x^),g(xj))ec(f(x^),g(xi))
在这里公式 f 定义了对测试样本的编码方式,对于Figure 1 中的 gθ,公式 g 定义了对训练样本的编码方式,对应于Figure 1 中的 fθ。这个公式先计算了一个余弦距离,然后在做一个softmax归一化。
训练函数g:
g是一个BiLSTM,它的输入是xi和支撑集 S 。
g ( x i , S ) = h i → + h i ← + g ′ ( x i ) h ⃗ i , c ⃗ i = LSTM ( g ′ ( x i ) , h ⃗ i − 1 , c ⃗ i − 1 ) h ← i , c ← i = LSTM ( g ′ ( x i ) , h ← i + 1 , c ˉ i + 1 ) \begin{gathered} g\left(x_{i}, S\right)=\overrightarrow{h_{i}}+\overleftarrow{h_{i}}+g^{\prime}\left(x_{i}\right) \\ \vec{h}_{i}, \vec{c}_{i}=\operatorname{LSTM}\left(g^{\prime}\left(x_{i}\right), \vec{h}_{i-1}, \vec{c}_{i-1}\right) \\ \stackrel{\leftarrow}{h}_{i}, \overleftarrow{c}_{i}=\operatorname{LSTM}\left(g^{\prime}\left(x_{i}\right), \stackrel{\leftarrow}{h}_{i+1}, \bar{c}_{i+1}\right) \end{gathered} g(xi,S)=hi+hi+g′(xi)hi,ci=LSTM(g′(xi),hi−1,ci−1)h←i,ci=LSTM(g′(xi),h←i+1,cˉi+1)
g(xi,S)是一个神经网络,比如VGG或者Inception。
测试函数f:
f是一个迭代了K步的 LSTM,它的输出是LSTM最后输出的隐状态 h。即 f(x,S) =hk,其中 hk 由(3)式决定:
h ^ k , c k = LSTM ( f ′ ( x ^ ) , [ h k − 1 , r k − 1 ] , c k − 1 ) h k = h ^ k + f ′ ( x ^ ) r k − 1 = ∑ i = 1 ∣ S ∣ a ( h k − 1 , g ( x i ) ) g ( x i ) a ( h k − 1 , g ( x i ) ) = e h k − 1 T g ( x i ) / ∑ j = 1 ∣ S ∣ e h k − 1 T g ( x j ) \begin{aligned} \hat{h}_{k}, c_{k} &=\operatorname{LSTM}\left(f^{\prime}(\hat{x}),\left[h_{k-1}, r_{k-1}\right], c_{k-1}\right) \\ h_{k} &=\hat{h}_{k}+f^{\prime}(\hat{x}) \\ r_{k-1} &=\sum_{i=1}^{|S|} a\left(h_{k-1}, g\left(x_{i}\right)\right) g\left(x_{i}\right) \\ a\left(h_{k-1}, g\left(x_{i}\right)\right) &=e^{h_{k-1}^{T} g\left(x_{i}\right)} / \sum_{j=1}^{|S|} e^{h_{k-1}^{T} g\left(x_{j}\right)} \end{aligned} h^k,ckhkrk−1a(hk−1,g(xi))=LSTM(f′(x^),[hk−1,rk−1],ck−1)=h^k+f′(x^)=i=1∑∣S∣a(hk−1,g(xi))g(xi)=ehk−1Tg(xi)/j=1∑∣S∣ehk−1Tg(xj)
其中, f′是一个embedding函数,比如一个CNN。
项目结构:
├── pretrained/
│ ├── omniglot_n=1_k=5_q=15_nv=1_kv=5_qv=1_dist=l2_fce=None.pdparams
├── few_shot/
│ ├── __init__
│ ├── callbacks
│ ├── core
│ ├── datasets
│ ├── eval
│ ├── maml
│ ├── matching
│ ├── metrics
│ ├── models
│ ├── proto
│ ├── train
│ ├── utils
├── config
├── matching_nets
├── prepare_omniglot
├── requirements.txt
├── README.md
1.prepare the Omniglot dataset from the raw Omniglot dataset
-
Augment classes with rotations in multiples of 90 degrees.
-
Downsize images to 28x28
-
Uses background and evaluation sets present in the raw dataset
Edit the DATA_PATH
variable in config.py
to the location where you store the Omniglot and miniImagenet datasets.
After acquiring the data and running the setup scripts your folder structure should look like
DATA_PATH/
Omniglot/
images_background/
images_evaluation/
miniImageNet/
images_background/
images_evaluation/
Place the extracted files into DATA_PATH/Omniglot_Raw
and run prepare_omniglot.py
def handle_alphabet(folder):
print('{}...'.format(folder.split('/')[-1]))
for rotate in [0, 90, 180, 270]:
# Create new folders for each augmented alphabet
mkdir(f'{folder}.{rotate}')
for root, character_folders, _ in os.walk(folder):
for character_folder in character_folders:
# For each character folder in an alphabet rotate and resize all of the images and save
# to the new folder
handle_characters(folder, root + '/' + character_folder, rotate)
for root, alphabets, _ in os.walk(prepared_omniglot_location + 'images_background/'):
for alphabet in sorted(alphabets):
handle_alphabet(root + alphabet)
2.training
for epoch in range(1, epochs+1):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
for batch_index, batch in enumerate(dataloader):
batch_logs = dict(batch=batch_index, size=(batch_size or 1))
callbacks.on_batch_begin(batch_index, batch_logs)
x, y = prepare_batch(batch)
loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs)
batch_logs['loss'] = loss.cpu().numpy() #loss.items()
# Loops through all metrics
batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)
callbacks.on_batch_end(batch_index, batch_logs)
# Run on epoch end
callbacks.on_epoch_end(epoch, epoch_logs)
3.test
def evaluate(model: Layer, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]],
loss_fn: Callable = None, prefix: str = 'val_', suffix: str = ''):
logs = {}
seen = 0
totals = {m: 0 for m in metrics}
if loss_fn is not None:
totals['loss'] = 0
model.eval()
with paddle.no_grad():
for batch in dataloader:
x, y = prepare_batch(batch)
y_pred = model(x)
seen += x.shape[0]
if loss_fn is not None:
totals['loss'] += loss_fn(y_pred, y).item() * x.shape[0]
for m in metrics:
if isinstance(m, str):
v = NAMED_METRICS[m](y, y_pred)
else:
# Assume metric is a callable function
v = m(y, y_pred)
totals[m] += v * x.shape[0]
for m in ['loss'] + metrics:
logs[prefix + m + suffix] = totals[m] / seen
return logs
4.results
Omniglot | |
---|---|
k-way | 5 |
n-shot | 1 |
Pytorch Published (l2) | 98.3 |
This paddle Repo (l2) | 98.85 |
4、关于本项目
参考论文:《Matching Networks for One Shot Learnings》
源代码(Pytorch实现):few-shot
我们的GitHub项目:few-shot
更多推荐
所有评论(0)