GCT:高斯上下文变换器

摘要

        近年来,人们提出了大量通道注意块来增强深度卷积神经网络(CNN)的表示能力。这些方法通常通过全连接层或线性转换来学习全局环境和注意力激活之间的关系。然而,我们的经验发现,尽管引入了许多参数,这些注意块可能不能很好地学习关系。在本文中,我们假设这种关系是预先确定的。基于这一假设,我们提出了一种简单但非常有效的通道注意块,称为高斯上下文变压器(GCT),它使用满足假设关系的高斯函数来实现上下文特征激励。根据高斯函数的标准差是否可学习,我们开发了两种版本的GCT: GCT- b0和GCT- b1。GCT-B0是一种固定标准差的无参数通道注意块。它直接将全局环境映射到注意力激活,而不需要学习。GCT-B1是参数化版本,自适应学习标准差,增强映射能力。在ImageNet和MS COCO基准上的广泛实验表明,我们的GCTs在各种深度CNN和检测器上实现了一致的改进。与SE和ECA等先进的通道注意力相比,我们的GCTs在效果和效率上都有优势。

1. GCT

1.1 总览

        基于全局上下文和注意力激活值之间存在负相关关系这个假设,作者提出了一个新的通道注意力模块,名为Gaussian Context Transformer(GCT),该模块通过一个表示预设负相关的高斯函数直接将全局注意力映射为注意力图。GCT的基础结构如下图所示,具体而言,GAP之后,GCT首先对通道向量进行标准化以稳定全局上下文的分布,然后一个高斯函数对标准化后的全局上下文进行激励操作(包括转换和激活两步)来获得注意力激活值(attention map,注意力图)。若高斯函数是固定的,这个模型称为GCT-B0,需要注意,GCT-B0是一个无参数的注意力块,它建模全局上下文不需要特征变换的学习。
在这里插入图片描述

1.2 GCT

        GCT由三个操作构成,包括全局上下文聚合(global context aggregation,GCA)、标准化(normalization)和高斯上下文激励(gaussian context excitation,GCE)

  1. GCA:使用全局平均池化来跨空间维度聚合每个样本的全局上下文特征,生成通道描述符:
    z = { z k = 1 H × W ∑ i = 1 W ∑ j = 1 H X k ( i , j ) : k ∈ { 1 , … , C } } z = \{z_{k} = \frac{1}{H \times W} \sum_{i=1}^{W} \sum_{j=1}^{H} \mathbf{X}_{k}(i, j): k \in\{1, \ldots, C\}\} z={zk=H×W1i=1Wj=1HXk(i,j):k{1,,C}}

  2. Normalization:
    z ^ = 1 σ ( z − μ ) \hat{\mathbf{z}}=\frac{1}{\sigma}(\mathbf{z}-\mu) z^=σ1(zμ)

  3. GCE:为了满足通道注意力学习一种负相关关系(全局上下文与均值越远,注意力激活值越小)这一假设,连续函数需满足一下条件:

    1. 处于(0,1]半闭区间内
    2. 当z等于0时,F(z)最大
    3. 在z>0时单调递减,z<0时单调递增
    4. z趋近无穷时,F(z)趋近0

    本文将GCE定义为高斯函数:
    G ( z ^ ) = a e − ( z ^ − b ) 2 2 c 2 G(\hat{\mathbf{z}})=a e^{-\frac{(\hat{\mathbf{z}}-b)^{2}}{2 c^{2}}} G(z^)=ae2c2(z^b)2
    为满足假设条件,简化得:
    g = G ( z ^ ) = e − z ^ 2 2 c 2 \mathbf{g}=G(\hat{\mathbf{z}})=e^{-\frac{\hat{\mathbf{z}}^{2}}{2 c^{2}}} g=G(z^)=e2c2z^2
    其中c为控制通道注意激活差值的高斯函数标准差,标准差越大,通道间注意激活差值越小

        无参数GCT,c通常取2;有参数GCT,c的公式如下(默认α=3,β=1):
c = α ⋅ sigmoid ⁡ ( θ ) + β c=\alpha \cdot \operatorname{sigmoid}(\theta)+\beta c=αsigmoid(θ)+β

2. 代码复现

2.1 下载并导入所需要的包

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from paddle import ParamAttr
import math

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((130, 130)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(128, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm)
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 标签平滑

class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()

2.4 AlexNet-GCT-B0

2.4.1 GCT-B0
class GCT_B0(nn.Layer):
    def __init__(self, c=2):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2D(1)
        self.c = c

    def forward(self, x):
        out = self.pool(x)

        mean = paddle.mean(out, axis=1, keepdim=True)
        var = paddle.std(out, axis=1, keepdim=True)
        out = (out - mean) / var

        out = paddle.exp(-paddle.pow(out, 2) / (2 * math.pow(self.c, 2)))

        out = out * x

        return out
model = GCT_B0()
paddle.summary(model, (1, 64, 224, 224))
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
AdaptiveAvgPool2D-1 [[1, 64, 224, 224]]     [1, 64, 1, 1]            0       
===============================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
-------------------------------------------------------------------------------
Input size (MB): 12.25
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 12.25
-------------------------------------------------------------------------------



W0807 11:07:43.047647   350 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0807 11:07:43.051434   350 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.





{'total_params': 0, 'trainable_params': 0}
2.4.2 AlexNet-GCT-B0
class AlexNet_GCT_B0(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            GCT_B0(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            GCT_B0(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            GCT_B0(),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            GCT_B0(),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            GCT_B0(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(3 * 3 * 128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet_GCT_B0(num_classes=10)
paddle.summary(model, (1, 3, 128, 128))

在这里插入图片描述

2.4.3 训练
learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

model = AlexNet_GCT_B0(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.4.4 实验结果
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = AlexNet_GCT_B0(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:2631
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet_GCT_B0(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 128, 128, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

2.5 AlexNet-GCT-B1

2.5.1 GCT-B1
class GCT_B1(nn.Layer):
    def __init__(self, alpha=3, beta=1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2D(1)
        self.theta = self.create_parameter([1], default_initializer=nn.initializer.Constant(0.5))
        self.c = alpha * F.sigmoid(self.theta) + beta

    def forward(self, x):
        out = self.pool(x)

        mean = paddle.mean(out, axis=1, keepdim=True)
        var = paddle.std(out, axis=1, keepdim=True)
        out = (out - mean) / var

        out = paddle.exp(-paddle.pow(out, 2) / (2 * math.pow(self.c, 2)))

        out = out * x

        return out

参数没显示出来,但是后面AlexNet-GCT-B1有显示参数大小

model = GCT_B1()
paddle.summary(model, (1, 64, 224, 224))
--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
================================================================================
AdaptiveAvgPool2D-22 [[1, 64, 224, 224]]     [1, 64, 1, 1]            0       
================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
--------------------------------------------------------------------------------
Input size (MB): 12.25
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 12.25
--------------------------------------------------------------------------------






{'total_params': 0, 'trainable_params': 0}
2.5.2 AlexNet-GCT-B1
class AlexNet_GCT_B1(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            GCT_B1(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            GCT_B1(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            GCT_B1(),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            GCT_B1(),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            GCT_B1(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(3 * 3 * 128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet_GCT_B1(num_classes=10)
paddle.summary(model, (1, 3, 128, 128))

在这里插入图片描述

2.5.3 训练
learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model2'

model = AlexNet_GCT_B1(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record2 = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record2 = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record2['train']['loss'].append(loss.numpy())
            loss_record2['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record2['train']['acc'].append(train_acc)
    acc_record2['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record2['val']['loss'].append(total_val_loss.numpy())
    loss_record2['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record2['val']['acc'].append(val_acc)
    acc_record2['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.5.4 实验结果
plot_learning_curve(loss_record2, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record2, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model2'
model = AlexNet_GCT_B1(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:2339
work_path = 'work/model2'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet_GCT_B1(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 128, 128, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

3. AlexNet

3.1 AlexNet

class AlexNet(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(3 * 3 * 128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet(num_classes=10)
paddle.summary(model, (1, 3, 128, 128))

在这里插入图片描述

3.2 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model1'

model = AlexNet(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record1 = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record1 = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record1['train']['loss'].append(loss.numpy())
            loss_record1['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record1['train']['acc'].append(train_acc)
    acc_record1['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record1['val']['loss'].append(total_val_loss.numpy())
    loss_record1['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record1['val']['acc'].append(val_acc)
    acc_record1['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

3.3 实验结果

plot_learning_curve(loss_record1, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record1, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model1'
model = AlexNet(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:2291
work_path = 'work/model1'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 128, 128, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

4. 对比实验结果

model Train Acc Val Acc parameter
AlexNet 0.7713 0.80083 7524042
AlexNet w GCT-B0 0.8060 0.81725 7524042
AlexNet w GCT-B1 0.7965 0.81576 7524047

总结

        GCT在增加少量参数(+5)甚至不增加参数的同时大大加快了收敛速度以及精度(+0.01642,+0.01493)

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐