【ECCV 2022】Wave-ViT:统一小波和ViT的视觉表示学习
Wave-ViT: Unifying Wavelet and Transformers for Visual Representation Learning论文浅析
★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
摘要
多尺度视觉Transformer(ViT)已成为计算机视觉的重要支柱,而Transformer中的自注意力计算与输入Patch数呈二次方关系。 因此,现有的解决方案通常在键/值上使用向下采样操作(例如,平均池化)以显著降低计算成本。 在这项工作中,我们认为这种过度激进的下采样设计是不可逆的,不可避免地会导致信息丢失,尤其是对于对象中的高频成分(如纹理细节)。 在小波理论的启发下,我们构造了一种新的小波视觉Transformer(Wave-ViT),将可逆下采样与小波变换和自注意学习统一起来。 该方案实现了对键/值进行无损下采样的自注意力学习,有助于追求更好的效率与精度的权衡。 此外,通过聚集局部上下文与扩大的感受野,逆小波变换被用来增强自注意力输出。 通过对多个视觉任务(如图像识别、目标检测和实例分割)的大量实验,验证了Wave-VIT的优越性。 在可比的FLOPs下,它的性能超过了最先进的ViT骨干。
1. Wave-ViT
1.1 离散小波变换
余弦变换是经典的谱分析工具,他考察的是整个时域过程的频域特征或整个频域过程的时域特征,因此对于平稳过程,他有很好的效果,但对于非平稳过程,他却有诸多不足。在JPEG中,离散余弦变换将图像压缩为8×8 的小块,然后依次放入文件中,这种算法靠丢弃频率信息实现压缩,因而图像的压缩率越高,频率信息被丢弃的越多。在极端情况下,JPEG图像只保留了反映图像外貌的基本信息,精细的图像细节都损失了。小波变换是现代谱分析工具,他既能考察局部时域过程的频域特征,又能考察局部频域过程的时域特征,因此即使对于非平稳过程,处理起来也得心应手。它能将图像变换为一系列小波系数,这些系数可以被高效压缩和存储,此外,小波的粗略边缘可以更好地表现图像,因为他消除了DCT压缩普遍具有的方块效应。
Haar小波变换是最简单和最早的一种小波变换方法。它是由Alfred Haar在1910年提出的,因此得名为Haar小波。它的特点是具有简单的计算过程和紧凑的表示形式。Haar小波变换将信号分解为两个子信号,一个是低频近似分量(Approximation),另一个是高频细节分量(Detail)。经由Haar小波变换,如下图所示可以得到四个分量:低频分量、水平高频分量、竖直高频分量、对角线高频分量。同时也可以通过这四个分量无损重构原有的图像。
!pip install pywavelets
import numpy as np
from matplotlib import pyplot as plt
import pywt
import PIL
img = PIL.Image.open("catdog.jpg")
img = np.array(img)[:, :, 0]
plt.imshow(img, cmap='gray')
plt.show()
LLY, (LHY, HLY, HHY) = pywt.dwt2(img, 'haar')
plt.subplot(2, 2, 1)
plt.imshow(LLY, cmap="Greys")
plt.subplot(2, 2, 2)
plt.imshow(LHY, cmap="Greys")
plt.subplot(2, 2, 3)
plt.imshow(HLY, cmap="Greys")
plt.subplot(2, 2, 4)
plt.imshow(HHY, cmap="Greys")
plt.show()
img = pywt.idwt2((LLY, (LHY, HLY, HHY)), 'haar')
plt.imshow(img, cmap='gray')
plt.show()
1.2 小波变换块(Wavelets Block)
如图2(a)所示,原始的自注意力机制与输入数据呈二次方,计算量大。为此,一些方法使用如图2(b)所示的方式对K和V进行下采样以减少计算,但是由于平均池化操作会导致高频细节的丢失从而影响性能。因此如图2(c)所示,本文结合小波变换提出了一种具有无损可逆下采样的自注意力块——小波变换块。
首先本文使用经典的Haar小波变换得到四个分量,并将其合并并使用卷积生成K和V,然后与Q进行自注意力机制,同时将四个分量使用离散小波逆变换并将其与自注意力得到的结果合并并使用线性层来融合特征得到指定通道数的特征图(根据小波理论,重构的特征图 X r X^r Xr 能够保留原始输入 $\tilde{x} $ 的各个细节。 值得注意的是,与单个3×3卷积相比,这种小波块中的DWT-卷积-IDWT过程在增大感受野的情况下触发了更强的局部上下文化,而计算成本/内存的增加几乎可以忽略不计)。计算过程如下公式所示:
head j = Attention w ( Q j , K j w , V j w ) = Softmax ( Q j K j w T D h ) V j w WaveletsBlock ( X ) = MultiHead w ( X W q , X c W k , X c W v , X r ) MultiHead w ( Q , K , V , X r ) = Concat ( head 0 , head 1 , … , h e a d N h , X r ) W ~ O \begin{array}{l} \operatorname{head}_{j}=\operatorname{Attention}^{\mathbf{w}}\left(Q_{j}, K_{j}^{w}, V_{j}^{w}\right)=\operatorname{Softmax}\left(\frac{Q_{j} K_{j}^{w T}}{\sqrt{D_{h}}}\right) V_{j}^{w}\\ \text { WaveletsBlock }(X)=\operatorname{MultiHead}^{\mathbf{w}}\left(X W^{q}, X^{c} W^{k}, X^{c} W^{v}, X^{r}\right) \\ \operatorname{MultiHead}^{\mathbf{w}}\left(Q, K, V, X^{r}\right)=\operatorname{Concat}\left(\text {head}_{0}, \operatorname{head}_{1}, \ldots, head_{N_{h}}, X^{r}\right) \tilde{W}^{O} \end{array} headj=Attentionw(Qj,Kjw,Vjw)=Softmax(DhQjKjwT)Vjw WaveletsBlock (X)=MultiHeadw(XWq,XcWk,XcWv,Xr)MultiHeadw(Q,K,V,Xr)=Concat(head0,head1,…,headNh,Xr)W~O
2. 代码复现
2.1 下载并导入所需的库
!pip install einops-0.3.0-py3-none-any.whl
!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from einops import rearrange, repeat
import pywt
from paddle.autograd import PyLayer
from functools import partial
2.2 创建数据集
train_tfm = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(20),
paddlex.transforms.MixupImage(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
test_tfm = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
2.3 模型的创建
2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
def __init__(self, smoothing=0.1):
super().__init__()
self.smoothing = smoothing
def forward(self, pred, target):
confidence = 1. - self.smoothing
log_probs = F.log_softmax(pred, axis=-1)
idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
nll_loss = paddle.gather_nd(-log_probs, index=idx)
smooth_loss = paddle.mean(-log_probs, axis=-1)
loss = confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class DropPath(nn.Layer):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
2.3.3 Wave-ViT模型的创建
class DWT_Function(PyLayer):
@staticmethod
def forward(ctx, x, w_ll, w_lh, w_hl, w_hh):
x = x
ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh)
ctx.shape = x.shape
dim = x.shape[1]
x_ll = F.conv2d(x, w_ll.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
x_lh = F.conv2d(x, w_lh.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
x_hl = F.conv2d(x, w_hl.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
x_hh = F.conv2d(x, w_hh.expand((dim, -1, -1, -1)), stride = 2, groups = dim)
x = paddle.concat([x_ll, x_lh, x_hl, x_hh], axis=1)
return x
@staticmethod
def backward(ctx, dx):
w_ll, w_lh, w_hl, w_hh = ctx.saved_tensor()
B, C, H, W = ctx.shape
dx = dx.reshape((B, 4, -1, H//2, W//2))
dx = dx.transpose([0, 2, 1, 3, 4]).reshape((B, -1, H//2, W//2))
filters = paddle.concat([w_ll, w_lh, w_hl, w_hh], axis=0)
filters = repeat(filters, 'o i h w -> (repeat o) i h w', repeat=C)
dx = F.conv2d_transpose(dx, filters, stride=2, groups=C)
return dx, None, None, None, None
class IDWT_Function(PyLayer):
@staticmethod
def forward(ctx, x, filters):
ctx.save_for_backward(filters)
ctx.shape = x.shape
B, _, H, W = x.shape
x = x.reshape((B, 4, -1, H, W)).transpose([0, 2, 1, 3, 4])
C = x.shape[1]
x = x.reshape((B, -1, H, W))
filters = repeat(filters, 'o i h w -> (repeat o) i h w', repeat=C)
x = F.conv2d_transpose(x, filters, stride=2, groups=C)
return x
@staticmethod
def backward(ctx, dx):
filters = ctx.saved_tensor()
filters = filters[0]
B, C, H, W = ctx.shape
C = C // 4
dx = dx
w_ll, w_lh, w_hl, w_hh = paddle.unbind(filters, axis=0)
x_ll = F.conv2d(dx, w_ll.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
x_lh = F.conv2d(dx, w_lh.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
x_hl = F.conv2d(dx, w_hl.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
x_hh = F.conv2d(dx, w_hh.unsqueeze(1).expand((C, -1, -1, -1)), stride = 2, groups = C)
dx = paddle.concat([x_ll, x_lh, x_hl, x_hh], axis=1)
return dx, None
class DWT_2D(nn.Layer):
def __init__(self, wave):
super(DWT_2D, self).__init__()
w = pywt.Wavelet(wave)
dec_hi = paddle.to_tensor(w.dec_hi[::-1], dtype='float32')
dec_lo = paddle.to_tensor(w.dec_lo[::-1], dtype='float32')
w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1)
w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1)
w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1)
w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)
self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0))
self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0))
self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0))
self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0))
self.w_ll = self.w_ll
self.w_lh = self.w_lh
self.w_hl = self.w_hl
self.w_hh = self.w_hh
def forward(self, x):
return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh)
class IDWT_2D(nn.Layer):
def __init__(self, wave):
super(IDWT_2D, self).__init__()
w = pywt.Wavelet(wave)
rec_hi = paddle.to_tensor(w.rec_hi, dtype='float32')
rec_lo = paddle.to_tensor(w.rec_lo, dtype='float32')
w_ll = rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1)
w_lh = rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1)
w_hl = rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1)
w_hh = rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)
w_ll = w_ll.unsqueeze(0).unsqueeze(1)
w_lh = w_lh.unsqueeze(0).unsqueeze(1)
w_hl = w_hl.unsqueeze(0).unsqueeze(1)
w_hh = w_hh.unsqueeze(0).unsqueeze(1)
filters = paddle.concat([w_ll, w_lh, w_hl, w_hh], axis=0)
self.register_buffer('filters', filters)
def forward(self, x):
return IDWT_Function.apply(x, self.filters)
class ClassAttention(nn.Layer):
def __init__(self, dim, num_heads):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.kv = nn.Linear(dim, dim * 2)
self.q = nn.Linear(dim, dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, C = x.shape
kv = self.kv(x).reshape((B, N, 2, self.num_heads, self.head_dim)).transpose([2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
q = self.q(x[:, :1, :]).reshape((B, self.num_heads, 1, self.head_dim))
attn = ((q * self.scale) @ k.transpose([0, 1, 3, 2]))
attn = F.softmax(attn, axis=-1)
cls_embed = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, 1, self.head_dim * self.num_heads))
cls_embed = self.proj(cls_embed)
return cls_embed
class FFN(nn.Layer):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class ClassBlock(nn.Layer):
def __init__(self, dim, num_heads, mlp_ratio, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = ClassAttention(dim, num_heads)
self.mlp = FFN(dim, int(dim * mlp_ratio))
def forward(self, x):
cls_embed = x[:, :1]
cls_embed = cls_embed + self.attn(self.norm1(x))
cls_embed = cls_embed + self.mlp(self.norm2(cls_embed))
return paddle.concat([cls_embed, x[:, 1:]], axis=1)
class DWConv(nn.Layer):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose([0, 2, 1]).reshape((B, C, H, W))
x = self.dwconv(x)
x = x.flatten(2).transpose([0, 2, 1])
return x
class PVT2FFN(nn.Layer):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.fc2(x)
return x
class WaveAttention(nn.Layer):
def __init__(self, dim, num_heads, sr_ratio):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.sr_ratio = sr_ratio
self.dwt = DWT_2D(wave='haar')
self.idwt = IDWT_2D(wave='haar')
self.reduce = nn.Sequential(
nn.Conv2D(dim, dim//4, kernel_size=1, padding=0, stride=1),
nn.BatchNorm2D(dim//4),
nn.ReLU(),
)
self.filter = nn.Sequential(
nn.Conv2D(dim, dim, kernel_size=3, padding=1, stride=1, groups=1),
nn.BatchNorm2D(dim),
nn.ReLU(),
)
self.kv_embed = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) if sr_ratio > 1 else nn.Identity()
self.q = nn.Linear(dim, dim)
self.kv = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 2)
)
self.proj = nn.Linear(dim + dim // 4, dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape((B, N, self.num_heads, C // self.num_heads)).transpose([0, 2, 1, 3])
x = x.reshape((B, H, W, C)).transpose([0, 3, 1, 2])
x_dwt = self.dwt(self.reduce(x))
x_dwt = self.filter(x_dwt)
x_idwt = self.idwt(x_dwt)
x_idwt = x_idwt.reshape((B, -1, x_idwt.shape[-2] * x_idwt.shape[-1])).transpose([0, 2, 1])
kv = self.kv_embed(x_dwt).reshape((B, C, -1)).transpose([0, 2, 1])
kv = self.kv(kv).reshape((B, -1, 2, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
attn = F.softmax(attn, axis=-1)
x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
x = self.proj(paddle.concat([x, x_idwt], axis=-1))
return x
class Attention(nn.Layer):
def __init__(self, dim, num_heads):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape((B, N, self.num_heads, C // self.num_heads)).transpose([0, 2, 1, 3])
kv = self.kv(x).reshape((B, -1, 2, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
attn = F.softmax(attn, axis=-1)
x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
x = self.proj(x)
return x
class Block(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio,
drop_path=0.,
norm_layer=nn.LayerNorm,
sr_ratio=1,
block_type = 'wave'
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
if block_type == 'std_att':
self.attn = Attention(dim, num_heads)
else:
self.attn = WaveAttention(dim, num_heads, sr_ratio)
self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class DownSamples(nn.Layer):
def __init__(self, in_channels, out_channels):
super().__init__()
self.proj = nn.Conv2D(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.norm = nn.LayerNorm(out_channels)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose([0, 2, 1])
x = self.norm(x)
return x, H, W
class Stem(nn.Layer):
def __init__(self, in_channels, stem_hidden_dim, out_channels):
super().__init__()
hidden_dim = stem_hidden_dim
self.conv = nn.Sequential(
nn.Conv2D(in_channels, hidden_dim, kernel_size=7, stride=2,
padding=3, bias_attr=False), # 112x112
nn.BatchNorm2D(hidden_dim),
nn.ReLU(),
nn.Conv2D(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, bias_attr=False), # 112x112
nn.BatchNorm2D(hidden_dim),
nn.ReLU(),
nn.Conv2D(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, bias_attr=False), # 112x112
nn.BatchNorm2D(hidden_dim),
nn.ReLU(),
)
self.proj = nn.Conv2D(hidden_dim,
out_channels,
kernel_size=3,
stride=2,
padding=1)
self.norm = nn.LayerNorm(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose([0, 2, 1])
x = self.norm(x)
return x, H, W
class WaveViT(nn.Layer):
def __init__(self,
in_chans=3,
num_classes=1000,
stem_hidden_dim = 32,
embed_dims=[64, 128, 320, 448],
num_heads=[2, 4, 10, 14],
mlp_ratios=[8, 8, 4, 4],
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
sr_ratios=[4, 2, 1, 1],
num_stages=4,
token_label=True,
**kwargs
):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
else:
patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])
block = nn.LayerList([Block(
dim = embed_dims[i],
num_heads = num_heads[i],
mlp_ratio = mlp_ratios[i],
drop_path=dpr[cur + j],
norm_layer=norm_layer,
sr_ratio = sr_ratios[i],
block_type='wave' if i < 2 else 'std_att')
for j in range(depths[i])])
norm = norm_layer(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
post_layers = ['ca']
self.post_network = nn.LayerList([
ClassBlock(
dim = embed_dims[-1],
num_heads = num_heads[-1],
mlp_ratio = mlp_ratios[-1],
norm_layer=norm_layer)
for _ in range(len(post_layers))
])
# classification head
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
tn = nn.initializer.TruncatedNormal(std=.02)
km = nn.initializer.KaimingNormal()
one = nn.initializer.Constant(1.0)
zero = nn.initializer.Constant(0.0)
if isinstance(m, nn.Linear):
tn(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zero(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2D)):
zero(m.bias)
one(m.weight)
elif isinstance(m, nn.Conv2D):
km(m.weight)
if m.bias is not None:
zero(m.bias)
def forward_cls(self, x):
B, N, C = x.shape
cls_tokens = x.mean(axis=1, keepdim=True)
x = paddle.concat((cls_tokens, x), axis=1)
for block in self.post_network:
x = block(x)
return x
def forward_features(self, x):
B = x.shape[0]
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
if i != self.num_stages - 1:
norm = getattr(self, f"norm{i + 1}")
x = norm(x)
x = x.reshape((B, H, W, -1)).transpose([0, 3, 1, 2])
x = self.forward_cls(x)[:, 0]
norm = getattr(self, f"norm{self.num_stages}")
x = norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def forward_tokens(self, x, H, W):
B = x.shape[0]
x = x.reshape((B, -1, x.shape[-1]))
for i in range(self.num_stages):
if i != 0:
patch_embed = getattr(self, f"patch_embed{i + 1}")
x, H, W = patch_embed(x)
block = getattr(self, f"block{i + 1}")
for blk in block:
x = blk(x, H, W)
if i != self.num_stages - 1:
norm = getattr(self, f"norm{i + 1}")
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.forward_cls(x)
norm = getattr(self, f"norm{self.num_stages}")
x = norm(x)
return x
def wavevit_s(pretrained=False, **kwargs):
model = WaveViT(
stem_hidden_dim = 32,
embed_dims = [64, 128, 320, 448],
num_heads = [2, 4, 10, 14],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, epsilon=1e-6),
depths = [3, 4, 6, 3],
sr_ratios = [4, 2, 1, 1],
**kwargs)
return model
def wavevit_b(pretrained=False, **kwargs):
model = WaveViT(
stem_hidden_dim = 64,
embed_dims = [64, 128, 320, 512],
num_heads = [2, 4, 10, 16],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, epsilon=1e-6),
depths = [3, 4, 12, 3],
sr_ratios = [4, 2, 1, 1],
**kwargs)
return model
def wavevit_l(pretrained=False, **kwargs):
model = WaveViT(
stem_hidden_dim = 64,
embed_dims = [96, 192, 384, 512],
num_heads = [3, 6, 12, 16],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, epsilon=1e-6),
depths = [3, 6, 18, 3],
sr_ratios = [4, 2, 1, 1],
**kwargs)
return model
2.3.4 模型的参数
model = wavevit_s(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))
model = wavevit_b(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))
model = wavevit_l(num_classes=10)
paddle.summary(model, (1, 3, 224, 224))
2.4 训练
learning_rate = 0.0001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'
# WaveViT-Small
model = wavevit_s(num_classes=10)
criterion = LabelSmoothingCrossEntropy()
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)
gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}} # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}} # for recording accuracy
loss_iter = 0
acc_iter = 0
for epoch in range(n_epochs):
# ---------- Training ----------
model.train()
train_num = 0.0
train_loss = 0.0
val_num = 0.0
val_loss = 0.0
accuracy_manager = paddle.metric.Accuracy()
val_accuracy_manager = paddle.metric.Accuracy()
print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
for batch_id, data in enumerate(train_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
logits = model(x_data)
loss = criterion(logits, y_data)
acc = paddle.metric.accuracy(logits, labels)
accuracy_manager.update(acc)
if batch_id % 10 == 0:
loss_record['train']['loss'].append(loss.numpy())
loss_record['train']['iter'].append(loss_iter)
loss_iter += 1
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
train_loss += loss
train_num += len(y_data)
total_train_loss = (train_loss / train_num) * batch_size
train_acc = accuracy_manager.accumulate()
acc_record['train']['acc'].append(train_acc)
acc_record['train']['iter'].append(acc_iter)
acc_iter += 1
# Print the information.
print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))
# ---------- Validation ----------
model.eval()
for batch_id, data in enumerate(val_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
with paddle.no_grad():
logits = model(x_data)
loss = criterion(logits, y_data)
acc = paddle.metric.accuracy(logits, labels)
val_accuracy_manager.update(acc)
val_loss += loss
val_num += len(y_data)
total_val_loss = (val_loss / val_num) * batch_size
loss_record['val']['loss'].append(total_val_loss.numpy())
loss_record['val']['iter'].append(loss_iter)
val_acc = val_accuracy_manager.accumulate()
acc_record['val']['acc'].append(val_acc)
acc_record['val']['iter'].append(acc_iter)
print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))
# ===================save====================
if val_acc > best_acc:
best_acc = val_acc
paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))
print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
2.5 结果分析
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
''' Plot learning curve of your CNN '''
maxtrain = max(map(float, record['train'][title]))
maxval = max(map(float, record['val'][title]))
ymax = max(maxtrain, maxval) * 1.1
mintrain = min(map(float, record['train'][title]))
minval = min(map(float, record['val'][title]))
ymin = min(mintrain, minval) * 0.9
total_steps = len(record['train'][title])
x_1 = list(map(int, record['train']['iter']))
x_2 = list(map(int, record['val']['iter']))
figure(figsize=(10, 6))
plt.plot(x_1, record['train'][title], c='tab:red', label='train')
plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
plt.ylim(ymin, ymax)
plt.xlabel('Training steps')
plt.ylabel(ylabel)
plt.title('Learning curve of {}'.format(title))
plt.legend()
plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
import time
work_path = 'work/model'
model = wavevit_s(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
with paddle.no_grad():
logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:483
def get_cifar10_labels(labels):
"""返回CIFAR10数据集的文本标签。"""
text_labels = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if paddle.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if pred or gt:
ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = wavevit_s(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = wavevit_s(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [01:05<00:00, 152.78it/s]
05<00:00, 152.78it/s]
总结
在本文中,作者深入研究了统一典型 Transformer 模块和可逆下采样的思想,从而通过无损下采样实现高效的多尺度自注意力学习。提出了一个新的 Transformer 模块,即 Wavelets 模块,它利用离散小波变换 (DWT) 在自注意力学习中对Key/Value执行可逆下采样。此外,还采用逆 DWT (IDWT) 来重建下采样的 DWT 输出,通过聚合具有扩大的感受野的局部上下文来增强小波块的输出。
参考文献
- Wave-ViT: Unifying Wavelet and Transformers for Visual Representation Learning
- YehLi/ImageNetModel
- 【ECCV2022】Wave-ViT: Unifying Wavelet and Transformers for Visual Representation Learning
此文章为搬运
原项目链接
更多推荐
所有评论(0)