ACmix:自注意力与卷积的结合

摘要

        卷积和自注意力是表示学习的两种强大技术,通常被认为是两种不同机制的方法。在该文中,作者证明了这两种范式的大部分计算实际上是通过相同的操作完成的,展示了它们之间很强的内在关系。作者将卷积和自注意力均拆分成两个阶段,卷积操作中,将kernel大小为 k × k 的传统卷积可以分解为 k × k 个单独的 1×1 卷积,然后进行移位和求和操作。self attention 模块中,我们将查询、键和值的投影解释为多个 1×1 卷积,然后通过计算注意力权重和聚合值。因此,两个模块的第一阶段都包含类似的操作。更重要的是,与第二阶段相比,第一阶段贡献了主要的计算复杂度(通道大小的平方)。这样就可以将这两种看似不同的范式结合在一起,提出ACmix,它享有自注意力和卷积的好处,同时与纯卷积或self-attention相比具有最小的计算开销。作者并将大量实验证明了,模型在图像识别和下游任务的竞争基线上取得了持续改进的结果。

1. ACmix

1.1 总览

        从结构上,如图1所示,ACmix将卷积和自注意力中隐式包含的1x1卷积实现共享,从而减少第一部分的计算,而第二部分计算几乎是免费的。
在这里插入图片描述

1.2 重新审视卷积

        传统卷积会通过 k × k 卷积核特征转换后,先聚合后再偏移进入下个位置再卷积,这里的卷积是通过 1 × 1 卷积核计算feature map所有的特征转换后,先偏移,再聚合,将其拆分成变换和偏移聚合两个阶段:

  1. 将输入的feature map从某一位置线性投影,这与标准的1×1卷积相同
  2. 将投影的feature map根据kernel position进行移位,最终聚合到一起。可以很容易地观察到,大多数计算代价是在1×1卷积中执行的,而接下来的位移和聚合是轻量级的
    公式表示如下:
    S t a g e I : g ~ i j ( p , q ) = K p , q f i j , S t a g e I I : g i j ( p , q ) = Shift ⁡ ( g ~ i j ( p , q ) , p − ⌊ k / 2 ⌋ , q − ⌊ k / 2 ⌋ ) , g i j = ∑ p , q g i j ( p , q ) .  \begin{array}{l} \mathbf{Stage I}: \tilde{g}_{i j}^{(p, q)} = K_{p, q} f_{i j} ,\\ \mathbf{Stage II}: \begin{array}{l} g_{i j}^{(p, q)} = \operatorname{Shift}\left(\tilde{g}_{i j}^{(p, q)}, p-\lfloor k / 2\rfloor, q-\lfloor k / 2\rfloor\right) ,\\ g_{i j} = \sum_{p, q} g_{i j}^{(p, q)} \text {. } \end{array} \end{array} StageI:g~ij(p,q)=Kp,qfij,StageII:gij(p,q)=Shift(g~ij(p,q),pk/2,qk/2),gij=p,qgij(p,q)
    在这里插入图片描述

1.3 重新审视自注意力

        自注意力机制,目前也被广泛用于CV领域,如Transformer系列,与传统的卷积相比,它让模型在更大的内容空间中聚焦于重要区域,同样也可以分为两个阶段:

  1. 使用1×1卷积将输入特征投影为query、key和value
  2. 包括注意力权重的计算和value矩阵的聚合,即聚集局部特征。与第一阶段相比,相应的计算代价较小,与卷积的模式相同。
    公式表示如下:
    S t a g e I : q i j ( l ) = W q ( l ) f i j , k i j ( l ) = W k ( l ) f i j , v i j ( l ) = W v ( l ) f i j , S t a g e I I : g i j = ∥ l = 1 N ( ∑ a , b ∈ N k ( i , j ) A ( q i j ( l ) , k a b ( l ) ) v a b ( l ) ) . \begin{array}{l} Stage I: q_{i j}^{(l)}=W_{q}^{(l)} f_{i j}, k_{i j}^{(l)}=W_{k}^{(l)} f_{i j}, v_{i j}^{(l)}=W_{v}^{(l)} f_{i j} ,\\ Stage II: g_{i j}=\|_{l=1}^{N}\left(\sum_{a, b \in \mathcal{N}_{k}(i, j)} \mathrm{A}\left(q_{i j}^{(l)}, k_{a b}^{(l)}\right) v_{a b}^{(l)}\right) . \end{array} StageI:qij(l)=Wq(l)fij,kij(l)=Wk(l)fij,vij(l)=Wv(l)fij,StageII:gij=l=1N(a,bNk(i,j)A(qij(l),kab(l))vab(l)).
    在这里插入图片描述

1.4 ACmix

        ACmix操作集成了卷积和自注意力操作,如下图所示主要分为两个阶段:

  1. 通过3个1×1卷积对输入特征进行投影,然后reshape为N个Pieces。因此,获得了包含3×N特征映射的一组丰富的中间特征。
  2. 它们遵循不同的范例。对于自注意力路径,将中间特征集合到N组中,每组包含3个特征,每个特征来自1×1卷积。对应的三个特征图分别作为query、key和value,遵循传统的多头自注意力模块。对于kernel size为 k 的卷积路径,采用轻量级全连接层,生成个特征映射。因此,通过对生成的特征进行移位和聚合,对输入特征进行卷积处理,并像传统的一样从局部感受野收集信息。
            最后,将两个路径的输出相加,其强度由两个可学习标量控制:
    F out  = α F att  + β F c o n v F_{\text {out }}=\alpha F_{\text {att }}+\beta F_{\mathrm{conv}} Fout =αFatt +βFconv
    在这里插入图片描述

1.5 改进Shift和Summation

        卷积路径中的中间特征遵循传统卷积模块中的移位和求和操作。尽管它们在理论上是轻量级的,但向不同方向移动张量实际上破坏了数据局部性,很难实现向量化。这可能会极大地损害了推理时的实际效率。因此如下图所示,本文采用可学习kernel的多分组卷积来代替低效张量位移,采用精心设计的卷积核来实现普通的卷积操作,从而加速该部分操作
在这里插入图片描述

1.6 参数分析

在这里插入图片描述

2. 代码复现

2.1 下载并导入所需要的包

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from paddle import ParamAttr

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((130, 130)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    paddlex.transforms.MixupImage(),
    transforms.RandomResizedCrop(128, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 标签平滑

class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()

2.4 AlexNet-ACmix

2.4.1 ACmix
def position(H, W):
    loc_w = paddle.repeat_interleave(paddle.linspace(-1.0, 1.0, W).unsqueeze(0), repeats=H, axis=0)
    loc_h = paddle.repeat_interleave(paddle.linspace(-1.0, 1.0, H).unsqueeze(1), repeats=W, axis=1)
    loc = paddle.concat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], axis=0).unsqueeze(0)
    return loc


def stride(x, stride):
    b, c, h, w = x.shape
    return x[:, :, ::stride, ::stride]


class ACmix(nn.Layer):
    def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
        super(ACmix, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.head = head
        self.kernel_att = kernel_att
        self.kernel_conv = kernel_conv
        self.stride = stride
        self.dilation = dilation
        self.rate1 = self.create_parameter([1], default_initializer=nn.initializer.Assign([0.5]))
        self.rate2 = self.create_parameter([1], default_initializer=nn.initializer.Assign([0.5]))
        self.head_dim = self.out_planes // self.head

        self.conv1 = nn.Conv2D(in_planes, out_planes, kernel_size=1)
        self.conv2 = nn.Conv2D(in_planes, out_planes, kernel_size=1)
        self.conv3 = nn.Conv2D(in_planes, out_planes, kernel_size=1)
        self.conv_p = nn.Conv2D(2, self.head_dim, kernel_size=1)

        self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
        self.pad_att = nn.Pad2D(self.padding_att, mode='reflect')
        self.unfold = nn.Unfold(kernel_sizes=kernel_att, strides=self.stride, paddings=0)
        self.softmax = nn.Softmax(axis=1)

        self.fc = nn.Conv2D(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias_attr=False)
        self.dep_conv = nn.Conv2D(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes, 
            kernel_size=self.kernel_conv, bias_attr=True, groups=self.head_dim, padding=1, stride=stride)
        
        self.reset_parameters()

    def reset_parameters(self):
        kernel = paddle.zeros((self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv))
        for i in range(self.kernel_conv * self.kernel_conv):
            kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
        kernel = paddle.repeat_interleave(kernel.unsqueeze(0), self.out_planes, axis=0)
        self.dep_conv.weight = self.create_parameter((self.out_planes, 1, 1, 1), default_initializer=nn.initializer.Assign(kernel))
        self.dep_conv.bias = self.create_parameter([1], default_initializer=nn.initializer.Assign([0.]))

    def forward(self, x):
        q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
        scaling = float(self.head_dim) ** -0.5
        b, c, h, w = q.shape
        h_out, w_out = h//self.stride, w//self.stride


    # att
        # positional encoding
        pe = self.conv_p(position(h, w))

        q_att = q.reshape((b*self.head, self.head_dim, h, w)) * scaling
        k_att = k.reshape((b*self.head, self.head_dim, h, w))
        v_att = v.reshape((b*self.head, self.head_dim, h, w))

        if self.stride > 1:
            q_att = stride(q_att, self.stride)
            q_pe = stride(pe, self.stride)
        else:
            q_pe = pe

        # b*head, head_dim, k_att^2, h_out, w_out
        unfold_k = self.unfold(self.pad_att(k_att)).reshape(
            (b * self.head, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out)) 
        # 1, head_dim, k_att^2, h_out, w_out
        unfold_rpe = self.unfold(self.pad_att(pe)).reshape(
            (1, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out)) 
        
        # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
        att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)
        att = self.softmax(att)

        out_att = self.unfold(self.pad_att(v_att)).reshape(
            (b * self.head, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out))
        out_att = (att.unsqueeze(1) * out_att).sum(2).reshape((b, self.out_planes, h_out, w_out))

    # conv
        f_all = self.fc(paddle.concat([q.reshape((b, self.head, self.head_dim, h*w)), 
            k.reshape((b, self.head, self.head_dim, h * w)), v.reshape((b, self.head, self.head_dim, h * w))], axis=1))
        f_conv = f_all.transpose([0, 2, 1, 3]).reshape((x.shape[0], -1, x.shape[-2], x.shape[-1]))
        
        out_conv = self.dep_conv(f_conv)

        return self.rate1 * out_att + self.rate2 * out_conv
model = ACmix(16, 64)
paddle.summary(model, (1, 16, 224, 224))
W0726 10:04:00.044970  5245 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0726 10:04:00.048735  5245 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.


---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-1     [[1, 16, 224, 224]]   [1, 64, 224, 224]        1,088     
   Conv2D-2     [[1, 16, 224, 224]]   [1, 64, 224, 224]        1,088     
   Conv2D-3     [[1, 16, 224, 224]]   [1, 64, 224, 224]        1,088     
   Conv2D-4      [[1, 2, 224, 224]]   [1, 16, 224, 224]         48       
    Pad2D-1     [[4, 16, 224, 224]]   [4, 16, 230, 230]          0       
   Unfold-1     [[4, 16, 230, 230]]    [4, 784, 50176]           0       
   Softmax-1    [[4, 49, 224, 224]]   [4, 49, 224, 224]          0       
   Conv2D-5     [[1, 12, 16, 50176]]  [1, 9, 16, 50176]         108      
   Conv2D-6     [[1, 144, 224, 224]]  [1, 64, 224, 224]        5,185     
===========================================================================
Total params: 8,605
Trainable params: 8,605
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 3.06
Forward/backward pass size (MB): 1460.61
Params size (MB): 0.03
Estimated Total Size (MB): 1463.71
---------------------------------------------------------------------------






{'total_params': 8605, 'trainable_params': 8605}
2.4.2 AlexNet-ACmix
class AlexNet_ACmix(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            ACmix(128, 192,kernel_conv=3,stride=1),
            nn.ReLU(),
            ACmix(192,192,kernel_conv=3,stride=1),
            nn.ReLU(),
            ACmix(192,128,kernel_conv=3,stride=1),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(3*3*128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet_ACmix(num_classes=10)
paddle.summary(model, (1, 3, 128, 128))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-7      [[1, 3, 128, 128]]    [1, 48, 32, 32]        17,472     
    ReLU-5       [[1, 48, 32, 32]]     [1, 48, 32, 32]           0       
  MaxPool2D-1    [[1, 48, 32, 32]]     [1, 48, 15, 15]           0       
   Conv2D-8      [[1, 48, 15, 15]]     [1, 128, 15, 15]       153,728    
    ReLU-6       [[1, 128, 15, 15]]    [1, 128, 15, 15]          0       
  MaxPool2D-2    [[1, 128, 15, 15]]     [1, 128, 7, 7]           0       
   Conv2D-9       [[1, 128, 7, 7]]      [1, 192, 7, 7]        24,768     
   Conv2D-10      [[1, 128, 7, 7]]      [1, 192, 7, 7]        24,768     
   Conv2D-11      [[1, 128, 7, 7]]      [1, 192, 7, 7]        24,768     
   Conv2D-12       [[1, 2, 7, 7]]       [1, 48, 7, 7]           144      
    Pad2D-2       [[4, 48, 7, 7]]      [4, 48, 13, 13]           0       
   Unfold-5      [[4, 48, 13, 13]]      [4, 2352, 49]            0       
   Softmax-2      [[4, 49, 7, 7]]       [4, 49, 7, 7]            0       
   Conv2D-13     [[1, 12, 48, 49]]      [1, 9, 48, 49]          108      
   Conv2D-14      [[1, 432, 7, 7]]      [1, 192, 7, 7]        15,553     
    ACmix-2       [[1, 128, 7, 7]]      [1, 192, 7, 7]           2       
    ReLU-7        [[1, 192, 7, 7]]      [1, 192, 7, 7]           0       
   Conv2D-15      [[1, 192, 7, 7]]      [1, 192, 7, 7]        37,056     
   Conv2D-16      [[1, 192, 7, 7]]      [1, 192, 7, 7]        37,056     
   Conv2D-17      [[1, 192, 7, 7]]      [1, 192, 7, 7]        37,056     
   Conv2D-18       [[1, 2, 7, 7]]       [1, 48, 7, 7]           144      
    Pad2D-3       [[4, 48, 7, 7]]      [4, 48, 13, 13]           0       
   Unfold-6      [[4, 48, 13, 13]]      [4, 2352, 49]            0       
   Softmax-3      [[4, 49, 7, 7]]       [4, 49, 7, 7]            0       
   Conv2D-19     [[1, 12, 48, 49]]      [1, 9, 48, 49]          108      
   Conv2D-20      [[1, 432, 7, 7]]      [1, 192, 7, 7]        15,553     
    ACmix-3       [[1, 192, 7, 7]]      [1, 192, 7, 7]           2       
    ReLU-8        [[1, 192, 7, 7]]      [1, 192, 7, 7]           0       
   Conv2D-21      [[1, 192, 7, 7]]      [1, 128, 7, 7]        24,704     
   Conv2D-22      [[1, 192, 7, 7]]      [1, 128, 7, 7]        24,704     
   Conv2D-23      [[1, 192, 7, 7]]      [1, 128, 7, 7]        24,704     
   Conv2D-24       [[1, 2, 7, 7]]       [1, 32, 7, 7]           96       
    Pad2D-4       [[4, 32, 7, 7]]      [4, 32, 13, 13]           0       
   Unfold-7      [[4, 32, 13, 13]]      [4, 1568, 49]            0       
   Softmax-4      [[4, 49, 7, 7]]       [4, 49, 7, 7]            0       
   Conv2D-25     [[1, 12, 32, 49]]      [1, 9, 32, 49]          108      
   Conv2D-26      [[1, 288, 7, 7]]      [1, 128, 7, 7]        10,369     
    ACmix-4       [[1, 192, 7, 7]]      [1, 128, 7, 7]           2       
    ReLU-9        [[1, 128, 7, 7]]      [1, 128, 7, 7]           0       
  MaxPool2D-3     [[1, 128, 7, 7]]      [1, 128, 3, 3]           0       
   Linear-1         [[1, 1152]]           [1, 2048]          2,361,344   
    ReLU-10         [[1, 2048]]           [1, 2048]              0       
   Dropout-1        [[1, 2048]]           [1, 2048]              0       
   Linear-2         [[1, 2048]]           [1, 2048]          4,196,352   
    ReLU-11         [[1, 2048]]           [1, 2048]              0       
   Dropout-2        [[1, 2048]]           [1, 2048]              0       
   Linear-3         [[1, 2048]]            [1, 10]            20,490     
===========================================================================
Total params: 7,051,159
Trainable params: 7,051,159
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 13.31
Params size (MB): 26.90
Estimated Total Size (MB): 40.39
---------------------------------------------------------------------------






{'total_params': 7051159, 'trainable_params': 7051159}

2.5 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

model = AlexNet_ACmix(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.6 实验结果

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = AlexNet_ACmix(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:827
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet_ACmix(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 128, 128, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

3. AlexNet

3.1 AlexNet

class AlexNet(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(3 * 3 * 128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet(num_classes=10)
paddle.summary(model, (1, 3, 128, 128))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
  Conv2D-107     [[1, 3, 128, 128]]    [1, 48, 32, 32]        17,472     
    ReLU-40      [[1, 48, 32, 32]]     [1, 48, 32, 32]           0       
 MaxPool2D-16    [[1, 48, 32, 32]]     [1, 48, 15, 15]           0       
  Conv2D-108     [[1, 48, 15, 15]]     [1, 128, 15, 15]       153,728    
    ReLU-41      [[1, 128, 15, 15]]    [1, 128, 15, 15]          0       
 MaxPool2D-17    [[1, 128, 15, 15]]     [1, 128, 7, 7]           0       
  Conv2D-109      [[1, 128, 7, 7]]      [1, 192, 7, 7]        221,376    
    ReLU-42       [[1, 192, 7, 7]]      [1, 192, 7, 7]           0       
  Conv2D-110      [[1, 192, 7, 7]]      [1, 192, 7, 7]        331,968    
    ReLU-43       [[1, 192, 7, 7]]      [1, 192, 7, 7]           0       
  Conv2D-111      [[1, 192, 7, 7]]      [1, 128, 7, 7]        221,312    
    ReLU-44       [[1, 128, 7, 7]]      [1, 128, 7, 7]           0       
 MaxPool2D-18     [[1, 128, 7, 7]]      [1, 128, 3, 3]           0       
   Linear-16        [[1, 1152]]           [1, 2048]          2,361,344   
    ReLU-45         [[1, 2048]]           [1, 2048]              0       
  Dropout-11        [[1, 2048]]           [1, 2048]              0       
   Linear-17        [[1, 2048]]           [1, 2048]          4,196,352   
    ReLU-46         [[1, 2048]]           [1, 2048]              0       
  Dropout-12        [[1, 2048]]           [1, 2048]              0       
   Linear-18        [[1, 2048]]            [1, 10]            20,490     
===========================================================================
Total params: 7,524,042
Trainable params: 7,524,042
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 1.81
Params size (MB): 28.70
Estimated Total Size (MB): 30.69
---------------------------------------------------------------------------






{'total_params': 7524042, 'trainable_params': 7524042}

3.2 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model1'

model = AlexNet(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record1 = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record1 = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record1['train']['loss'].append(loss.numpy())
            loss_record1['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record1['train']['acc'].append(train_acc)
    acc_record1['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record1['val']['loss'].append(total_val_loss.numpy())
    loss_record1['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record1['val']['acc'].append(val_acc)
    acc_record1['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

3.3 实验结果

plot_learning_curve(loss_record1, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record1, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model1'
model = AlexNet(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:1174
work_path = 'work/model1'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 128, 128, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

4. 对比实验结果

modelTrain AccVal Accparameter
AlexNet w/o ACmix0.77700.798557524042
AlexNet w ACmix0.79150.818047051159

总结

        ACmix在减少参数(-472883)的同时大大加快了收敛速度以及精度(+0.01949)

此文仅为搬运,原作链接:https://aistudio.baidu.com/aistudio/projectdetail/4375641

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐