★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

摘要

        视觉Transformer在计算机视觉任务中显示出巨大的潜力。 最近的工作主要集中在详细阐述空间令牌混合器以提高性能。 然而,我们观察到,设计良好的通用架构可以显著提高整个骨干网的性能,无论配备哪种空间令牌混合器。 本文提出了一种改进的视觉骨干网通用体系结构UniNext。 为了验证它的有效性,我们用各种典型的和现代的设计实例化了空间令牌混合器,包括卷积和注意模块。 与提出的体系结构相比,我们的UniNext体系结构可以稳定地提高所有空间令牌混合器的性能,缩小它们之间的性能差距。 令人惊讶的是,我们的UniNext配备了原始的本地窗口注意力,其性能甚至超过了以前的最先进水平。 有趣的是,这些空间令牌混合器在UniNext下的排名也发生了变化,表明一个优秀的空间令牌混合器可能会因为一个次优的通用架构而被扼杀,这进一步说明了对视觉骨干网通用架构研究的重要性。

1. UniNext

        本文集中探讨了计算机视觉任务中的通用架构问题。最近的一些研究集中于通过精心设计的空间令牌混合器(Spatial Token Mixer, STM)来提高性能。但是,作者认为一个设计良好的通用架构可以显著提高整个骨干网络的性能,而不论配备哪种空间标记混合器。因此,本文提出了UniNeXt,一种改进的通用架构。
        为了验证其有效性,本文采用了各种经典和现代化设计实例化了 STM。实现结果表明,与它们最初提出的架构相比,所提架构可以稳定地提高所有 STM 的性能,并缩小它们之间的性能差距。令人惊讶的是,当配备最原始的局部窗口注意力的UniNeXt甚至优于之前的最新技术,这表明优秀的 STM 可能会因为通用架构的次优性而受到压制,这进一步显示了对视觉骨干网络通用架构的研究的重要性。

        在本文中,UniNeXt的主要设计思想是通过三种方法增加归纳偏差:

  1. 添加并行的 EC 分支到 STM
  2. 在FFN后添加 PC 模块
  3. 在 FFN 中添加 3×3 深度可分离卷积 Hdc

        如图2所示,UniNeXt 整体依旧沿用了常见的金字塔结构,包含四个分层阶段,每个阶段都包含一个下采样层和多个Unified Blocks。在下采样层中,每个阶段的空间下采样比例为 2,通道数扩展两倍。统一块中的 TOken 数量保持一致。最后,本文方法应用全局平均池化(Global Average Pooling, GAP)和全连接层来执行图像分类任务。

1.1 High-dimensional Convolution (HdC)

        HdC是一种轻量的 深度可分离卷积,它可以将高维特征进行局部融合,从而编码高维隐式特征,提高模型的性能。作者继承和扩展了这种卷积嵌入机制,首先使用 MLP 的第一个线性层将特征维度映射到高维特征 F,然后使用深度可分离卷积来进行局部融合,从而编码高维隐式特征,提高模型的效率。通过这种方法,UniNeXt 可以更好地建模高维特征,并且在保持模型轻量化的同时,获得更好的性能。

1.2 Embedded Convolution (EC)

        其次,本文进一步提出了Embedded Convolution,主要有两个方面的动机:

  1. EC 增强了模型的归纳偏差,这对于学习和泛化至关重要。
  2. EC 与所有现有的 Token Mixer 兼容,提供了跨各种架构的灵活性和易于实现性。

        因此,EC 能够增强模型对于空间信息的建模,从而提高模型的性能。此外,由于 EC 是一种通用的卷积操作,能够兼容各种 STM,因此可以在各种视觉任务中进行灵活使用。

1.3 Post Convolution (PC)

        最后,为了进一步增强卷积嵌入偏差和增强局部表示,本文精心设计了一种卷积结构——Post Convolution。其中,PC 能够通过引入轻量级的3×3深度卷积,在特别是密集预测任务中显著提高模型性能。具体而言,首先将 Token 矩阵转换为 2D 图像表示,然后进行深度卷积以进行局部上下文融合,紧接着进行 flatten 操作,最后使用残差连接以防止权重过度缩放。通过这种方法能够有效增强模型对于图像局部信息的建模,进而提高模型的性能。

2. 代码复现

2.1 下载并导入所需的库

!pip install einops-0.3.0-py3-none-any.whl
!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from einops import rearrange

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 UniNext模型的创建
def local_group(x, H, W, ws):
    '''
    ws: window size
    x: BNC
    return: BG ws*ws C
    '''
    B, _, C = x.shape
    pad_right, pad_bottom, pad_opt = 0, 0, False

    if H % ws != 0 or W % ws != 0:
        pad_opt =True
        # reshape (B, N, C) -> (B, H, W, C)
        x = x.reshape((B, H, W, C))
        # padding right & below
        pad_right = ws - W % ws
        pad_bottom = ws - H % ws
        x = F.pad(x, (0, pad_right, 0, pad_bottom), data_format='NHWC')
        H = H + pad_bottom
        W = W + pad_right
        N = H * W
        # reshape (B, H, W, C) -> (B, N, C)
        x = x.reshape((B, N, C))
    Gh = H//ws
    Gw = W//ws
    x = x.reshape((B, Gh, ws, Gw, ws, C)).transpose([0, 1, 3, 2, 4, 5]).reshape((B * Gh * Gw, ws * ws, C))

    return x, H, W, pad_right, pad_bottom, pad_opt



def img2group(x, H, W, ws, num_head):
    '''
    x: B, H*W, C
    return : (B G) head  N C
    '''
    # After group x:B G N C
    x, H, W, pad_right, pad_bottom, pad_opt = local_group(x, H, W, ws)
    B, N, C =x.shape
    x = x.reshape((B, N, num_head, C//num_head)).transpose([0, 2, 1, 3])

    return x, H, W, pad_right, pad_bottom, pad_opt


def group2image(x, H, W, pad_right, pad_bottom, pad_opt, ws):
    # Input x: (BG G) Head n C
    # Output x: B N C

    BG, Head, n, C = x.shape
    Gh, Gw = H // ws, W // ws
    Gn = Gh * Gw
    nb1 = BG // Gn
    x = x.reshape((nb1, Gh, Gw, Head, ws, ws, C)).transpose([0, 1, 4, 2, 5, 3, 6]).reshape((nb1, -1, Head * C))

    if pad_opt:
        x = x.reshape((nb1, H, W, Head * C))
        x = x[:, :(H - pad_bottom), :(W - pad_right), :]
        x = x.reshape((nb1, -1, Head * C))

    return x
# dwconv MLP
class Mlp(nn.Layer):
    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()

        self.dwconv = nn.Conv2D(hidden_features, hidden_features, 3, 1, 1, groups=hidden_features)
        self.norm_act = nn.Sequential(
            nn.LayerNorm(hidden_features),
            nn.GELU()
        )

        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        B, N, C = x.shape
        x1 = x.transpose([0, 2, 1]).reshape((B, C, H, W))
        x1 = self.dwconv(x1)
        x1 = x1.reshape((B, C, -1)).transpose([0, 2, 1])
        x1 = self.norm_act(x1)
        x = x + x1
        x = self.fc2(x)
        x = self.drop(x)
        return x
class DilatedAttention(nn.Layer):
    def __init__(self, dim, ws=7, num_heads=8, attn_drop=0., qk_scale=None):
        super().__init__()
        self.dim = dim
        self.ws = ws
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.attn_drop = nn.Dropout(attn_drop)
        self.lepe = nn.Conv2D(in_channels=dim, out_channels=dim, kernel_size=3, stride=1, padding=1, groups=dim, bias_attr=True)
        self.proj = nn.Linear(dim, dim)

    def forward(self, qkv, H, W):
        """
        qkv: B N C  after proj
        H, W: img h and w
        """
        q, k, v = qkv[0], qkv[1], qkv[2]
        # lepe
        B, _, vc = v.shape
        lepe = v.transpose([0, 2, 1]).reshape((B, vc, H, W))
        lepe = self.lepe(lepe)
        lepe = lepe.reshape((B, vc, -1)).transpose([0, 2, 1])

        N = q.shape[1]
        assert N == H * W, "flatten img_tokens has wrong size"
        # B,N,C -->BG,H,N,C
        q, H_new, W_new, pad_right, pad_bottom, pad_opt = img2group(q, H, W, self.ws, self.num_heads)
        k, _, _, _, _, _ = img2group(k, H, W, self.ws, self.num_heads)
        v, _, _, _, _, _ = img2group(v, H, W, self.ws, self.num_heads)

        q = q * self.scale
        # B head N C @ B head C N --> B head N N
        attn = (q @ k.transpose([0, 1, 3, 2]))
        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)
        x = attn @ v
        # x: BHnC -> BNC
        x = group2image(x, H_new, W_new, pad_right, pad_bottom, pad_opt, self.ws)

        # add lepe
        x = x + lepe

        # proj
        x = self.proj(x)
        return x
class UnifiedBlock(nn.Layer):
    def __init__(self,
                 dim,
                 num_heads,
                 ws=7,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 ):
        super().__init__()
        self.mlp_ratio = mlp_ratio
        mlp_hidden_dim = int(dim * mlp_ratio)

        # Unified Mixer
        self.norm1 = nn.LayerNorm(dim)
        self.attns = DilatedAttention(dim, ws=ws, num_heads=num_heads, attn_drop=attn_drop, qk_scale=qk_scale)
        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # ICMP
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       out_features=dim, act_layer=act_layer, drop=drop)

        # PC
        self.cpe = nn.Conv2D(dim, dim, 3, 1, 1, groups=dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "flatten img_tokens has wrong size"
        img = self.norm1(x)

        qkv = self.qkv(img)
        qkv = qkv.reshape((B, -1, 3, C)).transpose([2, 0, 1, 3])

        x = x + self.drop_path(self.attns(qkv, H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        pe = self.cpe(x.transpose([0, 2, 1]).reshape((B, C, H, W)))
        pe = pe.reshape((B, C, -1)).transpose([0, 2, 1])
        x = x + pe

        return x
class Merge_Block(nn.Layer):
    def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm):
        super().__init__()
        self.conv = nn.Conv2D(dim, dim_out, 3, 2, 1)
        self.norm = norm_layer(dim_out)

    def forward(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = x.transpose([0, 2, 1]).reshape((B, C, H, W))
        x = self.conv(x)
        B, C = x.shape[:2]
        x = x.reshape((B, C, -1)).transpose([0, 2, 1])
        x = self.norm(x)

        return x
class UniNeXt(nn.Layer):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self,
                 in_chans=3,
                 num_classes=1000,
                 embed_dim=96,
                 depth=[2,2,6,2],
                 ws = [7,7,7,7],
                 num_heads=12,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        heads=num_heads
        #------------- stem -----------------------
        stem_out = embed_dim // 2
        self.stem1 = nn.Conv2D(in_chans, stem_out, 3, 2, 1)
        self.norm_act1 = nn.Sequential(
            nn.LayerNorm(stem_out),
            nn.GELU()
        )
        self.stem2 = nn.Conv2D(stem_out, stem_out, 3, 1, 1)
        self.norm_act2 = nn.Sequential(
            nn.LayerNorm(stem_out),
            nn.GELU()
        )
        self.stem3 = nn.Conv2D(stem_out, stem_out, 3, 1, 1)
        self.norm_act3 = nn.Sequential(
            nn.LayerNorm(stem_out),
            nn.GELU()
        )
        #----------------------------------------


        self.merge0 = Merge_Block(stem_out, embed_dim)

        curr_dim = embed_dim
        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, np.sum(depth))]  # stochastic depth decay rule
        self.stage1 = nn.LayerList([
            UnifiedBlock(
                dim=curr_dim,
                num_heads=heads[0],
                ws=ws[0],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i])
            for i in range(depth[0])])

        self.merge1 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage2 = nn.LayerList(
            [UnifiedBlock(
                dim=curr_dim,
                num_heads=heads[1],
                ws=ws[1],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:1])+i])
            for i in range(depth[1])])

        self.merge2 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage3 = nn.LayerList(
            [UnifiedBlock(
                dim=curr_dim,
                num_heads=heads[1],
                ws=ws[1],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:2])+i])
            for i in range(depth[2])])

        self.merge3 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage4 = nn.LayerList(
            [UnifiedBlock(
                dim=curr_dim,
                num_heads=heads[3],
                ws=ws[3],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:-1])+i])
            for i in range(depth[-1])])

        self.norm = norm_layer(curr_dim)
        # Classifier head
        self.head = nn.Linear(curr_dim, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)
    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        kaiming = nn.initializer.KaimingNormal()
        zero = nn.initializer.Constant(0.)
        one = nn.initializer.Constant(1.)
        if isinstance(m, nn.Linear):
            tn(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zero(m.bias)

        if isinstance(m, nn.Conv2D):
            kaiming(m.weight)
            if isinstance(m, nn.Conv2D) and m.bias is not None:
                zero(m.bias)

        if isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)):
            one(m.weight)
            zero(m.bias)

    def forward_features(self, x):
        B, _, H, W = x.shape
        H0, W0, H1, W1, H2, W2, H3, W3, H4, W4 = H // 2, W // 2, H // 4, W // 4, H // 8, W // 8, H // 16, W // 16, H // 32, W // 32
        # stem
        x = self.stem1(x)
        c1 = x.shape[1]
        x = x.reshape((B, c1, -1)).transpose([0, 2, 1])
        x = self.norm_act1(x)
        x = x.transpose([0, 2, 1]).reshape([B, c1, H0, W0])
        x = self.stem2(x)
        c2 = x.shape[1]
        x = x.reshape((B, c2, -1)).transpose([0, 2, 1])
        x = self.norm_act2(x)
        x = x.transpose([0, 2, 1]).reshape((B, c2, H0, W0))
        x = self.stem3(x)
        c3 = x.shape[1]
        x = x.reshape((B, c3, -1)).transpose([0, 2, 1])
        x = self.norm_act3(x)

        x = self.merge0(x)
        C = x.shape[2]

        for blk in self.stage1:
            x = blk(x, H1, W1)

        for pre, blocks, H_i, W_i in zip([self.merge1, self.merge2, self.merge3],
                                              [self.stage2, self.stage3, self.stage4],
                                              [H2, H3, H4],
                                              [W2, W3, W4]):
            x = pre(x)
            C = x.shape[2]
            for blk in blocks:
                x = blk(x, H_i, W_i)

        x = self.norm(x)
        return paddle.mean(x, axis=1)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x
num_classes = 10

def UniNeXt_T():
    model = UniNeXt(embed_dim=64, depth=[2,2,18,2], num_classes=num_classes,
        ws = [7,7,7,7], num_heads=[2,4,8,16], mlp_ratio=4.)

    return model


def UniNeXt_S():
    model = UniNeXt(embed_dim=96, depth=[2,2,18,2], num_classes=num_classes,
        ws = [7,7,7,7], num_heads=[3,6,12,24], mlp_ratio=4.)

    return model


def UniNeXt_B():
    model = UniNeXt(embed_dim=128, depth=[2,2,18,2], num_classes=num_classes,
        ws = [7,7,7,7], num_heads=[4,8,16,32], mlp_ratio=4.)

    return model
2.3.4 模型的参数
model = UniNeXt_T()
paddle.summary(model, (1, 3, 224, 224))

model = UniNeXt_S()
paddle.summary(model, (1, 3, 224, 224))

model = UniNeXt_B()
paddle.summary(model, (1, 3, 224, 224))

2.4 训练

learning_rate = 0.0001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# UniNeXt-T
model = UniNeXt_T()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = UniNeXt_T()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:403
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = UniNeXt_T()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = UniNeXt_T()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=10, num_samples=1000, visual=True)
100%|██████████| 1000/1000 [00:11<00:00, 90.72it/s]

:11<00:00, 90.72it/s]

在这里插入图片描述

总结

        这篇论文提出了一种名为UniNeXt的统一架构,通过将空间令牌混合器与其他组件相结合,增加了网络的归纳偏差,包括在 STM 中添加并行的 EC 分支、在FFN添加 PC 模块以及在 FFN 中添加 3×3 的深度可分离卷积Hdc。最后,在广泛的实验中,作者验证了该框架的有效性,并呼吁研究人员关注宏观架构设计,而非仅仅是一些精心设计的模块(魔改?)。

参考文献

  1. UniNeXt: Exploring A Unified Architecture for Vision Recognition
  2. jianlong-yuan/UniNeXt
  3. 阿里开源新一代通用神经网络架构: UniNeXt

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐