AI达人特训营】GFNet:用于图像分类的全局滤波网络

摘要

        近年来在视觉的自我注意和纯多层感知器(MLP)模型方面取得的进展显示出在减少归纳偏置的情况下实现有前景的性能方面的巨大潜力。这些模型通常基于从原始数据中学习空间位置之间的相互作用。随着图像尺寸的增加,自我注意和MLP的复杂性呈二次增长,这使得当需要高分辨率特征时,这些模型难以放大。在本文中,我们提出了一种概念上简单但计算效率高的全局滤波网络(GFNet),它学习频率域内具有对数线性复杂度的长期空间相关性。我们的架构用三个关键操作取代了ViT中的自我注意层:二维离散傅里叶变换、频域特征和可学习的全局滤波器之间的元素相乘,以及二维傅里叶反变换。我们在ImageNet和下游任务上展示了我们的模型良好的准确性/复杂性权衡。我们的结果表明,GFNet在效率、泛化能力和鲁棒性方面可以成为Transformer式模型和CNNs的一个非常有竞争力的替代方案。

1. GFNet

1.1 总体架构

        本文的架构如图1所示与ViT类似,唯一的不同之处是将自注意力模块替换为全局滤波层。
在这里插入图片描述

1.2 全局滤波层

        全局滤波层主要包含三个部分:2D FFT、learnable global filters和2D IFFT:

  1. 2D FFT使用快速傅里叶变换将图像从空间域转换到频域,公式如下:
    X = F [ x ] ∈ C H × W × D \boldsymbol{X}=\mathcal{F}[\boldsymbol{x}] \in \mathbb{C}^{H \times W \times D} X=F[x]CH×W×D
  2. 使用可学习的全局滤波器与频域特征进行相乘:
    X ~ = K ⊙ X \tilde{\boldsymbol{X}}=\boldsymbol{K} \odot \boldsymbol{X} X~=KX
  3. 2D IFFT使用快速傅里叶逆变换将图像从频域转换回空间域:
    x ← F − 1 [ X ~ ] \boldsymbol{x} \leftarrow \mathcal{F}^{-1}[\tilde{\boldsymbol{X}}] xF1[X~]

2. 代码复现

2.1 下载并导入所需的库

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import math

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((230, 230)),
    transforms.ColorJitter(brightness=0.4,contrast=0.4, saturation=0.4),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm)
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 GFNet模型的创建
class Mlp(nn.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
class GlobalFilter(nn.Layer):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = self.create_parameter(attr=None, shape=(h, w, dim, 2),
            dtype='float32', is_bias=False,
            default_initializer=nn.initializer.Assign(paddle.randn(shape=(h, w, dim, 2), dtype='float32') * 0.02))
        self.w = w
        self.h = h

    def forward(self, x, spatial_size=None):
        B, N, C = x.shape
        if spatial_size is None:
            a = b = int(math.sqrt(N))
        else:
            a, b = spatial_size

        x = x.reshape((B, a, b, C))

        x = paddle.to_tensor(x, paddle.float32)

        x = paddle.fft.rfft2(x, axes=(1, 2), norm='ortho')
        weight = paddle.as_complex(self.complex_weight)
        x = x * weight
        x = paddle.fft.irfft2(x, s=(a, b), axes=(1, 2), norm='ortho')

        x = x.reshape((B, N, C))

        return x
class Block(nn.Layer):

    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.filter = GlobalFilter(dim, h=h, w=w)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
        return x
class GFNet(nn.Layer):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 mlp_ratio=4., representation_size=None, uniform_drop=False,
                 drop_rate=0., drop_path_rate=0., norm_layer=None,
                 dropcls=0):
        super().__init__()

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or nn.LayerNorm

        assert img_size % patch_size == 0, 'Image size cannot be divided by patch size'

        self.patch_embed = nn.Conv2D(in_chans, embed_dim, patch_size, patch_size)
        num_patches = (img_size//patch_size) * (img_size//patch_size)

        self.pos_embed = self.create_parameter(attr=None, shape=(1, num_patches, embed_dim),
            dtype='float32', is_bias=False)
        self.pos_drop = nn.Dropout(p=drop_rate)

        h = img_size // patch_size
        w = h // 2 + 1
        if uniform_drop:
            print('using uniform droppath with expect rate', drop_path_rate)
            dpr = [drop_path_rate for _ in range(depth)]  # stochastic depth decay rule
        else:
            print('using linear droppath with expect rate', drop_path_rate * 0.5)
            dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        self.blocks = nn.LayerList([
            Block(
                dim=embed_dim, mlp_ratio=mlp_ratio,
                drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w)
            for i in range(depth)])

        self.norm = norm_layer(embed_dim)

        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(
                nn.Linear(embed_dim, representation_size),
                nn.Tanh()
           )
        else:
            self.pre_logits = nn.Identity()

        if dropcls > 0:
            print('dropout %.2f before classifier' % dropcls)
            self.final_dropout = nn.Dropout(p=dropcls)
        else:
            self.final_dropout = nn.Identity()

        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        init = paddle.nn.initializer.TruncatedNormal(mean=0.0, std=.02)
        init(self.pos_embed)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        zeros_ = nn.initializer.Constant(value=0.)
        ones_ = nn.initializer.Constant(value=1.)
        if isinstance(m, (nn.Linear, nn.Conv2D)):
            init = paddle.nn.initializer.TruncatedNormal(mean=0.0, std=.02)
            init(m.weight)
            if isinstance(m, (nn.Linear, nn.Conv2D)) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
            zeros_(m.bias)
            ones_(m.weight)


    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = paddle.flatten(x, 2).transpose([0, 2, 1])
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x).mean(1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.final_dropout(x)
        x = self.head(x)
        return x
2.3.4 模型的参数
# GFNet-Ti
model = GFNet(num_classes=10, embed_dim=256, depth=12)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# GFNet-XS
model = GFNet(num_classes=10, embed_dim=384, depth=12)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# GFNet-S
model = GFNet(num_classes=10, embed_dim=384, depth=19)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

# GFNet-B
model = GFNet(num_classes=10, embed_dim=512, depth=19)
paddle.summary(model, (batch_size, 3, 224, 224))

在这里插入图片描述

2.4 训练

learning_rate = 5e-4
n_epochs = 300
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# GFNet-XS
model = GFNet(num_classes=10, embed_dim=384, depth=12)

criterion = LabelSmoothingCrossEntropy()

grad_norm = paddle.nn.ClipGradByGlobalNorm(clip_norm=1)
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs,
    eta_min=1e-5,verbose=False)
optimizer = paddle.optimizer.AdamW(parameters=model.parameters(), learning_rate=scheduler, weight_decay=0.05
    ,grad_clip=grad_norm)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.5 结果分析

2.5.1 位置编码可视化
ans = model.pos_embed.squeeze().reshape((14, 14, 384)).transpose([2, 0, 1])
plt.imshow(ans.mean(axis=0))
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f6adec5fc10>

在这里插入图片描述

plt.imshow(ans[0, :, :])
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f6ac03653d0>

在这里插入图片描述

plt.imshow(ans[191, :, :])
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f6ac028e5d0>

在这里插入图片描述

plt.imshow(ans[383, :, :])
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f6ac02338d0>

在这里插入图片描述

2.5.2 loss和acc曲线
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = GFNet(num_classes=10, embed_dim=384, depth=12)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
using linear droppath with expect rate 0.0
Throughout:648
2.5.3 预测与真实标签比较
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = GFNet(num_classes=10, embed_dim=384, depth=12)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
using linear droppath with expect rate 0.0


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

layers_name = []
for n, v in model.named_sublayers():
    layers_name.append(n)
print(layers_name)
['patch_embed', 'pos_drop', 'blocks', 'blocks.0', 'blocks.0.norm1', 'blocks.0.filter', 'blocks.0.drop_path', 'blocks.0.norm2', 'blocks.0.mlp', 'blocks.0.mlp.fc1', 'blocks.0.mlp.act', 'blocks.0.mlp.fc2', 'blocks.0.mlp.drop', 'blocks.1', 'blocks.1.norm1', 'blocks.1.filter', 'blocks.1.drop_path', 'blocks.1.norm2', 'blocks.1.mlp', 'blocks.1.mlp.fc1', 'blocks.1.mlp.act', 'blocks.1.mlp.fc2', 'blocks.1.mlp.drop', 'blocks.2', 'blocks.2.norm1', 'blocks.2.filter', 'blocks.2.drop_path', 'blocks.2.norm2', 'blocks.2.mlp', 'blocks.2.mlp.fc1', 'blocks.2.mlp.act', 'blocks.2.mlp.fc2', 'blocks.2.mlp.drop', 'blocks.3', 'blocks.3.norm1', 'blocks.3.filter', 'blocks.3.drop_path', 'blocks.3.norm2', 'blocks.3.mlp', 'blocks.3.mlp.fc1', 'blocks.3.mlp.act', 'blocks.3.mlp.fc2', 'blocks.3.mlp.drop', 'blocks.4', 'blocks.4.norm1', 'blocks.4.filter', 'blocks.4.drop_path', 'blocks.4.norm2', 'blocks.4.mlp', 'blocks.4.mlp.fc1', 'blocks.4.mlp.act', 'blocks.4.mlp.fc2', 'blocks.4.mlp.drop', 'blocks.5', 'blocks.5.norm1', 'blocks.5.filter', 'blocks.5.drop_path', 'blocks.5.norm2', 'blocks.5.mlp', 'blocks.5.mlp.fc1', 'blocks.5.mlp.act', 'blocks.5.mlp.fc2', 'blocks.5.mlp.drop', 'blocks.6', 'blocks.6.norm1', 'blocks.6.filter', 'blocks.6.drop_path', 'blocks.6.norm2', 'blocks.6.mlp', 'blocks.6.mlp.fc1', 'blocks.6.mlp.act', 'blocks.6.mlp.fc2', 'blocks.6.mlp.drop', 'blocks.7', 'blocks.7.norm1', 'blocks.7.filter', 'blocks.7.drop_path', 'blocks.7.norm2', 'blocks.7.mlp', 'blocks.7.mlp.fc1', 'blocks.7.mlp.act', 'blocks.7.mlp.fc2', 'blocks.7.mlp.drop', 'blocks.8', 'blocks.8.norm1', 'blocks.8.filter', 'blocks.8.drop_path', 'blocks.8.norm2', 'blocks.8.mlp', 'blocks.8.mlp.fc1', 'blocks.8.mlp.act', 'blocks.8.mlp.fc2', 'blocks.8.mlp.drop', 'blocks.9', 'blocks.9.norm1', 'blocks.9.filter', 'blocks.9.drop_path', 'blocks.9.norm2', 'blocks.9.mlp', 'blocks.9.mlp.fc1', 'blocks.9.mlp.act', 'blocks.9.mlp.fc2', 'blocks.9.mlp.drop', 'blocks.10', 'blocks.10.norm1', 'blocks.10.filter', 'blocks.10.drop_path', 'blocks.10.norm2', 'blocks.10.mlp', 'blocks.10.mlp.fc1', 'blocks.10.mlp.act', 'blocks.10.mlp.fc2', 'blocks.10.mlp.drop', 'blocks.11', 'blocks.11.norm1', 'blocks.11.filter', 'blocks.11.drop_path', 'blocks.11.norm2', 'blocks.11.mlp', 'blocks.11.mlp.fc1', 'blocks.11.mlp.act', 'blocks.11.mlp.fc2', 'blocks.11.mlp.drop', 'norm', 'pre_logits', 'final_dropout', 'head']
def show_single_filter(gf):
    # gf: global filter: (h, h // 2 + 1, 2)
    h = gf.shape[0]
    gf_complex = paddle.as_complex(gf)
    gf_spatial = paddle.fft.irfft2(gf_complex, axes=(0, 1), s=(h, h))
    gf_complex = paddle.fft.fft2(gf_spatial, axes=(0, 1))
    gf_complex = paddle.fft.fftshift(gf_complex, axes=(0, 1))
    gf_abs = gf_complex.abs()
    return gf_abs

n_viz_channel = 12
global_filters = []
for i_layer in range(12):
    weight = model_state_dict[f'blocks.{i_layer}.filter.complex_weight']
    for i_channel in range(n_viz_channel):
        global_filters.append(show_single_filter(weight[:, :, i_channel])[None])
global_filters = paddle.stack(global_filters)
global_filters.shape
[144, 1, 14, 14]
2.5.4 全局滤波器可视化

        第i行代表第i个全局滤波器,第j列代表第j个通道

fig = plt.figure(figsize = (12, 12))
rows = 12
cols = 12
for i in range(1, rows*cols + 1):
    img_array = global_filters[i-1].detach().cpu().squeeze().numpy()
    # 子图位置
    ax = fig.add_subplot(rows, cols, i)
    plt.axis('off') # 去掉每个子图的坐标轴
    plt.imshow(img_array, cmap='YlGnBu')

plt.subplots_adjust(wspace = 0, hspace = 0) # 修改子图之间的间隔
plt.savefig('filter_map.png', dpi=100)

在这里插入图片描述

总结

        本文探究在频域上进行图像分类,效果还不错,但是训练花费很大,本文训练了300个epochs仍未完全收敛,有兴趣可以将epochs调大些,余弦退火的最小学习率再调小一些
        未来工作:对GFNet进行改进,尝试让其更快更好地收敛

开源链接https://aistudio.baidu.com/aistudio/projectdetail/4293392?shared=1

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐