★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

摘要

        多尺度视觉Transformer(ViT)已成为计算机视觉的重要支柱,而Transformer中的自注意力计算与输入Patch数呈二次方关系。 因此,现有的解决方案通常在键/值上使用向下采样操作(例如,平均池化)以显著降低计算成本。 在这项工作中,我们认为这种过度激进的下采样设计是不可逆的,不可避免地会导致信息丢失,尤其是对于对象中的高频成分(如纹理细节)。 在小波理论的启发下,我们构造了一种新的小波视觉Transformer(Wave-ViT),将可逆下采样与小波变换和自注意学习统一起来。 该方案实现了对键/值进行无损下采样的自注意力学习,有助于追求更好的效率与精度的权衡。 此外,通过聚集局部上下文与扩大的感受野,逆小波变换被用来增强自注意力输出。 通过对多个视觉任务(如图像识别、目标检测和实例分割)的大量实验,验证了Wave-VIT的优越性。 在可比的FLOPs下,它的性能超过了最先进的ViT骨干。

1. Wave-ViT

1.1 离散小波变换

        余弦变换是经典的谱分析工具,他考察的是整个时域过程的频域特征或整个频域过程的时域特征,因此对于平稳过程,他有很好的效果,但对于非平稳过程,他却有诸多不足。在JPEG中,离散余弦变换将图像压缩为8×8 的小块,然后依次放入文件中,这种算法靠丢弃频率信息实现压缩,因而图像的压缩率越高,频率信息被丢弃的越多。在极端情况下,JPEG图像只保留了反映图像外貌的基本信息,精细的图像细节都损失了。小波变换是现代谱分析工具,他既能考察局部时域过程的频域特征,又能考察局部频域过程的时域特征,因此即使对于非平稳过程,处理起来也得心应手。它能将图像变换为一系列小波系数,这些系数可以被高效压缩和存储,此外,小波的粗略边缘可以更好地表现图像,因为他消除了DCT压缩普遍具有的方块效应。

        Haar小波变换是最简单和最早的一种小波变换方法。它是由Alfred Haar在1910年提出的,因此得名为Haar小波。它的特点是具有简单的计算过程和紧凑的表示形式。Haar小波变换将信号分解为两个子信号,一个是低频近似分量(Approximation),另一个是高频细节分量(Detail)。经由Haar小波变换,如下图所示可以得到四个分量:低频分量、水平高频分量、竖直高频分量、对角线高频分量。同时也可以通过这四个分量无损重构原有的图像。

!pip install pywavelets
import numpy as np
from matplotlib import pyplot as plt
import pywt
import PIL

img = PIL.Image.open("catdog.jpg")
img = np.array(img)[:, :, 0]
plt.imshow(img, cmap='gray')
plt.show()
LLY, (LHY, HLY, HHY) = pywt.dwt2(img, 'haar')
plt.subplot(2, 2, 1)
plt.imshow(LLY, cmap="Greys")
plt.subplot(2, 2, 2)
plt.imshow(LHY, cmap="Greys")
plt.subplot(2, 2, 3)
plt.imshow(HLY, cmap="Greys")
plt.subplot(2, 2, 4)
plt.imshow(HHY, cmap="Greys")
plt.show()

在这里插入图片描述

在这里插入图片描述

img = pywt.idwt2((LLY, (LHY, HLY, HHY)), 'haar')
plt.imshow(img, cmap='gray')
plt.show()

在这里插入图片描述

1.2 小波变换块(Wavelets Block)

        如图2(a)所示,原始的自注意力机制与输入数据呈二次方,计算量大。为此,一些方法使用如图2(b)所示的方式对K和V进行下采样以减少计算,但是由于平均池化操作会导致高频细节的丢失从而影响性能。因此如图2(c)所示,本文结合小波变换提出了一种具有无损可逆下采样的自注意力块——小波变换块。

        首先本文使用经典的Haar小波变换得到四个分量,并将其合并并使用卷积生成K和V,然后与Q进行自注意力机制,同时将四个分量使用离散小波逆变换并将其与自注意力得到的结果合并并使用线性层来融合特征得到指定通道数的特征图(根据小波理论,重构的特征图 X r X^r Xr 能够保留原始输入 $\tilde{x} $ 的各个细节。 值得注意的是,与单个3×3卷积相比,这种小波块中的DWT-卷积-IDWT过程在增大感受野的情况下触发了更强的局部上下文化,而计算成本/内存的增加几乎可以忽略不计)。计算过程如下公式所示:

head ⁡ j = Attention ⁡ w ( Q j , K j w , V j w ) = Softmax ⁡ ( Q j K j w T D h ) V j w  WaveletsBlock  ( X ) = MultiHead ⁡ w ( X W q , X c W k , X c W v , X r ) MultiHead ⁡ w ( Q , K , V , X r ) = Concat ⁡ ( head 0 , head ⁡ 1 , … , h e a d N h , X r ) W ~ O \begin{array}{l} \operatorname{head}_{j}=\operatorname{Attention}^{\mathbf{w}}\left(Q_{j}, K_{j}^{w}, V_{j}^{w}\right)=\operatorname{Softmax}\left(\frac{Q_{j} K_{j}^{w T}}{\sqrt{D_{h}}}\right) V_{j}^{w}\\ \text { WaveletsBlock }(X)=\operatorname{MultiHead}^{\mathbf{w}}\left(X W^{q}, X^{c} W^{k}, X^{c} W^{v}, X^{r}\right) \\ \operatorname{MultiHead}^{\mathbf{w}}\left(Q, K, V, X^{r}\right)=\operatorname{Concat}\left(\text {head}_{0}, \operatorname{head}_{1}, \ldots, head_{N_{h}}, X^{r}\right) \tilde{W}^{O} \end{array} headj=Attentionw(Qj,Kjw,Vjw)=Softmax(Dh QjKjwT)Vjw WaveletsBlock (X)=MultiHeadw(XWq,XcWk,XcWv,Xr)MultiHeadw(Q,K,V,Xr)=Concat(head0,head1,,headNh,Xr)W~O

2. 代码复现

2.1 下载并导入所需的库

!pip install einops-0.3.0-py3-none-any.whl
!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from einops import rearrange, repeat
import pywt
from paddle.autograd import PyLayer
from functools import partial

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 Wave-ViT模型的创建
class DWT_Function(PyLayer):
    @staticmethod
    def forward(ctx, x, w_ll, w_lh, w_hl, w_hh):
        x = x
        ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh)
        ctx.shape = x.shape

        dim = x.shape[1]
        x_ll = F.conv2d(x, w_ll.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
        x_lh = F.conv2d(x, w_lh.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
        x_hl = F.conv2d(x, w_hl.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
        x_hh = F.conv2d(x, w_hh.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
        x = paddle.concat([x_ll, x_lh, x_hl, x_hh], axis=1)
        return x

    @staticmethod
    def backward(ctx, dx):

        w_ll, w_lh, w_hl, w_hh = ctx.saved_tensor()
        B, C, H, W = ctx.shape
        dx = dx.reshape((B, 4, -1, H//2, W//2))

        dx = dx.transpose([0, 2, 1, 3, 4]).reshape((B, -1, H//2, W//2))
        filters = paddle.concat([w_ll, w_lh, w_hl, w_hh], axis=0)
        filters = repeat(filters, 'o i h w -> (repeat o) i h w', repeat=C)
        dx = F.conv2d_transpose(dx, filters, stride=2, groups=C)

        return dx, None, None, None, None
class IDWT_Function(PyLayer):
    @staticmethod
    def forward(ctx, x, filters):
        ctx.save_for_backward(filters)
        ctx.shape = x.shape

        B, _, H, W = x.shape
        x = x.reshape((B, 4, -1, H, W)).transpose([0, 2, 1, 3, 4])
        C = x.shape[1]
        x = x.reshape((B, -1, H, W))
        filters = repeat(filters, 'o i h w -> (repeat o) i h w', repeat=C)
        x = F.conv2d_transpose(x, filters, stride=2, groups=C)
        return x

    @staticmethod
    def backward(ctx, dx):

        filters = ctx.saved_tensor()
        filters = filters[0]
        B, C, H, W = ctx.shape
        C = C // 4
        dx = dx

        w_ll, w_lh, w_hl, w_hh = paddle.unbind(filters, axis=0)
        x_ll = F.conv2d(dx, w_ll.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
        x_lh = F.conv2d(dx, w_lh.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
        x_hl = F.conv2d(dx, w_hl.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
        x_hh = F.conv2d(dx, w_hh.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
        dx = paddle.concat([x_ll, x_lh, x_hl, x_hh], axis=1)

        return dx, None
class DWT_2D(nn.Layer):
    def __init__(self, wave):
        super(DWT_2D, self).__init__()
        w = pywt.Wavelet(wave)
        dec_hi = paddle.to_tensor(w.dec_hi[::-1], dtype='float32')
        dec_lo = paddle.to_tensor(w.dec_lo[::-1], dtype='float32')

        w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1)
        w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1)
        w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1)
        w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)

        self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0))
        self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0))
        self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0))
        self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0))

        self.w_ll = self.w_ll
        self.w_lh = self.w_lh
        self.w_hl = self.w_hl
        self.w_hh = self.w_hh

    def forward(self, x):
        return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh)
class IDWT_2D(nn.Layer):
    def __init__(self, wave):
        super(IDWT_2D, self).__init__()
        w = pywt.Wavelet(wave)
        rec_hi = paddle.to_tensor(w.rec_hi, dtype='float32')
        rec_lo = paddle.to_tensor(w.rec_lo, dtype='float32')

        w_ll = rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1)
        w_lh = rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1)
        w_hl = rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1)
        w_hh = rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)

        w_ll = w_ll.unsqueeze(0).unsqueeze(1)
        w_lh = w_lh.unsqueeze(0).unsqueeze(1)
        w_hl = w_hl.unsqueeze(0).unsqueeze(1)
        w_hh = w_hh.unsqueeze(0).unsqueeze(1)

        filters = paddle.concat([w_ll, w_lh, w_hl, w_hh], axis=0)
        self.register_buffer('filters', filters)

    def forward(self, x):
        return IDWT_Function.apply(x, self.filters)
class ClassAttention(nn.Layer):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.head_dim = head_dim
        self.scale = head_dim**-0.5
        self.kv = nn.Linear(dim, dim * 2)
        self.q = nn.Linear(dim, dim)
        self.proj = nn.Linear(dim, dim)


    def forward(self, x):
        B, N, C = x.shape
        kv = self.kv(x).reshape((B, N, 2, self.num_heads, self.head_dim)).transpose([2, 0, 3, 1, 4])
        k, v = kv[0], kv[1]
        q = self.q(x[:, :1, :]).reshape((B, self.num_heads, 1, self.head_dim))
        attn = ((q * self.scale) @ k.transpose([0, 1, 3, 2]))
        attn = F.softmax(attn, axis=-1)
        cls_embed = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, 1, self.head_dim * self.num_heads))
        cls_embed = self.proj(cls_embed)
        return cls_embed
class FFN(nn.Layer):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x
class ClassBlock(nn.Layer):
    def __init__(self, dim, num_heads, mlp_ratio, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = ClassAttention(dim, num_heads)
        self.mlp = FFN(dim, int(dim * mlp_ratio))


    def forward(self, x):
        cls_embed = x[:, :1]
        cls_embed = cls_embed + self.attn(self.norm1(x))
        cls_embed = cls_embed + self.mlp(self.norm2(cls_embed))
        return paddle.concat([cls_embed, x[:, 1:]], axis=1)
class DWConv(nn.Layer):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose([0, 2, 1]).reshape((B, C, H, W))
        x = self.dwconv(x)
        x = x.flatten(2).transpose([0, 2, 1])
        return x


class PVT2FFN(nn.Layer):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)


    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.fc2(x)
        return x
class WaveAttention(nn.Layer):
    def __init__(self, dim, num_heads, sr_ratio):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        self.sr_ratio = sr_ratio

        self.dwt = DWT_2D(wave='haar')
        self.idwt = IDWT_2D(wave='haar')
        self.reduce = nn.Sequential(
            nn.Conv2D(dim, dim//4, kernel_size=1, padding=0, stride=1),
            nn.BatchNorm2D(dim//4),
            nn.ReLU(),
        )
        self.filter = nn.Sequential(
            nn.Conv2D(dim, dim, kernel_size=3, padding=1, stride=1, groups=1),
            nn.BatchNorm2D(dim),
            nn.ReLU(),
        )
        self.kv_embed = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) if sr_ratio > 1 else nn.Identity()
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 2)
        )
        self.proj = nn.Linear(dim + dim // 4, dim)


    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape((B, N, self.num_heads, C // self.num_heads)).transpose([0, 2, 1, 3])

        x = x.reshape((B, H, W, C)).transpose([0, 3, 1, 2])
        x_dwt = self.dwt(self.reduce(x))
        x_dwt = self.filter(x_dwt)

        x_idwt = self.idwt(x_dwt)
        x_idwt = x_idwt.reshape((B, -1, x_idwt.shape[-2] * x_idwt.shape[-1])).transpose([0, 2, 1])

        kv = self.kv_embed(x_dwt).reshape((B, C, -1)).transpose([0, 2, 1])
        kv = self.kv(kv).reshape((B, -1, 2, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
        attn = F.softmax(attn, axis=-1)
        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = self.proj(paddle.concat([x, x_idwt], axis=-1))
        return x
class Attention(nn.Layer):
    def __init__(self, dim, num_heads):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)
        self.proj = nn.Linear(dim, dim)


    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape((B, N, self.num_heads, C // self.num_heads)).transpose([0, 2, 1, 3])
        kv = self.kv(x).reshape((B, -1, 2, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        k, v = kv[0], kv[1]
        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
        attn = F.softmax(attn, axis=-1)
        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = self.proj(x)
        return x
class Block(nn.Layer):
    def __init__(self,
        dim,
        num_heads,
        mlp_ratio,
        drop_path=0.,
        norm_layer=nn.LayerNorm,
        sr_ratio=1,
        block_type = 'wave'
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)

        if block_type == 'std_att':
            self.attn = Attention(dim, num_heads)
        else:
            self.attn = WaveAttention(dim, num_heads, sr_ratio)
        self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()


    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
        return x
class DownSamples(nn.Layer):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.proj = nn.Conv2D(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm = nn.LayerNorm(out_channels)


    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose([0, 2, 1])
        x = self.norm(x)
        return x, H, W
class Stem(nn.Layer):
    def __init__(self, in_channels, stem_hidden_dim, out_channels):
        super().__init__()
        hidden_dim = stem_hidden_dim
        self.conv = nn.Sequential(
            nn.Conv2D(in_channels, hidden_dim, kernel_size=7, stride=2,
                      padding=3, bias_attr=False),  # 112x112
            nn.BatchNorm2D(hidden_dim),
            nn.ReLU(),
            nn.Conv2D(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias_attr=False),  # 112x112
            nn.BatchNorm2D(hidden_dim),
            nn.ReLU(),
            nn.Conv2D(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias_attr=False),  # 112x112
            nn.BatchNorm2D(hidden_dim),
            nn.ReLU(),
        )
        self.proj = nn.Conv2D(hidden_dim,
                              out_channels,
                              kernel_size=3,
                              stride=2,
                              padding=1)
        self.norm = nn.LayerNorm(out_channels)


    def forward(self, x):
        x = self.conv(x)
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose([0, 2, 1])
        x = self.norm(x)
        return x, H, W
class WaveViT(nn.Layer):
    def __init__(self,
        in_chans=3,
        num_classes=1000,
        stem_hidden_dim = 32,
        embed_dims=[64, 128, 320, 448],
        num_heads=[2, 4, 10, 14],
        mlp_ratios=[8, 8, 4, 4],
        drop_path_rate=0.,
        norm_layer=nn.LayerNorm,
        depths=[3, 4, 6, 3],
        sr_ratios=[4, 2, 1, 1],
        num_stages=4,
        token_label=True,
        **kwargs
    ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0

        for i in range(num_stages):
            if i == 0:
                patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
            else:
                patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])

            block = nn.LayerList([Block(
                dim = embed_dims[i],
                num_heads = num_heads[i],
                mlp_ratio = mlp_ratios[i],
                drop_path=dpr[cur + j],
                norm_layer=norm_layer,
                sr_ratio = sr_ratios[i],
                block_type='wave' if i < 2 else 'std_att')
            for j in range(depths[i])])

            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        post_layers = ['ca']
        self.post_network = nn.LayerList([
            ClassBlock(
                dim = embed_dims[-1],
                num_heads = num_heads[-1],
                mlp_ratio = mlp_ratios[-1],
                norm_layer=norm_layer)
            for _ in range(len(post_layers))
        ])

        # classification head
        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        km = nn.initializer.KaimingNormal()
        one = nn.initializer.Constant(1.0)
        zero = nn.initializer.Constant(0.0)
        if isinstance(m, nn.Linear):
            tn(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zero(m.bias)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2D)):
            zero(m.bias)
            one(m.weight)
        elif isinstance(m, nn.Conv2D):
            km(m.weight)
            if m.bias is not None:
                zero(m.bias)

    def forward_cls(self, x):
        B, N, C = x.shape
        cls_tokens = x.mean(axis=1, keepdim=True)
        x = paddle.concat((cls_tokens, x), axis=1)
        for block in self.post_network:
            x = block(x)
        return x

    def forward_features(self, x):
        B = x.shape[0]
        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)

            if i != self.num_stages - 1:
                norm = getattr(self, f"norm{i + 1}")
                x = norm(x)
                x = x.reshape((B, H, W, -1)).transpose([0, 3, 1, 2])

        x = self.forward_cls(x)[:, 0]
        norm = getattr(self, f"norm{self.num_stages}")
        x = norm(x)
        return x

    def forward(self, x):

        x = self.forward_features(x)
        x = self.head(x)
        return x

    def forward_tokens(self, x, H, W):
        B = x.shape[0]
        x = x.reshape((B, -1, x.shape[-1]))

        for i in range(self.num_stages):
            if i != 0:
                patch_embed = getattr(self, f"patch_embed{i + 1}")
                x, H, W = patch_embed(x)

            block = getattr(self, f"block{i + 1}")
            for blk in block:
                x = blk(x, H, W)

            if i != self.num_stages - 1:
                norm = getattr(self, f"norm{i + 1}")
                x = norm(x)
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x = self.forward_cls(x)
        norm = getattr(self, f"norm{self.num_stages}")
        x = norm(x)
        return x
def wavevit_s(pretrained=False, **kwargs):
    model = WaveViT(
        stem_hidden_dim = 32,
        embed_dims = [64, 128, 320, 448],
        num_heads = [2, 4, 10, 14],
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, epsilon=1e-6),
        depths = [3, 4, 6, 3],
        sr_ratios = [4, 2, 1, 1],
        **kwargs)
    return model


def wavevit_b(pretrained=False, **kwargs):
    model = WaveViT(
        stem_hidden_dim = 64,
        embed_dims = [64, 128, 320, 512],
        num_heads = [2, 4, 10, 16],
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, epsilon=1e-6),
        depths = [3, 4, 12, 3],
        sr_ratios = [4, 2, 1, 1],
        **kwargs)
    return model


def wavevit_l(pretrained=False, **kwargs):
    model = WaveViT(
        stem_hidden_dim = 64,
        embed_dims = [96, 192, 384, 512],
        num_heads = [3, 6, 12, 16],
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, epsilon=1e-6),
        depths = [3, 6, 18, 3],
        sr_ratios = [4, 2, 1, 1],
        **kwargs)
    return model
2.3.4 模型的参数
model = wavevit_s(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))

model = wavevit_b(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))

model = wavevit_l(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))

2.4 训练

learning_rate = 0.0001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# WaveViT-Small
model = wavevit_s(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = wavevit_s(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:483
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = wavevit_s(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = wavevit_s(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [01:05<00:00, 152.78it/s]

05<00:00, 152.78it/s]

在这里插入图片描述

总结

        在本文中,作者深入研究了统一典型 Transformer 模块和可逆下采样的思想,从而通过无损下采样实现高效的多尺度自注意力学习。提出了一个新的 Transformer 模块,即 Wavelets 模块,它利用离散小波变换 (DWT) 在自注意力学习中对Key/Value执行可逆下采样。此外,还采用逆 DWT (IDWT) 来重建下采样的 DWT 输出,通过聚合具有扩大的感受野的局部上下文来增强小波块的输出。

参考文献

  1. Wave-ViT: Unifying Wavelet and Transformers for Visual Representation Learning
  2. YehLi/ImageNetModel
  3. 【ECCV2022】Wave-ViT: Unifying Wavelet and Transformers for Visual Representation Learning

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐