Triplet Attention: 轻量且有效的即插即用注意力

摘要

        由于能够在通道或空间位置之间建立相互依存关系,注意力机制近年来得到了广泛的研究,并广泛应用于各种计算机视觉任务中。在本文中,我们研究了轻量但有效的注意力机制,并提出了Triplet Attention,这是一种通过使用三分支结构捕捉交叉维度交互来计算注意权重的新方法。对于输入张量,Triplet Attention通过旋转操作和残差变换建立维度间依赖关系,并以可忽略的计算开销对通道间和空间信息进行编码。我们的方法简单高效,可以很容易地作为附加模块插入到经典骨干网络中。我们证明了我们的方法在各种挑战性任务上的有效性,包括ImageNet-1k上的图像分类和MSCOCO和PASCAL VOC数据集上的目标检测。此外,我们通过观察GradCAM和GradCAM++结果,深入了解了Triplet Attention的性能。对我们方法的实证评估支持了我们在计算注意力权重时捕捉跨维度依赖关系的重要性的直觉。

1. Triplet Attention

        本文的目标是研究如何在不涉及任何维数降低的情况下建立廉价但有效的通道注意力模型。Triplet Attention不像CBAM和SENet需要一定数量的可学习参数来建立通道间的依赖关系,本文提出了一个几乎无参数的注意机制来建模通道注意和空间注意,即Triplet Attention。
        所提出的Triplet Attention见图3所示。顾名思义,Triplet Attention由3个平行的Branch组成,其中两个负责捕获通道C和空间H或W之间的跨维交互。最后一个Branch类似于CBAM,用于构建Spatial Attention。最终3个Branch的输出使用平均进行聚合。
在这里插入图片描述

1.1 Cross-Dimension Interaction

        传统的计算通道注意力的方法涉及计算一个权值,然后使用权值统一缩放这些特征图。但是在考虑这种方法时,有一个重要的缺失。通常,为了计算这些通道的权值,输入张量在空间上通过全局平均池化分解为一个像素。这导致了空间信息的大量丢失,因此在单像素通道上计算注意力时,通道维数和空间维数之间的相互依赖性也不存在。
        虽然后期提出基于Spatial和Channel的CBAM模型缓解了空间相互依赖的问题,但是依然存在一个问题,即,通道注意和空间注意是分离的,计算是相互独立的。基于建立空间注意力的方法,本文提出了跨维度交互作用(cross dimension interaction)的概念,通过捕捉空间维度和输入张量通道维度之间的交互作用,解决了这一问题。
在这里插入图片描述

1.2 Z-pool

        Z-pool层负责将C维度的Tensor缩减到2维,将该维上的平均汇集特征和最大汇集特征连接起来。这使得该层能够保留实际张量的丰富表示,同时缩小其深度以使进一步的计算量更轻。可以用下式表示:
Z − pool ⁡ ( χ ) = [ MaxPool ⁡ 0 d ( χ ) , AvgPool ⁡ 0 d ( χ ) ] Z-\operatorname{pool}(\chi)=\left[\operatorname{MaxPool}_{0 d}(\chi), \operatorname{AvgPool}_{0 d}(\chi)\right] Zpool(χ)=[MaxPool0d(χ),AvgPool0d(χ)]

1.3 Triplet Attention

        给定一个输入张量 χ ∈ R C × H × W \chi \in R^{C \times H \times W} χRC×H×W ,首先将其传递到Triplet Attention模块中的三个分支中。
在第1个分支中,在H维度和C维度之间建立了交互:
在这里插入图片描述

        为了实现这一点,输入张量 χ \chi χ 沿H轴逆时针旋转90°。这个旋转张量 χ ^ 1 \hat{\chi }_{1} χ^1 表示为的形状为 ( W × H × C ) (W×H×C) (W×H×C) ,再然后经过Z-Pool后的张量 χ ^ 1 ∗ \hat{\chi }_{1}^{*} χ^1 的shape为 ( 2 × H × C ) (2×H×C) (2×H×C) ,然后,通过内核大小为 k × k k×k k×k 的标准卷积层,再通过批处理归一化层,提供维数 ( 1 × H × C ) (1×H×C) (1×H×C) 的中间输出。然后,通过将张量通过sigmoid来生成的注意力权值。在最后输出是沿着H轴进行顺时针旋转90°保持和输入的shape一致。
在第2个分支中,在C维度和W维度之间建立了交互:
在这里插入图片描述

        为了实现这一点,输入张量 χ \chi χ 沿W轴逆时针旋转90°。这个旋转张量 χ ^ 2 \hat{\chi }_{2} χ^2 表示为的形状为 ( H × C × W ) (H×C×W) (H×C×W) ,再然后经过Z-Pool后的张量 χ ^ 2 ∗ \hat{\chi }_{2}^{*} χ^2 的shape为 ( 2 × C × W ) (2×C×W ) (2×C×W) ,然后,通过内核大小为 k × k k×k k×k 的标准卷积层,再通过批处理归一化层,提供维数 ( 1 × C × W ) (1×C×W) (1×C×W) 的中间输出。然后,通过将张量通过sigmoid来生成的注意力权值。在最后输出是沿着W轴进行顺时针旋转90°保持和输入的shape一致。
在第3个分支中,在H维度和W维度之间建立了交互:
在这里插入图片描述

        输入张量 χ \chi χ 的通道通过Z-pool将变量简化为2。将这个形状的简化张量 ( 2 × H × W ) (2×H×W) (2×H×W) 简化后通过核大小 k × k k×k k×k 定义的标准卷积层,然后通过批处理归一化层。输出通过sigmoid激活层生成形状为(1×H×W)的注意权值,并将其应用于输入 χ \chi χ ,得到结果 χ ^ 3 \hat{\chi }_{3} χ^3 。然后通过简单的平均将3个分支产生的精细张量 ( C × H × W ) (C×H×W) (C×H×W) 聚合在一起。
        最终输出的Tensor:
y = 1 3 ( χ 1 ^ σ ( ψ 1 ( χ 1 ∗ ^ ) ) ‾ + χ 2 ^ σ ( ψ 2 ( χ 2 ∗ ^ ) ) ‾ + χ σ ( ψ 3 ( χ 3 ^ ) ) ) y=\frac{1}{3}\left(\overline{\hat{\chi_{1}} \sigma\left(\psi_{1}\left(\hat{\chi_{1}^{*}}\right)\right)}+\overline{\hat{\chi_{2}} \sigma\left(\psi_{2}\left(\hat{\chi_{2}^{*}}\right)\right)}+\chi \sigma\left(\psi_{3}\left(\hat{\chi_{3}}\right)\right)\right) y=31(χ1^σ(ψ1(χ1^))+χ2^σ(ψ2(χ2^))+χσ(ψ3(χ3^)))

1.4 参数分析

在这里插入图片描述

2. 代码复现

2.1 下载并导入所需要的包

!pip install einops-0.3.0-py3-none-any.whl
!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10, Cifar100
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from einops.layers.paddle import Rearrange, Reduce
from einops import rearrange

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((230, 230)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    paddlex.transforms.MixupImage(),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm)
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 标签平滑

class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()

2.4 AlexNet-Triplet Attention

2.4.1 Triplet Attention
class BasicConv(nn.Layer):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0,
        dilation=1, groups=1, relu=True, bn=True,bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size,
            stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias_attr=bias)
        self.bn = nn.BatchNorm2D(out_planes, epsilon=1e-5, momentum=0.01) \
            if bn else None

        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x
class Z_Pool(nn.Layer):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return paddle.concat((paddle.max(x, 1).unsqueeze(1), paddle.mean(x, 1).unsqueeze(1)), axis=1)


class SpatialGate(nn.Layer):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = Z_Pool()
        self.spatial = BasicConv(
            2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = self.sigmoid(x_out)
        return x * scale

class TripletAttention(nn.Layer):
    def __init__(self, no_spatial=False):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = nn.Sequential(Rearrange('b c h w -> b h c w'), SpatialGate(), Rearrange('b h c w -> b c h w'))
        self.ChannelGateW = nn.Sequential(Rearrange('b c h w -> b w h c'), SpatialGate(), Rearrange('b w h c -> b c h w'))
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_out1 = self.ChannelGateH(x)
        x_out2 = self.ChannelGateW(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out1 + x_out2)
        else:
            x_out = (1 / 2) * (x_out1 + x_out2)
        return x_out
model = TripletAttention()
paddle.summary(model, (1, 3, 224, 224))
W0724 13:17:54.556988  1024 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0724 13:17:54.560813  1024 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.


---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
  Rearrange-1    [[1, 3, 224, 224]]    [1, 224, 3, 224]          0       
   Z_Pool-1      [[1, 224, 3, 224]]     [1, 2, 3, 224]           0       
   Conv2D-1       [[1, 2, 3, 224]]      [1, 1, 3, 224]          98       
 BatchNorm2D-1    [[1, 1, 3, 224]]      [1, 1, 3, 224]           4       
  BasicConv-1     [[1, 2, 3, 224]]      [1, 1, 3, 224]           0       
   Sigmoid-2      [[1, 1, 3, 224]]      [1, 1, 3, 224]           0       
 SpatialGate-1   [[1, 224, 3, 224]]    [1, 224, 3, 224]          0       
  Rearrange-2    [[1, 224, 3, 224]]    [1, 3, 224, 224]          0       
  Rearrange-3    [[1, 3, 224, 224]]    [1, 224, 224, 3]          0       
   Z_Pool-2      [[1, 224, 224, 3]]     [1, 2, 224, 3]           0       
   Conv2D-2       [[1, 2, 224, 3]]      [1, 1, 224, 3]          98       
 BatchNorm2D-2    [[1, 1, 224, 3]]      [1, 1, 224, 3]           4       
  BasicConv-2     [[1, 2, 224, 3]]      [1, 1, 224, 3]           0       
   Sigmoid-3      [[1, 1, 224, 3]]      [1, 1, 224, 3]           0       
 SpatialGate-2   [[1, 224, 224, 3]]    [1, 224, 224, 3]          0       
  Rearrange-4    [[1, 224, 224, 3]]    [1, 3, 224, 224]          0       
   Z_Pool-3      [[1, 3, 224, 224]]    [1, 2, 224, 224]          0       
   Conv2D-3      [[1, 2, 224, 224]]    [1, 1, 224, 224]         98       
 BatchNorm2D-3   [[1, 1, 224, 224]]    [1, 1, 224, 224]          4       
  BasicConv-3    [[1, 2, 224, 224]]    [1, 1, 224, 224]          0       
   Sigmoid-4     [[1, 1, 224, 224]]    [1, 1, 224, 224]          0       
 SpatialGate-3   [[1, 3, 224, 224]]    [1, 3, 224, 224]          0       
===========================================================================
Total params: 306
Trainable params: 294
Non-trainable params: 12
---------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 10.40
Params size (MB): 0.00
Estimated Total Size (MB): 10.97
---------------------------------------------------------------------------






{'total_params': 306, 'trainable_params': 294}
class AlexNet_TA(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            TripletAttention(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            TripletAttention(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            TripletAttention(),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            TripletAttention(),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            TripletAttention(),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(6*6*128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet_TA(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

2.5 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

model = AlexNet_TA(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.6 实验结果

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = AlexNet_TA(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:544
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet_TA(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

3. AlexNet

3.1 AlexNet

class AlexNet(nn.Layer):
    def __init__(self,num_classes=10):
        super(AlexNet, self).__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(6*6*128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))

在这里插入图片描述

3.2 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model1'

model = AlexNet(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record1 = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record1 = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record1['train']['loss'].append(loss.numpy())
            loss_record1['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record1['train']['acc'].append(train_acc)
    acc_record1['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record1['val']['loss'].append(total_val_loss.numpy())
    loss_record1['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record1['val']['acc'].append(val_acc)
    acc_record1['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

3.3 实验结果

plot_learning_curve(loss_record1, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record1, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model1'
model = AlexNet(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:562
work_path = 'work/model1'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

4. 对比实验结果

modelTrain AccVal Accparameter
AlexNet w/o Triplet Attention0.80030.8241714601930
AlexNet w Triplet Attention0.89360.8776714603460

总结

        Triplet Attention在增加极小参数(+1530)的情况下大大加快了收敛速度以及精度(+0.05350)

此文仅为搬运,原作链接:https://aistudio.baidu.com/aistudio/projectdetail/4368594

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐