MLP-Mixer复现并探究位置编码的影响

摘要

        卷积神经网络(CNN)是计算机视觉的首选模型。最近,以注意力为基础的网络,如Vision Transformer,也变得流行起来。在这篇论文中,我们证明了卷积和注意力对于良好的性能都是足够的,但它们都不是必要的。我们提出MLP-Mixer,一个专门基于多层感知器(MLPs)的架构。MLP-Mixer包含两种类型的图层:一种是独立应用于图像补丁的MLPs(即“混合”每个位置的特征),另一种是跨补丁应用MLPs(即“混合”空间信息)。当在大数据集上训练时,或使用现代正则化方案,MLP-Mixer在图像分类基准上获得了具有竞争力的分数,训练前和推断成本可与最先进的模型相媲美。我们希望这些结果能激发进一步的研究,超越已有的CNN和Transformer

1. MLP-Mixer

1.1 前言

        MLP-Mixer 是谷歌 AI 团队 2021 初的文章,论文为 MLP-Mixer: An all-MLP Architecture for Vision。卷积神经网络(CNN)是计算机视觉的首选模型。 最近,基于注意力的网络(例如Vision Transformer)也变得很流行。 在工作中,研究人员表明,尽管卷积和注意力都足以获得良好的性能,但它们都不是必需的,纯 MLP + 非线性激活函数 + Layer Normalization 也能取得可比的性能,其预训练和推理成本可与最新模型相媲美
        为什么要用全连接层有什么好处呢?它的归纳偏置(inductive bias)更低。归纳偏置可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或者“价值观”。CNN 的归纳偏置在于卷积操作,只有感受野内部有相互作用,即图像的局部性特征。时序网络 RNN 的归纳偏置在于时间维度上的连续性和局部性。事实上,ViT 已经开始延续了一直以来想要在神经网络中移除手工视觉特征和归纳偏置的趋势,让模型只依赖于原始数据进行学习。MLP 则更进了一步。

1.2 MLP-Mixer 网络结构

        MLP-Mixer 的其整体思路为:先将输入图片拆分成多个 patches(每个 patch 之间不重叠),通过 Per-patch Fully-connected 层的操作将每个 patch 转换成 feature embedding,然后送入N个Mixer Layer。最后,Mixer 将标准分类头与全局平均池化层配合使用,随后使用 Fully-connected 进行分类。细细看来,其第一步与 ViT 其实是一致的,Mixer Layer 替换了 Transformer Block,最后直接的输出接到全连接层而无需 class token。此外,Mixer 的输出基于输入的信息,因为全连接层,所以交换任意两个 token 会得到不同的结果(对应的权重不一样了),所以无需 position embedding。
在这里插入图片描述

2. 代码复现

2.1 下载并导入所需的库

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from functools import partial

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((230, 230)),
    transforms.ColorJitter(brightness=0.4,contrast=0.4, saturation=0.4),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(30),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 MLP-Mixer模型的创建
class PreNormResidual(nn.Layer):
    def __init__(self, dim, fn, drp=0.0):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.drop_path = DropPath(drp)

    def forward(self, x):
        return self.drop_path(self.fn(self.norm(x))) + x
class FeedForward(nn.Layer):
    def __init__(self, dim, hidden_dim, dropout = 0., dense = nn.Linear):
        super().__init__()
        self.net = nn.Sequential(
            dense(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            dense(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
class MLPMixer(nn.Layer):
    def __init__(self, in_channels = 3, d_model = 512, num_classes = 1000, patch_size = 16, image_size = 224,
        depth = 12, dim_c = 2048, dim_s = 256, dropout = 0.0, drp = 0.0):
        super().__init__()

        assert image_size % patch_size==0, 'The image_size cannot be divided by the patch_size'

        self.num_patches = (image_size // patch_size) * (image_size // patch_size)

        self.patcher = nn.Conv2D(in_channels, d_model, patch_size, patch_size)

        chan_first, chan_last = partial(nn.Conv1D, kernel_size = 1), nn.Linear
        self.model = nn.Sequential(
            *[nn.Sequential(
                PreNormResidual(d_model, FeedForward(self.num_patches, dim_s, dropout, chan_first), drp),
                PreNormResidual(d_model, FeedForward(d_model, dim_c, dropout, chan_last), drp)
            ) for _ in range(depth)]
        )

        self.active = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patcher(x)
        B, C, H, W = x.shape
        x = x.transpose([0, 2, 3, 1]).reshape((B, -1, C))
        x = self.model(x)
        x = self.active(x)
        x = paddle.mean(x, 1)
        x = self.classifier(x)

        return x
2.3.4 模型的参数
# MLP-Mixer-S/32
model = MLPMixer(d_model=512, num_classes=10, patch_size=32, depth=8, dim_c = 2048, dim_s = 256)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# MLP-Mixer-S/16
model = MLPMixer(d_model=512, num_classes=10, patch_size=16, depth=8, dim_c = 2048, dim_s = 256, drp=0.1)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# MLP-Mixer-B/32
model = MLPMixer(d_model=768, num_classes=10, patch_size=32, depth=12, dim_c = 3072, dim_s = 384)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# MLP-Mixer-B/16
model = MLPMixer(d_model=768, num_classes=10, patch_size=16, depth=12, dim_c = 3072, dim_s = 384)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# MLP-Mixer-L/32
model = MLPMixer(d_model=1024, num_classes=10, patch_size=32, depth=24, dim_c = 4096, dim_s = 512)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# MLP-Mixer-L/16
model = MLPMixer(d_model=1024, num_classes=10, patch_size=16, depth=24, dim_c = 4096, dim_s = 512)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# MLP-Mixer-H/14
model = MLPMixer(d_model=1280, num_classes=10, patch_size=14, depth=32, dim_c = 5120, dim_s = 640)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

2.4 训练

learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# MLP-Mixer-S/16
model = MLPMixer(d_model=512, num_classes=10, patch_size=16, depth=8, dim_c = 2048, dim_s = 256, drp=0.1)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
2.5.1 loss和acc曲线
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = MLPMixer(d_model=512, num_classes=10, patch_size=16, depth=8, dim_c = 2048, dim_s = 256, drp=0.1)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:683
2.5.2 预测与真实标签比较
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = MLPMixer(d_model=512, num_classes=10, patch_size=16, depth=8, dim_c = 2048, dim_s = 256, drp=0.1)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

2.5.3 可视化结果
!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = MLPMixer(d_model=512, num_classes=10, patch_size=16, depth=8, dim_c = 2048, dim_s = 256, drp=0.1)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model, use_cuda=True)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:53<00:00, 185.76it/s]

在这里插入图片描述

lime_weights = lime.interpret(X.numpy()[13], interpret_class=y.numpy()[13], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:53<00:00, 160.09it/s]

在这里插入图片描述

3. 对比实验

ModelTrain AccVal Acc
MLP-Mixer w/o PE0.94110.82872
MLP-Mixer w PE0.93840.84088

PE表示位置编码,带位置编码的结果见main-Copy1.ipynb,位置编码的完整可视化代码在main-Copy2.ipynb可见,可视化结果如下图所示在这里插入图片描述

总结

        MLP-Mixer的提出证明了MLP也可以很好地实现分类性能,在大规模参数和大规模数据集下达到与CNN和Transformer媲美的性能
        虽然MLP-Mixer没有使用位置编码,但是添加了位置编码准确率提高了1.21%(0.84088 vs 0.82872),同时具有更少的过拟合,可以看到位置编码对于MLP-Mixer还是有效的
        未来工作:探究更大的数据集,进一步证明以上观察

此文仅为搬运,原作链接:https://aistudio.baidu.com/aistudio/projectdetail/4302078

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐