LIT:少些注意力的视觉Transformer

摘要

        Transformer已经成为深度学习领域的主导架构之一,尤其是作为计算机视觉领域卷积神经网络(CNN)的强大替代方案。然而,在以前的工作中,Transformer的训练和推理可能会非常昂贵,因为在长序列表示中,自我注意具有二次元复杂度,尤其是在高分辨率密集预测任务中。为此,我们提出了一种新颖的Less attention vIsion Transformer (LIT),其基础是早期Transformer中的自我注意层仍然专注于局部模式,并在最近的分层视觉变形金刚中带来了较小的好处。具体来说,我们提出了一个分层的Transformer,其中我们使用纯多层感知器(MLPs)在早期阶段对丰富的本地模式进行编码,同时应用自我关注模块来捕获更深层次的较长依赖关系。此外,我们进一步提出了一个可学习的可变形令牌合并模块,以非统一的方式自适应融合信息补丁。该算法在图像分类、目标检测和实例分割等图像识别任务中取得了良好的性能,为许多视觉任务提供了强大的骨干

1. LiT

1.1 总览

        如图1所示,总体架构思想是采用分层Transformer架构,但是在早期阶段使用MLP来对局部进行建模,在后期阶段使用自注意力层对长距离依赖进行建模,公式表示如下:

  1. MLP Block:
    X l = X l − 1 + MLP ⁡ ( LN ⁡ ( X l − 1 ) ) \mathbf{X}_{l}=\mathbf{X}_{l-1}+\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{X}_{l-1}\right)\right) Xl=Xl1+MLP(LN(Xl1))
  2. Transformer Block:
    X l − 1 ′ = X l − 1 + MSA ⁡ ( LN ⁡ ( X l − 1 ) ) , X l = X l − 1 ′ + MLP ⁡ ( LN ⁡ ( X l − 1 ′ ) ) \begin{aligned} \mathbf{X}_{l-1}^{\prime} &=\mathbf{X}_{l-1}+\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{X}_{l-1}\right)\right), \\ \mathbf{X}_{l} &=\mathbf{X}_{l-1}^{\prime}+\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{X}_{l-1}^{\prime}\right)\right) \end{aligned} Xl1Xl=Xl1+MSA(LN(Xl1)),=Xl1+MLP(LN(Xl1))
    在这里插入图片描述

1.2 DTM

        以往关于HVT的研究依赖于Patch Merge来实现金字塔特征表示。然而,他们合并了常规网格中的补丁,并忽略了一个事实,即不是每个补丁对输出单元的贡献都相等。受可变形卷积的启发,我们提出了一个可变形令牌合并模块来学习偏移网格,以自适应采样更多信息的补丁
在这里插入图片描述

2. 代码复现

2.1 下载并导入所需的库

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from paddle.vision.ops import DeformConv2D

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((230, 230)),
    transforms.ColorJitter(brightness=0.4,contrast=0.4, saturation=0.4),
    transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 Patch Embedding
def to_2tuple(x):
    return (x, x)

class PatchEmbed(nn.Layer):

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose([0, 2, 1])  # B Ph*Pw C
        # _, _, H, W = x.shape
        if self.norm is not None:
            x = self.norm(x)
        return x
model =  PatchEmbed(norm_layer=nn.LayerNorm)
paddle.summary(model, (1, 3, 224, 224))
W0723 11:18:44.665843   345 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0723 11:18:44.672824   345 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.


---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-1      [[1, 3, 224, 224]]    [1, 96, 56, 56]         4,704     
  LayerNorm-1     [[1, 3136, 96]]       [1, 3136, 96]           192      
===========================================================================
Total params: 4,896
Trainable params: 4,896
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 4.59
Params size (MB): 0.02
Estimated Total Size (MB): 5.19
---------------------------------------------------------------------------






{'total_params': 4896, 'trainable_params': 4896}
2.3.4 DTM
# DCNv2
class DeformablePatchMerging(nn.Layer):

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.kernel_size = 2
        self.stride = 2
        self.padding = 0
        self.c_in = dim
        self.c_out = dim*2
        self.offset = paddle.nn.Conv2D(dim, 2 * 2 * 2, kernel_size=2, stride=2, padding=0)
        self.mask = paddle.nn.Conv2D(dim, 2 * 2, kernel_size=2, stride=2, padding=0)
        self.dconv = DeformConv2D(dim, dim * 2, kernel_size=2, stride=2, padding=0)
        self.norm_layer = nn.BatchNorm2D(dim*2)
        self.act_layer = nn.GELU()

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.reshape([B, H, W, C]).transpose([0, 3, 1, 2]) #B C H W
        offset = self.offset(x)
        mask = self.mask(x)
        x= self.dconv(x, offset, mask)
        # x= self.dconv(x, offset)
        x = self.act_layer(self.norm_layer(x)).flatten(2).transpose([0, 2, 1])  # B H*W C

        return x
model =  DeformablePatchMerging((16, 16), 96)
paddle.summary(model, (1, 16 * 16, 96))
----------------------------------------------------------------------------------------------------
 Layer (type)                    Input Shape                      Output Shape         Param #    
====================================================================================================
   Conv2D-2                   [[1, 96, 16, 16]]                   [1, 8, 8, 8]          3,080     
   Conv2D-3                   [[1, 96, 16, 16]]                   [1, 4, 8, 8]          1,540     
DeformConv2D-1  [[1, 96, 16, 16], [1, 8, 8, 8], [1, 4, 8, 8]]    [1, 192, 8, 8]        73,920     
 BatchNorm2D-1                [[1, 192, 8, 8]]                   [1, 192, 8, 8]          768      
    GELU-1                    [[1, 192, 8, 8]]                   [1, 192, 8, 8]           0       
====================================================================================================
Total params: 79,308
Trainable params: 78,540
Non-trainable params: 768
----------------------------------------------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 0.29
Params size (MB): 0.30
Estimated Total Size (MB): 0.68
----------------------------------------------------------------------------------------------------






{'total_params': 79308, 'trainable_params': 78540}
2.3.5 自注意力
class RelPosAttention(nn.Layer):

    def __init__(self, dim, input_resolution, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = self.create_parameter(((2 * input_resolution[0] - 1) * (2 * input_resolution[1] - 1), num_heads),
            default_initializer=(nn.initializer.Assign(
            paddle.zeros(((2 * input_resolution[0] - 1) * (2 * input_resolution[1] - 1), num_heads)))))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = paddle.arange(self.input_resolution[0])
        coords_w = paddle.arange(self.input_resolution[1])
        coords = paddle.stack(paddle.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = paddle.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.transpose([1, 2, 0])  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.input_resolution[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.input_resolution[1] - 1
        relative_coords[:, :, 0] *= 2 * self.input_resolution[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        init = paddle.nn.initializer.TruncatedNormal(mean=0.0, std=.02)
        init(self.relative_position_bias_table)
        self.softmax = nn.Softmax(axis=-1)

    def forward(self, x):
        B_, N, C = x.shape
        assert C % self.num_heads == 0, 'C cannot be divided by num_heads'
        qkv = self.qkv(x).reshape((B_, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose([0, 1, 3, 2]))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.flatten(0)].reshape((
            self.input_resolution[0] * self.input_resolution[1], self.input_resolution[0] * self.input_resolution[1], -1))  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.transpose([2, 0, 1])  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B_, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
model =  RelPosAttention(96, (16, 16), num_heads=16)
paddle.summary(model, (1, 16 * 16, 96))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Linear-1        [[1, 256, 96]]       [1, 256, 288]         27,936     
   Softmax-1    [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
   Dropout-1    [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
   Linear-2        [[1, 256, 96]]        [1, 256, 96]          9,312     
   Dropout-2       [[1, 256, 96]]        [1, 256, 96]            0       
===========================================================================
Total params: 37,248
Trainable params: 37,248
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 16.94
Params size (MB): 0.14
Estimated Total Size (MB): 17.17
---------------------------------------------------------------------------






{'total_params': 37248, 'trainable_params': 37248}
class Attention(nn.Layer):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(axis=-1)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
model =  Attention(96, num_heads=16)
paddle.summary(model, (1, 16 * 16, 96))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Linear-3        [[1, 256, 96]]       [1, 256, 288]         27,648     
   Softmax-2    [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
   Dropout-3    [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
   Linear-4        [[1, 256, 96]]        [1, 256, 96]          9,312     
   Dropout-4       [[1, 256, 96]]        [1, 256, 96]            0       
===========================================================================
Total params: 36,960
Trainable params: 36,960
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 16.94
Params size (MB): 0.14
Estimated Total Size (MB): 17.17
---------------------------------------------------------------------------






{'total_params': 36960, 'trainable_params': 36960}
2.3.6 MLP
class Mlp(nn.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class MLPBlock(nn.Layer):

    def __init__(self, dim, input_resolution, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
model =  MLPBlock(96, (16, 16), num_heads=64)
paddle.summary(model, (1, 16 * 16, 96))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
  LayerNorm-2      [[1, 256, 96]]        [1, 256, 96]           192      
   Linear-5        [[1, 256, 96]]       [1, 256, 384]         37,248     
    GELU-2        [[1, 256, 384]]       [1, 256, 384]            0       
   Dropout-5       [[1, 256, 96]]        [1, 256, 96]            0       
   Linear-6       [[1, 256, 384]]        [1, 256, 96]         36,960     
     Mlp-1         [[1, 256, 96]]        [1, 256, 96]            0       
  Identity-1       [[1, 256, 96]]        [1, 256, 96]            0       
===========================================================================
Total params: 74,400
Trainable params: 74,400
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 2.44
Params size (MB): 0.28
Estimated Total Size (MB): 2.82
---------------------------------------------------------------------------






{'total_params': 74400, 'trainable_params': 74400}
2.3.7 LiT
class TransformerBlock(nn.Layer):

    def __init__(self, dim, input_resolution, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.norm1 = norm_layer(dim)
        self.attn = RelPosAttention(
            dim, input_resolution=input_resolution, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
model =  TransformerBlock(96, (16, 16), num_heads=16)
paddle.summary(model, (1, 16 * 16, 96))
-----------------------------------------------------------------------------
  Layer (type)        Input Shape          Output Shape         Param #    
=============================================================================
   LayerNorm-3       [[1, 256, 96]]        [1, 256, 96]           192      
    Linear-7         [[1, 256, 96]]       [1, 256, 288]         27,648     
    Softmax-3     [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
    Dropout-6     [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
    Linear-8         [[1, 256, 96]]        [1, 256, 96]          9,312     
    Dropout-7        [[1, 256, 96]]        [1, 256, 96]            0       
RelPosAttention-2    [[1, 256, 96]]        [1, 256, 96]         15,376     
   Identity-2        [[1, 256, 96]]        [1, 256, 96]            0       
   LayerNorm-4       [[1, 256, 96]]        [1, 256, 96]           192      
    Linear-9         [[1, 256, 96]]       [1, 256, 384]         37,248     
     GELU-3         [[1, 256, 384]]       [1, 256, 384]            0       
    Dropout-8        [[1, 256, 96]]        [1, 256, 96]            0       
    Linear-10       [[1, 256, 384]]        [1, 256, 96]         36,960     
      Mlp-2          [[1, 256, 96]]        [1, 256, 96]            0       
=============================================================================
Total params: 126,928
Trainable params: 126,928
Non-trainable params: 0
-----------------------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 19.75
Params size (MB): 0.48
Estimated Total Size (MB): 20.33
-----------------------------------------------------------------------------






{'total_params': 126928, 'trainable_params': 126928}
class LITLayer(nn.Layer):

    def __init__(self, dim, input_resolution, depth, num_heads,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, has_msa=True):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        block = TransformerBlock if has_msa else MLPBlock
        self.blocks = nn.LayerList([
            block(dim=dim, input_resolution=input_resolution,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop, attn_drop=attn_drop,
                  drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                  norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x
model =  LITLayer(96, (16, 16), 1, num_heads=16)
paddle.summary(model, (1, 16 * 16, 96))
------------------------------------------------------------------------------
   Layer (type)        Input Shape          Output Shape         Param #    
==============================================================================
   LayerNorm-5        [[1, 256, 96]]        [1, 256, 96]           192      
    Linear-11         [[1, 256, 96]]       [1, 256, 288]         27,936     
    Softmax-4      [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
    Dropout-9      [[1, 16, 256, 256]]   [1, 16, 256, 256]          0       
    Linear-12         [[1, 256, 96]]        [1, 256, 96]          9,312     
    Dropout-10        [[1, 256, 96]]        [1, 256, 96]            0       
RelPosAttention-3     [[1, 256, 96]]        [1, 256, 96]         15,376     
    Identity-3        [[1, 256, 96]]        [1, 256, 96]            0       
   LayerNorm-6        [[1, 256, 96]]        [1, 256, 96]           192      
    Linear-13         [[1, 256, 96]]       [1, 256, 384]         37,248     
      GELU-4         [[1, 256, 384]]       [1, 256, 384]            0       
    Dropout-11        [[1, 256, 96]]        [1, 256, 96]            0       
    Linear-14        [[1, 256, 384]]        [1, 256, 96]         36,960     
      Mlp-3           [[1, 256, 96]]        [1, 256, 96]            0       
TransformerBlock-2    [[1, 256, 96]]        [1, 256, 96]            0       
==============================================================================
Total params: 127,216
Trainable params: 127,216
Non-trainable params: 0
------------------------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 19.94
Params size (MB): 0.49
Estimated Total Size (MB): 20.52
------------------------------------------------------------------------------






{'total_params': 127216, 'trainable_params': 127216}
class LIT(nn.Layer):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 has_msa=[0, 0, 1, 1], **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio
        self.has_msa = has_msa

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.LayerList()
        for i_layer in range(self.num_layers):
            layer = LITLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=DeformablePatchMerging if (i_layer < self.num_layers - 1) else None,
                               has_msa=self.has_msa[i_layer])
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1D(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        trunc_normal_ = paddle.nn.initializer.TruncatedNormal(mean=0.0, std=.02)
        zeros_ = nn.initializer.Constant(value=0.)
        ones_ = nn.initializer.Constant(value=1.)
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward_features(self, x):
        x = self.patch_embed(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose([0, 2, 1]))  # B C 1
        x = paddle.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x
2.3.8 模型的参数
# LIT-S
model =  LIT(num_classes=10, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24])
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

# LIT-M
model =  LIT(num_classes=10, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24])
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

# LIT-B
model =  LIT(num_classes=10, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[3, 6, 16, 32])
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

# LIT-ours
model =  LIT(num_classes=10, embed_dim=64, depths=[2, 2, 6, 2], num_heads=[3, 6, 16, 32])
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

2.4 训练

learning_rate = 5e-5
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# LIT-ours
model = LIT(num_classes=10, embed_dim=64, depths=[2, 2, 6, 2], num_heads=[3, 6, 16, 32])

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
2.5.1 loss和acc曲线
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = LIT(num_classes=10, embed_dim=64, depths=[2, 2, 6, 2], num_heads=[3, 6, 16, 32])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:987
2.5.2 预测与真实标签比较
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = LIT(num_classes=10, embed_dim=64, depths=[2, 2, 6, 2], num_heads=[3, 6, 16, 32])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

2.5.3 可视化结果
!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = LIT(num_classes=10, embed_dim=64, depths=[2, 2, 6, 2], num_heads=[3, 6, 16, 32])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model, use_cuda=True)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:53<00:00, 160.37it/s]

在这里插入图片描述

lime_weights = lime.interpret(X.numpy()[13], interpret_class=y.numpy()[13], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:53<00:00, 162.21it/s]

在这里插入图片描述

总结

        LiT的思想比较简单,前期使用MLP来减少计算并对局部进行建模,后期使用自注意力来进行长距离建模,同时在Patch Merge中LiT使用可变形卷积来进行聚合Patch,从而更好地抽取特征
        未来工作:仔细调整参数在更大的数据集上验证LiT的有效性

此文仅为搬运,原作链接:https://aistudio.baidu.com/aistudio/projectdetail/4365614

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐