gMLP & aMLP:带门控的MLP

摘要

        Transformers已经成为深度学习中最重要的架构创新之一,并且在过去几年里实现了许多突破。在这里,我们提出了一个简单的、注意力无关的网络架构,即gMLP,该架构仅仅基于带有门控的MLPs(多层感知机)。在关键的语言和视觉应用中,它的性能可与Transformer媲美。我们的对比实验结果表明,自注意力机制对于Vision Transformer并不重要,因为gMLP可以达到相同的精度。对比BERT,我们的模型在预训练的perplexity指标上达到了与Transformer的同等水平,并且在某些下游任务上更胜一筹。在一些gMLP性能较差的微调任务上,大大地增大gMLP模型可以缩小与Transformer的差距。总的来说,我们的实验表明,针对增加的数据和计算方面,gMLP可以和Transformer一样进行缩放。

1. gMLP & aMLP

1.1 gMLP

        gMLP由L个具有相同的大小和结构块堆叠而成。设 X ∈ R n × d X \in R^{n \times d} XRn×d 为序列长度为n,维度为d的token表示,每个块定义为:
Z = σ ( X U ) , Z ~ = s ( Z ) , Y = Z ~ V Z=\sigma(X U), \quad \tilde{Z}=s(Z), \quad Y=\tilde{Z} V Z=σ(XU),Z~=s(Z),Y=Z~V
        上式中最重要的是能捕捉空间交互的 s ( ) s() s() 。如果上式去掉 s ( ) s() s() 那么将不再能进行空间交互和FFN并无区别。文中作者选择名为 Spatial Gating Unit (SGU)的模块作为 s ( ) s() s() 捕捉空间依赖。另外,gMLP在NLP、CV任务中遵循与BERT、ViT一样的输入输出规则。
        为了能有跨token的交互, S ( ) S() S() 操作须在空间维度。可以简单的使用线性映射表示:
f W , b ( Z ) = W Z + b f_{W, b}(Z)=W Z+b fW,b(Z)=WZ+b
        此文对上式使用gating操作以便更好的训练,如下所示:
s ( Z ) = Z ⊙ f W , b ( Z ) s(Z)=Z \odot f_{W, b}(Z) s(Z)=ZfW,b(Z)
        通过进一步的实验,发现将Z按通道维进行分裂可以提高性能,故上式可以改写为:
s ( Z ) = Z 1 ⊙ f W , b ( Z 2 ) s(Z)=Z_{1} \odot f_{W, b}\left(Z_{2}\right) s(Z)=Z1fW,b(Z2)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-roEIn2k1-1660306679156)(https://ai-studio-static-online.cdn.bcebos.com/2c15312b248c4cd79505f63eadb3f550c61947eae7b14c7d801d1be1ba8ca550)]

1.2 aMLP

        aMLP相较于gMLP仅增加了一个单头大小为64的self-attention,如下图所示:
在这里插入图片描述

2. 代码复现

2.1 下载并导入所需的库

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from functools import partial

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((230, 230)),
    transforms.ColorJitter(brightness=0.4,contrast=0.4, saturation=0.4),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(30),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 gMLP模型的创建
class TinyAttention(nn.Layer):
    def __init__(self, d_in, d_out=None, d_attn=64):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out or d_in
        self.d_attn = d_attn
        self.qkv = nn.Linear(d_in, d_attn * 3)
        self.proj = nn.Linear(d_attn, d_out)
        self.softmax = nn.Softmax()

    def forward(self, x):

        qkv = self.qkv(x)
        q, k, v = paddle.chunk(qkv, 3, axis=-1)
        w = paddle.einsum('bnd, bmd->bnm', q, k)
        a = self.softmax(w * paddle.rsqrt(paddle.to_tensor(self.d_attn, dtype=paddle.float32)))
        x = paddle.einsum('bnm, bmd->bnd', a,v)
        out = self.proj(x)

        return out
class SpatialGatingUnit(nn.Layer):
    def __init__(self, d_ffn, seq_len, tiny_attn=False):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn//2)
        self.spatial_proj = nn.Conv1D(seq_len, seq_len, kernel_size=1)
        self.tiny_attn = tiny_attn
        self.tn = TinyAttention(d_ffn, d_ffn//2)
        ones_ = nn.initializer.Constant(1.0)
        ones_(self.spatial_proj.bias)

    def forward(self, x):
        u, v = paddle.chunk(x, 2, axis=-1)
        v = self.norm(v)
        if self.tiny_attn:
            tn = self.tn(x)
            v = tn + self.spatial_proj(v)
        else:
            v = self.spatial_proj(v)
        out = u * v
        return out
class gMLPBlock(nn.Layer):
    def __init__(self, d_model, d_ffn, seq_len, dpr=0.0, tiny_attn=False):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj1 = nn.Linear(d_model, d_ffn)
        self.channel_proj2 = nn.Linear(d_ffn//2, d_model)
        self.sgu = SpatialGatingUnit(d_ffn, seq_len, tiny_attn)
        self.droppath = DropPath(dpr) if dpr > 0.0 else nn.Identity()

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = F.gelu(self.channel_proj1(x))
        x = self.sgu(x)
        x = self.channel_proj2(x)
        out = self.droppath(x) + residual
        return out
class gMLP(nn.Layer):
    def __init__(self, in_channels = 3, d_model = 128, num_classes = 1000, patch_size = 16, image_size = 224,
        depth = 30, d_ffn=768, tiny_attn=False,dropout = 0.0, drp = 0.0):
        super().__init__()

        assert image_size % patch_size==0, 'The image_size cannot be divided by the patch_size'

        self.num_patches = (image_size // patch_size) * (image_size // patch_size)

        self.patcher = nn.Conv2D(in_channels, d_model, patch_size, patch_size)

        self.model = nn.Sequential(
            *[nn.Sequential(gMLPBlock(d_model, d_ffn, self.num_patches, drp, tiny_attn)
            ) for _ in range(depth)]
        )

        self.active = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patcher(x)
        B, C, H, W = x.shape
        x = x.transpose([0, 2, 3, 1]).reshape((B, -1, C))
        x = self.model(x)
        x = self.active(x)
        x = paddle.mean(x, 1)
        x = self.classifier(x)

        return x
2.3.4 模型的参数
# gMLP-Ti
model = gMLP(d_model=128, num_classes=10, patch_size=16, depth=30, d_ffn=768, drp=0.0)
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

# gMLP-S
model = gMLP(d_model=256, num_classes=10, patch_size=16, depth=30, d_ffn=1536, drp=0.05)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# gMLP-B
model = gMLP(d_model=512, num_classes=10, patch_size=16, depth=30, d_ffn=3072, drp=0.2)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

2.4 训练

learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# gMLP-Ti
model = gMLP(d_model=128, num_classes=10, patch_size=16, depth=30, d_ffn=768, drp=0.1)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
2.5.1 loss和acc曲线
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = gMLP(d_model=128, num_classes=10, patch_size=16, depth=30, d_ffn=768, drp=0.1)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:671
2.5.2 预测与真实标签比较
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = gMLP(d_model=128, num_classes=10, patch_size=16, depth=30, d_ffn=768, drp=0.1)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

2.5.3 可视化结果
!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = gMLP(d_model=128, num_classes=10, patch_size=16, depth=30, d_ffn=768, drp=0.1)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model, use_cuda=True)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:52<00:00, 190.64it/s]

在这里插入图片描述

lime_weights = lime.interpret(X.numpy()[13], interpret_class=y.numpy()[13], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:51<00:00, 166.92it/s]

在这里插入图片描述

3. 对比实验

modelAccparameter
gMLP0.884795.7M
aMLP0.8698610.9M

aMLP的实验见aMLP.ipynb

总结

  1. 本文提出了一种简单的架构——gMLP,可以在BERT的困惑度和ViT的准确性上与Transformer具有可比性
  2. 同时gMLP是可扩展的
  3. gMLP在一些富有挑战性的任务上取得了满意的结果,并在某些情况下显著优于Transformer
  4. 本文还发现Transformer多头自我注意中的感应偏差对需要跨句对齐的下游任务有用,但是从对比实验上看,似乎对于图像分类任务并没有效果,反而性能会下降

声明

此项目为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐