A2-Nets: Double Attention Networks

  • 论文的名字很好,反映了本文的核心想法:首先使用second-order attention pooling将整幅图的所有关键的特征搜集到了一个集合里,然后用另一种attention机制将这些特征分别图像的每个location。

paper:https://arxiv.org/pdf/1810.11579.pdf

github:https://github.com/pijiande/A2Net-DoubleAttentionlayer

前言

  • 这次是复现NeurIPS2018的一篇论文,本论文也是非常的简单呢,是一个涨点神器,也是一个即插即用的小模块。

  • 本文的主要创新点是提出了一个新的注意力机制,你可以看做SE的进化版本,在各CV任务测试性能如下

相关代码

  • 作者提出的A2-Net的核心思想是首先将整个空间的关键特征收集到一个紧凑的集合中,然后自适应地将其分布到每个位置,这样后续的卷积层即使没有很大的接收域也可以感知整个空间的特征。
    第一级的注意力集中操作有选择地从整个空间中收集关键特征,而第二级的注意力集中操作采用另一种注意力机制,自适应地分配关键特征的子集,这些特征有助于补充高级任务的每个时空位置。

  • A2-Net与SENet、协方差池化、Non-local、Transformer有点类似,但是不同点在于它的第一个注意力操作隐式地计算池化特征的二阶统计,并能捕获SENet中使用的全局平均池化无法捕获的复杂外观和运动相关性;它的第二注意力操作从一个紧凑的袋子中自适应地分配特征,这比 Non-local、Transformer中将所有位置的特征与每个特定位置进行穷举关联更有效。

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

class DoubleAtten(nn.Layer):
    """
    A2-Nets: Double Attention Networks. NIPS 2018
    """
    def __init__(self,in_c):
        """
        :param
        in_c: 进行注意力refine的特征图的通道数目;
        原文中的降维和升维没有使用
        """
        super(DoubleAtten,self).__init__()
        self.in_c = in_c
        """
        以下对同一输入特征图进行卷积,产生三个尺度相同的特征图,即为文中提到A, B, V
        """
        self.convA = nn.Conv2D(in_c,in_c,kernel_size=1)
        self.convB = nn.Conv2D(in_c,in_c,kernel_size=1)
        self.convV = nn.Conv2D(in_c,in_c,kernel_size=1)
    def forward(self,input):

        feature_maps = self.convA(input)
        atten_map = self.convB(input)
        b, _, h, w = feature_maps.shape

        feature_maps = feature_maps.reshape([b, 1, self.in_c, h*w]) # 对 A 进行reshape
        atten_map = atten_map.reshape([b, self.in_c, 1, h*w])       # 对 B 进行reshape 生成 attention_aps
        global_descriptors = paddle.mean((feature_maps * F.softmax(atten_map, axis=-1)),axis=-1) # 特征图与attention_maps 相乘生成全局特征描述子

        v = self.convV(input)
        atten_vectors = F.softmax(v.reshape([b, self.in_c, h*w]), axis=-1) # 生成 attention_vectors
        out = paddle.bmm(atten_vectors.transpose([0,2,1]), global_descriptors).transpose([0,2,1]) # 注意力向量左乘全局特征描述子

        return out.reshape([b, _, h, w])

W1130 04:50:41.878121   104 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1130 04:50:41.882817   104 device_context.cc:465] device: 0, cuDNN Version: 7.6.


[4, 512, 16, 16]

验证

input size = 64,512,14,14 --> CA --> output size = 64,512,14,14

if __name__=="__main__":
    a = paddle.rand([4,512,16,16])
    model = DoubleAtten(512)
    a = model(a)
    print(a.shape)

对DA性能进行验证

论文中,作者在Resnet26测试,但是对于ResNet50深层网络作者没有做相关实验,我们这次搭建一个ResNet50网络来验证性能,DA模块插入位置如下。

DA_ResNet50 搭建

import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url

class BasicBlock(nn.Layer):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2D

        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")

        self.conv1 = nn.Conv2D(
            inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class BottleneckBlock(nn.Layer):

    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super(BottleneckBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2D
        width = int(planes * (base_width / 64.)) * groups

        self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
        self.bn1 = norm_layer(width)

        self.conv2 = nn.Conv2D(
            width,
            width,
            3,
            padding=dilation,
            stride=stride,
            groups=groups,
            dilation=dilation,
            bias_attr=False)
        self.bn2 = norm_layer(width)

        self.conv3 = nn.Conv2D(
            width, planes * self.expansion, 1, bias_attr=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride
        self.attention = DoubleAtten(planes * self.expansion)



    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.attention(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Layer):

    def __init__(self,
                 block,
                 depth=50,
                 width=64,
                 num_classes=1000,
                 with_pool=True):
        super(ResNet, self).__init__()
        layer_cfg = {
            18: [2, 2, 2, 2],
            34: [3, 4, 6, 3],
            50: [3, 4, 6, 3],
            101: [3, 4, 23, 3],
            152: [3, 8, 36, 3]
        }
        layers = layer_cfg[depth]
        self.groups = 1
        self.base_width = width
        self.num_classes = num_classes
        self.with_pool = with_pool
        self._norm_layer = nn.BatchNorm2D

        self.inplanes = 64
        self.dilation = 1

        self.conv1 = nn.Conv2D(
            3,
            self.inplanes,
            kernel_size=7,
            stride=2,
            padding=3,
            bias_attr=False)
        self.bn1 = self._norm_layer(self.inplanes)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        if with_pool:
            self.avgpool = nn.AdaptiveAvgPool2D((1, 1))

        if num_classes > 0:
            self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2D(
                    self.inplanes,
                    planes * block.expansion,
                    1,
                    stride=stride,
                    bias_attr=False),
                norm_layer(planes * block.expansion), )

        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.with_pool:
            x = self.avgpool(x)

        if self.num_classes > 0:
            x = paddle.flatten(x, 1)
            x = self.fc(x)

        return x


def _resnet(arch, Block, depth, pretrained, **kwargs):
    model = ResNet(Block, depth, **kwargs)
    if pretrained:
        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
            arch)
        weight_path = get_weights_path_from_url(model_urls[arch][0],
                                                model_urls[arch][1])

        param = paddle.load(weight_path)
        model.set_dict(param)

    return model


def resnet18(pretrained=False, **kwargs):

    return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)


def resnet34(pretrained=False, **kwargs):

    return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)


def resnet50(pretrained=False, **kwargs):

    return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)


def resnet101(pretrained=False, **kwargs):

    return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)


def resnet152(pretrained=False, **kwargs):

    return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)


def wide_resnet50_2(pretrained=False, **kwargs):
    kwargs['width'] = 64 * 2
    return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs)


def wide_resnet101_2(pretrained=False, **kwargs):

    kwargs['width'] = 64 * 2
    return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained,
                   **kwargs)
da_res50 = resnet50(num_classes=10)
paddle.Model(da_res50).summary((1,3,224,224))
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
    Conv2D-105       [[1, 3, 224, 224]]   [1, 64, 112, 112]        9,408     
  BatchNorm2D-54    [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
      ReLU-18       [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
    MaxPool2D-2     [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
    Conv2D-107       [[1, 64, 56, 56]]     [1, 64, 56, 56]         4,096     
  BatchNorm2D-56     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-19        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-108       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-57     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
    Conv2D-109       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-58     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
    Conv2D-110       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
    Conv2D-111       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
    Conv2D-112       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
  DoubleAtten-18     [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-106       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-55     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
BottleneckBlock-17   [[1, 64, 56, 56]]     [1, 256, 56, 56]          0       
    Conv2D-113       [[1, 256, 56, 56]]    [1, 64, 56, 56]        16,384     
  BatchNorm2D-59     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-20        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-114       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-60     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
    Conv2D-115       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-61     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
    Conv2D-116       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
    Conv2D-117       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
    Conv2D-118       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
  DoubleAtten-19     [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
BottleneckBlock-18   [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-119       [[1, 256, 56, 56]]    [1, 64, 56, 56]        16,384     
  BatchNorm2D-62     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-21        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-120       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-63     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
    Conv2D-121       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-64     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
    Conv2D-122       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
    Conv2D-123       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
    Conv2D-124       [[1, 256, 56, 56]]    [1, 256, 56, 56]       65,792     
  DoubleAtten-20     [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
BottleneckBlock-19   [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-126       [[1, 256, 56, 56]]    [1, 128, 56, 56]       32,768     
  BatchNorm2D-66     [[1, 128, 56, 56]]    [1, 128, 56, 56]         512      
      ReLU-22        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-127       [[1, 128, 56, 56]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-67     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
    Conv2D-128       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-68     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
    Conv2D-129       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-130       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-131       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
  DoubleAtten-21     [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-125       [[1, 256, 56, 56]]    [1, 512, 28, 28]       131,072    
  BatchNorm2D-65     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
BottleneckBlock-20   [[1, 256, 56, 56]]    [1, 512, 28, 28]          0       
    Conv2D-132       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-69     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-23        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-133       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-70     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
    Conv2D-134       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-71     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
    Conv2D-135       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-136       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-137       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
  DoubleAtten-22     [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-21   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-138       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-72     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-24        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-139       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-73     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
    Conv2D-140       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-74     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
    Conv2D-141       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-142       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-143       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
  DoubleAtten-23     [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-22   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-144       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-75     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-25        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-145       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-76     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
    Conv2D-146       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-77     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
    Conv2D-147       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-148       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
    Conv2D-149       [[1, 512, 28, 28]]    [1, 512, 28, 28]       262,656    
  DoubleAtten-24     [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-23   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-151       [[1, 512, 28, 28]]    [1, 256, 28, 28]       131,072    
  BatchNorm2D-79     [[1, 256, 28, 28]]    [1, 256, 28, 28]        1,024     
      ReLU-26       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-152       [[1, 256, 28, 28]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-80     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-153       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-81    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-154      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-155      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-156      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
  DoubleAtten-25    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-150       [[1, 512, 28, 28]]   [1, 1024, 14, 14]       524,288    
  BatchNorm2D-78    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-24   [[1, 512, 28, 28]]   [1, 1024, 14, 14]          0       
    Conv2D-157      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-82     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-27       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-158       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-83     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-159       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-84    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-160      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-161      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-162      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
  DoubleAtten-26    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-25  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-163      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-85     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-28       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-164       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-86     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-165       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-87    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-166      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-167      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-168      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
  DoubleAtten-27    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-26  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-169      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-88     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-29       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-170       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-89     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-171       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-90    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-172      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-173      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-174      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
  DoubleAtten-28    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-27  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-175      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-91     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-30       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-176       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-92     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-177       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-93    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-178      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-179      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-180      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
  DoubleAtten-29    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-28  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-181      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-94     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-31       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-182       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-95     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-183       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-96    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-184      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-185      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
    Conv2D-186      [[1, 1024, 14, 14]]   [1, 1024, 14, 14]      1,049,600   
  DoubleAtten-30    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-29  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-188      [[1, 1024, 14, 14]]    [1, 512, 14, 14]       524,288    
  BatchNorm2D-98     [[1, 512, 14, 14]]    [1, 512, 14, 14]        2,048     
      ReLU-32        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-189       [[1, 512, 14, 14]]     [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-99      [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
    Conv2D-190        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-100    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
    Conv2D-191       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
    Conv2D-192       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
    Conv2D-193       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
  DoubleAtten-31     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-187      [[1, 1024, 14, 14]]    [1, 2048, 7, 7]       2,097,152   
  BatchNorm2D-97     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
BottleneckBlock-30  [[1, 1024, 14, 14]]    [1, 2048, 7, 7]           0       
    Conv2D-194       [[1, 2048, 7, 7]]      [1, 512, 7, 7]       1,048,576   
  BatchNorm2D-101     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-33        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-195        [[1, 512, 7, 7]]      [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-102     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
    Conv2D-196        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-103    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
    Conv2D-197       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
    Conv2D-198       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
    Conv2D-199       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
  DoubleAtten-32     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
BottleneckBlock-31   [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-200       [[1, 2048, 7, 7]]      [1, 512, 7, 7]       1,048,576   
  BatchNorm2D-104     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-34        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-201        [[1, 512, 7, 7]]      [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-105     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
    Conv2D-202        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-106    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
    Conv2D-203       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
    Conv2D-204       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
    Conv2D-205       [[1, 2048, 7, 7]]     [1, 2048, 7, 7]       4,196,352   
  DoubleAtten-33     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
BottleneckBlock-32   [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
AdaptiveAvgPool2D-2  [[1, 2048, 7, 7]]     [1, 2048, 1, 1]           0       
     Linear-2           [[1, 2048]]            [1, 10]            20,490     
===============================================================================
Total params: 83,985,610
Trainable params: 83,879,370
Non-trainable params: 106,240
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 429.91
Params size (MB): 320.38
Estimated Total Size (MB): 750.87
-------------------------------------------------------------------------------






{'total_params': 83985610, 'trainable_params': 83879370}

Cifar10数据准备

import paddle.vision.transforms as T
from paddle.vision.datasets import Cifar10
paddle.set_device('gpu')

# 数据准备
transform = T.Compose([
    T.Resize(size=(224,224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
    T.ToTensor()
])

train_dataset = Cifar10(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)

ResNet50在Cifar10训练

# 模型准备
res50 = paddle.vision.models.resnet18(num_classes=10)
res50.train()


# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=res50.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()

res50_loss = []
res50_acc = []

for epoch in range(epoch_num):
    for batch_id, data in enumerate(train_loader):
        inputs = data[0]            
        labels = data[1].unsqueeze(1)            
        predicts = res50(inputs)    

        loss = loss_fn(predicts, labels)
        acc = paddle.metric.accuracy(predicts, labels)
        loss.backward()

        if batch_id % 100 == 0: 
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))

        if batch_id % 20 == 0:

            res50_loss.append(loss.numpy())
            res50_acc.append(acc.numpy())

        optim.step()
        optim.clear_grad()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:653: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")


epoch: 0, batch_id: 0, loss is: [2.8489058], acc is: [0.1875]
epoch: 0, batch_id: 100, loss is: [1.5169208], acc is: [0.453125]
epoch: 0, batch_id: 200, loss is: [1.4419072], acc is: [0.453125]
epoch: 0, batch_id: 300, loss is: [1.303793], acc is: [0.5625]
epoch: 0, batch_id: 400, loss is: [1.1453766], acc is: [0.5]
epoch: 0, batch_id: 500, loss is: [0.97646654], acc is: [0.609375]
epoch: 0, batch_id: 600, loss is: [0.9759238], acc is: [0.6875]
epoch: 0, batch_id: 700, loss is: [0.8780391], acc is: [0.609375]
epoch: 1, batch_id: 0, loss is: [1.042079], acc is: [0.703125]
epoch: 1, batch_id: 100, loss is: [0.83035773], acc is: [0.703125]
epoch: 1, batch_id: 200, loss is: [0.8466539], acc is: [0.765625]
epoch: 1, batch_id: 300, loss is: [0.93592036], acc is: [0.71875]
epoch: 1, batch_id: 400, loss is: [0.90168834], acc is: [0.71875]
epoch: 1, batch_id: 500, loss is: [0.79210997], acc is: [0.6875]
epoch: 1, batch_id: 600, loss is: [0.76145744], acc is: [0.734375]
epoch: 1, batch_id: 700, loss is: [0.8608778], acc is: [0.765625]
epoch: 2, batch_id: 0, loss is: [0.71813166], acc is: [0.75]
epoch: 2, batch_id: 100, loss is: [0.66491127], acc is: [0.765625]
epoch: 2, batch_id: 200, loss is: [0.5883076], acc is: [0.84375]
epoch: 2, batch_id: 300, loss is: [0.49795514], acc is: [0.8125]
epoch: 2, batch_id: 400, loss is: [0.85554576], acc is: [0.75]
epoch: 2, batch_id: 500, loss is: [0.6621829], acc is: [0.8125]
epoch: 2, batch_id: 600, loss is: [0.5962929], acc is: [0.796875]
epoch: 2, batch_id: 700, loss is: [0.42166245], acc is: [0.875]
epoch: 3, batch_id: 0, loss is: [0.44783327], acc is: [0.8125]
epoch: 3, batch_id: 100, loss is: [0.51038265], acc is: [0.796875]
epoch: 3, batch_id: 200, loss is: [0.5499249], acc is: [0.828125]
epoch: 3, batch_id: 300, loss is: [0.3780781], acc is: [0.859375]
epoch: 3, batch_id: 400, loss is: [0.5779753], acc is: [0.828125]
epoch: 3, batch_id: 500, loss is: [0.52967703], acc is: [0.828125]
epoch: 3, batch_id: 600, loss is: [0.34549445], acc is: [0.921875]
epoch: 3, batch_id: 700, loss is: [0.80611515], acc is: [0.734375]
epoch: 4, batch_id: 0, loss is: [0.21761498], acc is: [0.953125]
epoch: 4, batch_id: 100, loss is: [0.558988], acc is: [0.828125]
epoch: 4, batch_id: 200, loss is: [0.4339955], acc is: [0.84375]
epoch: 4, batch_id: 300, loss is: [0.45421124], acc is: [0.84375]
epoch: 4, batch_id: 400, loss is: [0.55842537], acc is: [0.78125]
epoch: 4, batch_id: 500, loss is: [0.46884495], acc is: [0.859375]
epoch: 4, batch_id: 600, loss is: [0.29323775], acc is: [0.875]
epoch: 4, batch_id: 700, loss is: [0.24813098], acc is: [0.9375]
epoch: 5, batch_id: 0, loss is: [0.26143038], acc is: [0.890625]
epoch: 5, batch_id: 100, loss is: [0.28214324], acc is: [0.90625]
epoch: 5, batch_id: 200, loss is: [0.28322464], acc is: [0.90625]
epoch: 5, batch_id: 300, loss is: [0.5579711], acc is: [0.796875]
epoch: 5, batch_id: 400, loss is: [0.2829948], acc is: [0.921875]
epoch: 5, batch_id: 500, loss is: [0.22569823], acc is: [0.921875]
epoch: 5, batch_id: 600, loss is: [0.36216933], acc is: [0.859375]
epoch: 5, batch_id: 700, loss is: [0.39805618], acc is: [0.84375]
epoch: 6, batch_id: 0, loss is: [0.17924805], acc is: [0.953125]
epoch: 6, batch_id: 100, loss is: [0.29918194], acc is: [0.90625]
epoch: 6, batch_id: 200, loss is: [0.18964693], acc is: [0.90625]
epoch: 6, batch_id: 300, loss is: [0.36167118], acc is: [0.90625]
epoch: 6, batch_id: 400, loss is: [0.2007184], acc is: [0.9375]
epoch: 6, batch_id: 500, loss is: [0.15255482], acc is: [0.90625]
epoch: 6, batch_id: 600, loss is: [0.33478457], acc is: [0.84375]
epoch: 6, batch_id: 700, loss is: [0.2562605], acc is: [0.875]
epoch: 7, batch_id: 0, loss is: [0.11516868], acc is: [0.96875]
epoch: 7, batch_id: 100, loss is: [0.10812954], acc is: [0.953125]
epoch: 7, batch_id: 200, loss is: [0.08551887], acc is: [0.984375]
epoch: 7, batch_id: 300, loss is: [0.34660637], acc is: [0.875]
epoch: 7, batch_id: 400, loss is: [0.1358226], acc is: [0.953125]
epoch: 7, batch_id: 500, loss is: [0.13017304], acc is: [0.96875]
epoch: 7, batch_id: 600, loss is: [0.46732765], acc is: [0.828125]
epoch: 7, batch_id: 700, loss is: [0.19166371], acc is: [0.9375]
epoch: 8, batch_id: 0, loss is: [0.07097062], acc is: [0.96875]
epoch: 8, batch_id: 100, loss is: [0.03216531], acc is: [0.984375]
epoch: 8, batch_id: 200, loss is: [0.1407477], acc is: [0.9375]
epoch: 8, batch_id: 300, loss is: [0.06303663], acc is: [0.984375]
epoch: 8, batch_id: 400, loss is: [0.07170855], acc is: [0.96875]
epoch: 8, batch_id: 500, loss is: [0.1672141], acc is: [0.90625]
epoch: 8, batch_id: 600, loss is: [0.05715706], acc is: [1.]
epoch: 8, batch_id: 700, loss is: [0.2710502], acc is: [0.9375]
epoch: 9, batch_id: 0, loss is: [0.06866629], acc is: [0.984375]
epoch: 9, batch_id: 100, loss is: [0.04419067], acc is: [0.984375]
epoch: 9, batch_id: 200, loss is: [0.07565454], acc is: [0.96875]
epoch: 9, batch_id: 300, loss is: [0.15519926], acc is: [0.953125]
epoch: 9, batch_id: 400, loss is: [0.01717619], acc is: [1.]
epoch: 9, batch_id: 500, loss is: [0.08571189], acc is: [0.96875]
epoch: 9, batch_id: 600, loss is: [0.05870717], acc is: [0.96875]
epoch: 9, batch_id: 700, loss is: [0.07783952], acc is: [0.984375]

DA_ResNet50在Cifar10数据集训练

# 模型准备
da_res50 = resnet18(num_classes=10)
da_res50.train()

# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=da_res50.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()

da_res50_loss = []
da_res50_acc = []

for epoch in range(epoch_num):
    for batch_id, data in enumerate(train_loader):
        inputs = data[0]            
        labels = data[1].unsqueeze(1)            
        predicts = da_res50(inputs)    

        loss = loss_fn(predicts, labels)
        acc = paddle.metric.accuracy(predicts, labels)
        loss.backward()

        if batch_id % 100 == 0: 
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
            
        if batch_id % 20 == 0:
            da_res50_loss.append(loss.numpy())
            da_res50_acc.append(acc.numpy())

        optim.step()
        optim.clear_grad()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:653: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")


epoch: 0, batch_id: 0, loss is: [2.6978078], acc is: [0.15625]
epoch: 0, batch_id: 100, loss is: [1.7868621], acc is: [0.359375]
epoch: 0, batch_id: 200, loss is: [1.4539795], acc is: [0.4375]
epoch: 0, batch_id: 300, loss is: [1.3904684], acc is: [0.515625]
epoch: 0, batch_id: 400, loss is: [1.0738535], acc is: [0.578125]
epoch: 0, batch_id: 500, loss is: [1.276068], acc is: [0.546875]
epoch: 0, batch_id: 600, loss is: [0.9131775], acc is: [0.671875]
epoch: 0, batch_id: 700, loss is: [0.9701067], acc is: [0.65625]
epoch: 1, batch_id: 0, loss is: [1.0521469], acc is: [0.6875]
epoch: 1, batch_id: 100, loss is: [0.8612659], acc is: [0.6875]
epoch: 1, batch_id: 200, loss is: [0.88844216], acc is: [0.703125]
epoch: 1, batch_id: 300, loss is: [0.80324924], acc is: [0.65625]
epoch: 1, batch_id: 400, loss is: [0.7941488], acc is: [0.734375]
epoch: 1, batch_id: 500, loss is: [0.72946227], acc is: [0.734375]
epoch: 1, batch_id: 600, loss is: [0.69075954], acc is: [0.78125]
epoch: 1, batch_id: 700, loss is: [0.49935198], acc is: [0.8125]
epoch: 2, batch_id: 0, loss is: [0.6254372], acc is: [0.796875]
epoch: 2, batch_id: 100, loss is: [0.6929467], acc is: [0.75]
epoch: 2, batch_id: 200, loss is: [0.6014596], acc is: [0.765625]
epoch: 2, batch_id: 300, loss is: [0.58570975], acc is: [0.796875]
epoch: 2, batch_id: 400, loss is: [0.56933963], acc is: [0.828125]
epoch: 2, batch_id: 500, loss is: [0.7758298], acc is: [0.703125]
epoch: 2, batch_id: 600, loss is: [0.7611967], acc is: [0.765625]
epoch: 2, batch_id: 700, loss is: [0.6714103], acc is: [0.84375]
epoch: 3, batch_id: 0, loss is: [0.4566578], acc is: [0.8125]
epoch: 3, batch_id: 100, loss is: [0.4512155], acc is: [0.84375]
epoch: 3, batch_id: 200, loss is: [0.79440355], acc is: [0.765625]
epoch: 3, batch_id: 300, loss is: [0.40544614], acc is: [0.859375]
epoch: 3, batch_id: 400, loss is: [0.3677606], acc is: [0.84375]
epoch: 3, batch_id: 500, loss is: [0.5795288], acc is: [0.765625]
epoch: 3, batch_id: 600, loss is: [0.3732043], acc is: [0.890625]
epoch: 3, batch_id: 700, loss is: [0.752839], acc is: [0.71875]
epoch: 4, batch_id: 0, loss is: [0.48794582], acc is: [0.828125]
epoch: 4, batch_id: 100, loss is: [0.27974662], acc is: [0.875]
epoch: 4, batch_id: 200, loss is: [0.35176855], acc is: [0.84375]
epoch: 4, batch_id: 300, loss is: [0.43276876], acc is: [0.8125]
epoch: 4, batch_id: 400, loss is: [0.45843226], acc is: [0.796875]
epoch: 4, batch_id: 500, loss is: [0.22956792], acc is: [0.921875]
epoch: 4, batch_id: 600, loss is: [0.3693059], acc is: [0.875]
epoch: 4, batch_id: 700, loss is: [0.466375], acc is: [0.8125]
epoch: 5, batch_id: 0, loss is: [0.19983137], acc is: [0.890625]
epoch: 5, batch_id: 100, loss is: [0.34767026], acc is: [0.890625]
epoch: 5, batch_id: 200, loss is: [0.31151807], acc is: [0.890625]
epoch: 5, batch_id: 300, loss is: [0.3181343], acc is: [0.84375]
epoch: 5, batch_id: 400, loss is: [0.36661047], acc is: [0.828125]
epoch: 5, batch_id: 500, loss is: [0.20549971], acc is: [0.921875]
epoch: 5, batch_id: 600, loss is: [0.3835126], acc is: [0.875]
epoch: 5, batch_id: 700, loss is: [0.29854366], acc is: [0.9375]
epoch: 6, batch_id: 0, loss is: [0.21725702], acc is: [0.953125]
epoch: 6, batch_id: 100, loss is: [0.10980862], acc is: [0.953125]
epoch: 6, batch_id: 200, loss is: [0.24021214], acc is: [0.90625]
epoch: 6, batch_id: 300, loss is: [0.17704055], acc is: [0.953125]
epoch: 6, batch_id: 400, loss is: [0.21923174], acc is: [0.921875]
epoch: 6, batch_id: 500, loss is: [0.17121044], acc is: [0.9375]
epoch: 6, batch_id: 600, loss is: [0.10075981], acc is: [1.]
epoch: 6, batch_id: 700, loss is: [0.2989301], acc is: [0.921875]
epoch: 7, batch_id: 0, loss is: [0.19572908], acc is: [0.921875]
epoch: 7, batch_id: 100, loss is: [0.2877647], acc is: [0.9375]
epoch: 7, batch_id: 200, loss is: [0.2735257], acc is: [0.875]
epoch: 7, batch_id: 300, loss is: [0.16631857], acc is: [0.96875]
epoch: 7, batch_id: 400, loss is: [0.19605], acc is: [0.953125]
epoch: 7, batch_id: 500, loss is: [0.11403949], acc is: [0.984375]
epoch: 7, batch_id: 600, loss is: [0.2396015], acc is: [0.9375]
epoch: 7, batch_id: 700, loss is: [0.22148657], acc is: [0.921875]
epoch: 8, batch_id: 0, loss is: [0.0501596], acc is: [0.984375]
epoch: 8, batch_id: 100, loss is: [0.07064686], acc is: [0.96875]
epoch: 8, batch_id: 200, loss is: [0.1000964], acc is: [0.9375]
epoch: 8, batch_id: 300, loss is: [0.21059893], acc is: [0.9375]
epoch: 8, batch_id: 400, loss is: [0.1722869], acc is: [0.9375]
epoch: 8, batch_id: 500, loss is: [0.1418913], acc is: [0.953125]
epoch: 8, batch_id: 600, loss is: [0.09865286], acc is: [0.953125]
epoch: 8, batch_id: 700, loss is: [0.18124923], acc is: [0.953125]
epoch: 9, batch_id: 0, loss is: [0.07974362], acc is: [0.96875]
epoch: 9, batch_id: 100, loss is: [0.09133874], acc is: [0.96875]
epoch: 9, batch_id: 200, loss is: [0.18680792], acc is: [0.9375]
epoch: 9, batch_id: 300, loss is: [0.09332339], acc is: [0.96875]
epoch: 9, batch_id: 400, loss is: [0.05375797], acc is: [0.96875]
epoch: 9, batch_id: 500, loss is: [0.02048387], acc is: [1.]
epoch: 9, batch_id: 600, loss is: [0.06067649], acc is: [0.984375]
epoch: 9, batch_id: 700, loss is: [0.0408816], acc is: [1.]
import matplotlib.pyplot as plt

plt.figure(figsize=(18,12))
plt.subplot(211)

plt.xlabel('iter')
plt.ylabel('loss')
plt.title('train loss')

x=range(len(ca_res50_loss))
plt.plot(x,res50_loss,color='b',label='ResNet50')
plt.plot(x,ca_res50_loss,color='r',label='ResNet50 + CA')

plt.legend()
plt.grid()

plt.subplot(212)
plt.xlabel('iter')
plt.ylabel('acc')
plt.title('train acc')

x=range(len(ca_res50_acc))
plt.plot(x, res50_acc, color='b',label='ResNet50')
plt.plot(x, ca_res50_acc, color='r',label='ResNet50 + CA')

plt.legend()
plt.grid()

label='ResNet50')
plt.plot(x,ca_res50_loss,color='r',label='ResNet50 + CA')

plt.legend()
plt.grid()

plt.subplot(212)
plt.xlabel('iter')
plt.ylabel('acc')
plt.title('train acc')

x=range(len(ca_res50_acc))
plt.plot(x, res50_acc, color='b',label='ResNet50')
plt.plot(x, ca_res50_acc, color='r',label='ResNet50 + CA')

plt.legend()
plt.grid()

plt.show()

绘制 ResNet50 和 DA_ResNet50 训练曲线


总结

  • 提出了一个通用的公式,通过通用的收集和分布函数来捕获长期的特征相关性
  • 提出了一种用于收集和分布长距离特征的双注意块,它是一种有效的二次特征统计和自适应特征分配的体系结构。该块可以用较低的计算和内存占用来建模长期的相互依赖关系,同时显著提高图像/视频识别性能
  • 通过广泛的消融研究来调查提出的A2-Net的影响,并通过与当前技术水平的比较来证明它在图像识别和视频动作识别任务的一些公共基准上的优越性能。

特别感谢:仰世而来丶(本文参考了https://aistudio.baidu.com/aistudio/projectdetail/1884947?channelType=0&channel=0)

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐