MLP-Mixer

drawing

paper:https://arxiv.org/abs/2105.01601

浅谈 MLP-Mixer

Hi guy!我们又见面了,这里将解析一篇来自谷歌的工作 MLP-Mixer

drawing

谈起 MLP-Mixer 之前,我们先了解一下 MLP 结构,即多层感知机(Multi-layer Perceptrons),理论上一定复杂程度的 MLP 可以拟合任何函数的,但是代价是大量的计算开销和参数量,这给纯 MLP 的模型发展造成了阻碍。之前提出的 CNN、RNN 就是通过将 Inductive Bias(归纳偏置) 引入模型里,从而能在计算资源有限、数据有限的情况下能取得很好的结果。

现在,我们有钱了、数据够了、计算资源足了、腰板硬了

MLP 计算量大?V100、A100、 TPU管够!噢我们是按机子算的不是按卡算的

MLP 参数量大容易过拟合?JFT-300M 数据安排上!

要什么 Inductive Bias,机器自己学去!

然后,大家就把视觉 Transformer 的 Self-Attention 部分替换成 MLP,然后再大力出奇迹,然后,咦,可以 work 的呀

就这样,视觉的 all-MLP 结构 MLP-Mixer 来了,一时间引爆了学术圈,有惊讶的,有赞赏的,有不看好的 。。。

不管怎么说,MLP-Mixer 证明了古老结构 MLP 的能力,视觉领域形成了 MLP --> CNN --> Transformer --> MLP 的轮回,一时间后续很多基于 MLP 的工作如雨后竹笋般频出,如 ResMLP、CycleMLP、gMLP、ViP、ConvMLP 等

MLP-Mixer 尽管没有展现出 SOTA 的性能,但是其给学术界了不少的启发

流程解析

我们来看一下 MLP-Mixer 的总体结构把,这里我们直接看代码来理解

首先我们先导入包,定义一些关于参数的初始化的设置,良好的初始化可以帮助模型收敛到最优解,以及无任何操作的 Identity 模块和 训练时起到正则作用的DropPath 模块

from functools import partial

import paddle
import paddle.nn as nn

normal_ = nn.initializer.Normal(std=1e-6)                  # normal 初始化
zeros_ = nn.initializer.Constant(value=0.0)                # zeros 初始化
ones_ = nn.initializer.Constant(value=1.0)                 # ones 初始化
xavier_uniform_ = nn.initializer.XavierUniform()           # xavier 初始化
trunc_normal_ = nn.initializer.TruncatedNormal(std=0.02)   # trunc_normal 初始化


class Identity(nn.Layer):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

      
def drop_path(inputs, drop_prob=0., training=False): 
    """drop path op"""
    if drop_prob == 0. or not training:
        return inputs
    keep_prob = 1 - drop_prob
    shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1) 
    random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype)
    random_tensor = random_tensor.floor()
    output = inputs.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    """DropPath class"""
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, inputs):
        return drop_path(inputs, self.drop_prob, self.training)

定义一个 MLP,不关注 Dropout 结构的话,MLP 结构大致如下所示

class Mlp(nn.Layer):
    """ MLP Layer
    """
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = (drop, drop)

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return 

下面是 PatchEmbed 操作,PatchEmbed 操作本质是将图片维度 N, C, H, W 映射到 N, num_patches, embed_dim,具体过程如下所示

对应的代码实现如下

class PatchEmbed(nn.Layer):
    """ Patch Embedding Layer
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],
                          img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2D(in_chans,
                              embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert (
            H == self.img_size[0]
        ), f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
        assert (
            W == self.img_size[1]
        ), f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose([0, 2, 1])  # BCHW -> BNC
        x = self.norm(x)
        return x

接下来就是 MLP-Mixer Block 的搭建了,我们看一下 Block 结构图

更细一点来说,我们看一下一个 patch 经过 PatchEmbed 被拉平,其自然包含了 channel 信息,对行进行MLP 操作可以提取 channel 信息

同理,对于 N, num_patches, embed_dim ,对列进行 MLP 操作可以提取 spatial 信息,如下所示

了解上述原理,代码就不难理解了

class MixerBlock(nn.Layer):
    """ Residual Block w/ token mixing and channel MLPs
    """
    def __init__(
        self,
        dim,
        seq_len,
        mlp_ratio=(0.5, 4.0),
        mlp_layer=Mlp,
        norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
        act_layer=nn.GELU,
        drop=0.0,
        drop_path=0.0,
    ):
        super().__init__()
        tokens_dim = int(mlp_ratio[0] * dim)
        channels_dim = int(mlp_ratio[1] * dim)
        self.norm1 = norm_layer(dim)
        self.mlp_tokens = mlp_layer(seq_len,
                                    tokens_dim,
                                    act_layer=act_layer,
                                    drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
        self.norm2 = norm_layer(dim)
        self.mlp_channels = mlp_layer(dim,
                                      channels_dim,
                                      act_layer=act_layer,
                                      drop=drop)

    def forward(self, x):
        x = x + self.drop_path(
            self.mlp_tokens(self.norm1(x).transpose([0, 2, 1])).transpose(
                [0, 2, 1]))
        x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
        return x

最后就是搭建 MLP-Mixer 了

class MlpMixer(nn.Layer):
    def __init__(
        self,
        num_classes=1000,
        img_size=224,
        in_chans=3,
        patch_size=16,
        num_blocks=8,
        embed_dim=512,
        mlp_ratio=(0.5, 4.0),
        block_layer=MixerBlock,
        mlp_layer=Mlp,
        norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
        act_layer=nn.GELU,
        drop_rate=0.0,
        drop_path_rate=0.0,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = (
            self.embed_dim
        ) = embed_dim  # num_features for consistency with other models

        self.stem = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        # FIXME drop_path (stochastic depth scaling rule or all the same?)
        self.blocks = nn.Sequential(*[
            block_layer(
                embed_dim,
                self.stem.num_patches,
                mlp_ratio,
                mlp_layer=mlp_layer,
                norm_layer=norm_layer,
                act_layer=act_layer,
                drop=drop_rate,
                drop_path=drop_path_rate,
            ) for _ in range(num_blocks)
        ])
        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(hidden_dim, self.num_classes)
        self.init_weights()

    def init_weights(self):
        for n, m in self.named_sublayers():
            _init_weights(m, n)

    def forward_features(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.norm(x)
        x = x.mean(axis=1)
        x = self.head(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        return x


def _init_weights(m, n):
    """ Mixer weight initialization
    """
    if isinstance(m, nn.Linear):
        if n.startswith("head"):
            zeros_(m.weight)
            zeros_(m.bias)
        else:
            xavier_uniform_(m.weight)
            if m.bias is not None:
                if "mlp" in n:
                    normal_(m.bias)
                else:
                    zeros_(m.bias)
    elif isinstance(m, nn.Conv2D):
        trunc_normal_(m.weight)
        if m.bias is not None:
            zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        zeros_(m.bias)
        ones_(m.weight)

这样模型就搭建完毕啦,我们看一下 MLP-Mixer 在 JFT-300M 下性能情况

结合开头 MLP-Mixer 在 ImageNet-1K 下的预训练性能来看,MLP-Mixer 需要更大的数据来自己学习归纳偏置,从而展现出相比 ConvNet 网络更高的性能,这也表明了为什么基于 ImageNet-1K 下 MLP-Mixer B/16 性能不如更大的 MLP-Mixer L/16,当数据量不够时候越大的模型会出现过拟合

PASSL 已支持 MLP-Mixer

PASSL 包含 SimCLR、MoCo v1/v2、BYOL、CLIP 等基于对比学习的图像自监督算法以及 Vision Transformer、Swin Transformer、BEiT、CvT、T2T-ViT、MLP-Mixer 等视觉 Transformer 及相关算法,欢迎 star ~

PASSL github:https://github.com/PaddlePaddle/PASSL

MLP-Mixer 性能

The results are evaluated on ImageNet2012 validation set

ArchWeightTop-1 AccTop-5 AccCrop ratio# Params
mlp_mixer_b16_224pretrain 1k76.6092.230.87560.0M
mlp_mixer_l16_224pretrain 1k72.0687.670.875208.2M

更详细内容可见:https://github.com/PaddlePaddle/PASSL/tree/main/configs/mlp_mixer

!git clone https://github.com/PaddlePaddle/PASSL.git  # 克隆 PASSL
!pip install ftfy   # 安装依赖
!pip install regex  # 安装依赖
%cd PASSL
import paddle

from passl.modeling.backbones import build_backbone 
from passl.modeling.heads import build_head 
from passl.utils.config import get_config      


class Model(paddle.nn.Layer):
    def __init__(self, cfg_file):
        super().__init__()
        cfg = get_config(cfg_file)
        self.backbone = build_backbone(cfg.model.architecture)
        self.head = build_head(cfg.model.head)

    def forward(self, x):

        x = self.backbone(x)
        x = self.head(x)
        return x


cfg_file  = 'configs/mlp_mixer/mlp-mixer_b16_224.yaml'  # MLP-Mixer 配置文件
m = Model(cfg_file)                                     # 模型组网
x = paddle.randn([2, 3, 224, 224])   # test
out = m(x)

loss = out.sum()
loss.backward()
print('Single iteration completed successfully')

总结

大力出奇迹的 MLP-Mixer 尽管展现了具有竞争力的水准,但是背后也参杂了现代网络结构的技巧,例如类 ViT 的架构、LayeNorm、GELU、skip connection 等,这也是为什么现在 MLP-Mixer 能出圈的原因之一。但是 MLP-Mixer 缺点也很明显,其扩展性如做下游任务不是很方便(后续的 CycleMLP 解决了这个问题)。实际上,MLP-Mixer 的成功出圈,与其说是 MLP 古老的结构简单有效,不如说是时代造就了它。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐