【AI达人特训营】RiR论文复现

在这里插入图片描述

摘要

        残差网络(ResNets)在计算机视觉任务中达到了state of art。我们提出了Resnet in Resnet(RiR):一种深度dual-stream架构,它对ResNets和标准的CNN进行了推广,并且很容易实现(没有额外的计算开销)。RiR在ResNets的基础上进一步提高了性能(同样是在CIFAR-10数据集上,采用和ResNets一样的数据增强技术),并且在CIFAR-100上达到了新的state of art。

1. RiR

        本文提出了一个广义残差网络架构,对ResNet和标准CNN进行推广,广义残差网络架构的模块化单元是一个并行结构的广义残差块,并行包含了一个残差通道 r \text{r} r 和一个瞬变通道 t \text{t} t 。残差通道采用和ResNet类似的identity shortcut连接,瞬变通道采用标准的卷积层。另外,有两组fliter对两个通道进行交叉卷积( W l , r → t W_{l, \mathrm{r} \rightarrow \mathrm{t}} Wl,rt W l , t → r W_{l, \mathrm{t} \rightarrow \mathrm{r}} Wl,tr )
r l + 1 = σ ( conv ⁡ ( r l , W l , r → r ) + conv ⁡ ( t l , W l , t → r ) + shortcut ⁡ ( r l ) ) t l + 1 = σ ( conv ⁡ ( r l , W l , r → t ) + conv ⁡ ( t l , W l , t → r ) ) \begin{array}{c} \mathrm{r}_{l+1}=\sigma\left(\operatorname{conv}\left(\mathrm{r}_{l}, W_{l, \mathrm{r} \rightarrow \mathrm{r}}\right)+\operatorname{conv}\left(\mathrm{t}_{l}, W_{l, \mathrm{t} \rightarrow \mathrm{r}}\right)+\operatorname{shortcut}\left(\mathrm{r}_{l}\right)\right) \\ \mathrm{t}_{l+1}=\sigma\left(\operatorname{conv}\left(\mathrm{r}_{l}, W_{l, \mathrm{r} \rightarrow \mathrm{t}}\right)+\operatorname{conv}\left(\mathrm{t}_{l}, W_{l, \mathrm{t} \rightarrow \mathrm{r}}\right)\right) \end{array} rl+1=σ(conv(rl,Wl,rr)+conv(tl,Wl,tr)+shortcut(rl))tl+1=σ(conv(rl,Wl,rt)+conv(tl,Wl,tr))
        残差通道的使用可以保留残差单元的优化特性,瞬变通道的使用将允许前层提取的特征被去除。下面是广义残差块的框架图

在这里插入图片描述

        如果残差通道的权重为0,广义残差块就相当于一个标准的卷积层;如果瞬变通道的权重为0,广义残差块就相当于标准的残差块。通过广义残差块的堆叠,网络可以学习图1b中的各种可能的结构(例如图1c)。新的广义残差块增强了信息处理能力。广义残差块不仅可以用于CNN,也可以用于其它类型的网络。用广义残差块(图1b)替换原始的残差块中的conv,就产生了一个新的架构(ResNet in ResNet(RiR)图1d),在图2中,我们总结了CNN、ResNet Init、ResNet和RiR架构之间的关系。

在这里插入图片描述

2. 代码复现

2.1 下载并导入所需要的包

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((40, 40)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(32, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 标签平滑

class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()

2.4 RiR模型搭建

        本文实验采用的是RiR-18,其网络结构如下图所示:
在这里插入图片描述

class RiR_Init(nn.Layer):
    def __init__(self, in_channel, out_channel=None, stride=1):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel if out_channel is not None else in_channel
        self.stride = stride
        self.conv_res1 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)
        self.conv_res2 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)
        self.conv1 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)
        self.conv2 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)
        self.bnres = nn.BatchNorm2D(out_channel)
        self.relures = nn.ReLU()
        self.bn = nn.BatchNorm2D(out_channel)
        self.relu = nn.ReLU()
        self.resize_indentity = (in_channel != out_channel) or (stride != 1)
        self.indentity_connection = nn.Conv2D(in_channel, out_channel, 1, stride) if self.resize_indentity else nn.Identity()

    def forward(self, x_res, x_tran):
        x_shortcut = self.indentity_connection(x_res)

        x_res1 = self.conv_res1(x_res)
        x_res2 = self.conv_res2(x_res)

        x1 = self.conv1(x_tran)
        x2 = self.conv2(x_tran)

        out_res = x_res1 + x1 + x_shortcut
        out_tran = x_res2 + x2

        out_res = self.relures(self.bnres(out_res))
        out_tran = self.relu(self.bn(out_tran))

        return out_res, out_tran
class RiRBlock(nn.Layer):
    def __init__(self, in_channel, out_channel, stride=1):
        super().__init__()
        self.rir_init1 = RiR_Init(in_channel, out_channel, stride)
        self.rir_init2 = RiR_Init(out_channel ,out_channel, 1)
        self.resize_indentity = (in_channel != out_channel) or (stride != 1)
        self.indentity_connection1 = nn.Conv2D(in_channel, out_channel, 1, stride) if self.resize_indentity else nn.Identity()
        self.indentity_connection2 = nn.Conv2D(in_channel, out_channel, 1, stride) if self.resize_indentity else nn.Identity()

    def forward(self, x_res, x_tran):
        x_shortcut1 = self.indentity_connection2(x_res)
        x_shortcut2 = self.indentity_connection2(x_tran)
        out_res, out_tran = self.rir_init1(x_res, x_tran)
        out_res, out_tran = self.rir_init2(out_res, out_tran)

        out_res = x_shortcut1 + out_res
        out_tran = x_shortcut2 + out_tran

        return out_res, out_tran
        
class RiRInitBlock(nn.Layer):
    def __init__(self, in_channel, out_channel, stride=1):
        super().__init__()
        self.conv_res = nn.Conv2D(in_channel, out_channel, 3, stride, 1)
        self.conv = nn.Conv2D(in_channel, out_channel, 3, stride, 1)
        self.bnres = nn.BatchNorm2D(out_channel)
        self.relures = nn.ReLU()
        self.bn = nn.BatchNorm2D(out_channel)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x_res = self.conv_res(x)
        x_tran = self.conv(x)

        x_res = self.relures(self.bnres(x_res))
        x_tran = self.relu(self.bn(x_tran))

        return x_res, x_tran
class RiRFinalBlock(nn.Layer):
    def __init__(self):
        super().__init__()
    
    def forward(self, x_res, x_tran):
        return paddle.concat([x_res, x_tran], axis=1)
class RiR(nn.Layer):
    def __init__(self, channels, blocks, in_channels=3, in_size=(32, 32), num_classes=10):
        super().__init__()

        assert len(channels) == len(blocks), 'the length of channels is not the same as the length of blocks'

        self.init = RiRInitBlock(in_channels, channels[0], 1)

        self.stage = nn.LayerList()

        for i in range(len(blocks)):
            if i == 0:
                for j in range(blocks[i]):
                    self.stage.append(RiRBlock(channels[i], channels[i]))
            else:
                for j in range(blocks[i]):
                    self.stage.append(RiRBlock(channels[i-1] if j==0 else channels[i], channels[i], stride = 2 if j==0 else 1))

        self.final = RiRFinalBlock()

        self.classifier = nn.Sequential(nn.Conv2D(channels[-1] * 2, num_classes, 1),
            nn.AdaptiveAvgPool2D(1), nn.Flatten(1))
        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        zeros_ = nn.initializer.Constant(value=0.)
        ones_ = nn.initializer.Constant(value=1.)
        if isinstance(m, (nn.Linear, nn.Conv2D)):
            paddle.nn.initializer.KaimingNormal(m.weight)
            if isinstance(m, (nn.Linear, nn.Conv2D)) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2D)):
            zeros_(m.bias)
            ones_(m.weight)

    def forward(self, x):
        
        x_res, x_tran = self.init(x)

        for i in range(len(self.stage)):
            x_res, x_tran = self.stage[i](x_res, x_tran)
        
        out = self.final(x_res, x_tran)

        out = self.classifier(out)

        return out

2.5 模型的参数和FLOPs

model = RiR([48, 96, 192], [2, 3, 3])
paddle.summary(model, (batch_size, 3, 32, 32))

在这里插入图片描述

model = RiR([48, 96, 192], [2, 3, 3])
paddle.flops(model, (batch_size, 3, 32, 32))
<class 'paddle.nn.layer.conv.Conv2D'>'s flops has been counted
<class 'paddle.nn.layer.norm.BatchNorm2D'>'s flops has been counted
<class 'paddle.nn.layer.activation.ReLU'>'s flops has been counted
Cannot find suitable count function for <class 'paddle.nn.layer.common.Identity'>. Treat it as zero FLOPs.
Cannot find suitable count function for <class '__main__.RiRFinalBlock'>. Treat it as zero FLOPs.
<class 'paddle.nn.layer.pooling.AdaptiveAvgPool2D'>'s flops has been counted
Cannot find suitable count function for <class 'paddle.fluid.dygraph.nn.Flatten'>. Treat it as zero FLOPs.
Total Flops: 164831593728     Total Params: 9532234





164831593728

2.6 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# RiR-18
model = RiR([48, 96, 192], [2, 3, 3])

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

在这里插入图片描述

2.7 实验结果

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = RiR([48, 96, 192], [2, 3, 3])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:1462
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = RiR([48, 96, 192], [2, 3, 3])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 32, 32, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

对比实验

ResNet-18(根据RiR论文里写的)的测试结果见main-copy1.ipynb

modelAccParameterFLOPs
RiR-180.925149,532,234164831593728
ResNet-180.919809,574,474165021910272

总结

        本文提出了一个广义残差架构(generalized residual architecture),通过对原始方案简单的修改便可以实现这个网络(ResNet Init)。将ResNet Init应用到原始的ResNet中,从而得到RiR架构,RiR架构取得了非常好的结果。
        从对比实验中可以看出RiR可以以更少的参数和FLOPs超越ResNet性能,在准确率方面提高了约0.6%。
        未来工作:对RiR进一步的改进,并尝试在更大的数据集上比较性能

开源链接:https://aistudio.baidu.com/aistudio/projectdetail/4294062?shared=1

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐