A2 -Nets Double Attention Networks
本文基于Paddle复现了NeurIPS2018的一篇关于双头注意力机制的论文,CV领域的涨点利器。
A2-Nets: Double Attention Networks
- 论文的名字很好,反映了本文的核心想法:首先使用second-order attention pooling将整幅图的所有关键的特征搜集到了一个集合里,然后用另一种attention机制将这些特征分别图像的每个location。
paper:https://arxiv.org/pdf/1810.11579.pdf
github:https://github.com/pijiande/A2Net-DoubleAttentionlayer
前言
-
这次是复现NeurIPS2018的一篇论文,本论文也是非常的简单呢,是一个涨点神器,也是一个即插即用的小模块。
-
本文的主要创新点是提出了一个新的注意力机制,你可以看做SE的进化版本,在各CV任务测试性能如下
相关代码
-
作者提出的A2-Net的核心思想是首先将整个空间的关键特征收集到一个紧凑的集合中,然后自适应地将其分布到每个位置,这样后续的卷积层即使没有很大的接收域也可以感知整个空间的特征。
第一级的注意力集中操作有选择地从整个空间中收集关键特征,而第二级的注意力集中操作采用另一种注意力机制,自适应地分配关键特征的子集,这些特征有助于补充高级任务的每个时空位置。 -
A2-Net与SENet、协方差池化、Non-local、Transformer有点类似,但是不同点在于它的第一个注意力操作隐式地计算池化特征的二阶统计,并能捕获SENet中使用的全局平均池化无法捕获的复杂外观和运动相关性;它的第二注意力操作从一个紧凑的袋子中自适应地分配特征,这比 Non-local、Transformer中将所有位置的特征与每个特定位置进行穷举关联更有效。
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class DoubleAtten(nn.Layer):
"""
A2-Nets: Double Attention Networks. NIPS 2018
"""
def __init__(self,in_c):
"""
:param
in_c: 进行注意力refine的特征图的通道数目;
原文中的降维和升维没有使用
"""
super(DoubleAtten,self).__init__()
self.in_c = in_c
"""
以下对同一输入特征图进行卷积,产生三个尺度相同的特征图,即为文中提到A, B, V
"""
self.convA = nn.Conv2D(in_c,in_c,kernel_size=1)
self.convB = nn.Conv2D(in_c,in_c,kernel_size=1)
self.convV = nn.Conv2D(in_c,in_c,kernel_size=1)
def forward(self,input):
feature_maps = self.convA(input)
atten_map = self.convB(input)
b, _, h, w = feature_maps.shape
feature_maps = feature_maps.reshape([b, 1, self.in_c, h*w]) # 对 A 进行reshape
atten_map = atten_map.reshape([b, self.in_c, 1, h*w]) # 对 B 进行reshape 生成 attention_aps
global_descriptors = paddle.mean((feature_maps * F.softmax(atten_map, axis=-1)),axis=-1) # 特征图与attention_maps 相乘生成全局特征描述子
v = self.convV(input)
atten_vectors = F.softmax(v.reshape([b, self.in_c, h*w]), axis=-1) # 生成 attention_vectors
out = paddle.bmm(atten_vectors.transpose([0,2,1]), global_descriptors).transpose([0,2,1]) # 注意力向量左乘全局特征描述子
return out.reshape([b, _, h, w])
W1130 04:50:41.878121 104 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1130 04:50:41.882817 104 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[4, 512, 16, 16]
验证
input size = 64,512,14,14 --> CA --> output size = 64,512,14,14
if __name__=="__main__":
a = paddle.rand([4,512,16,16])
model = DoubleAtten(512)
a = model(a)
print(a.shape)
对DA性能进行验证
论文中,作者在Resnet26测试,但是对于ResNet50深层网络作者没有做相关实验,我们这次搭建一个ResNet50网络来验证性能,DA模块插入位置如下。
DA_ResNet50 搭建
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
self.conv1 = nn.Conv2D(
inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2D(
width,
width,
3,
padding=dilation,
stride=stride,
groups=groups,
dilation=dilation,
bias_attr=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2D(
width, planes * self.expansion, 1, bias_attr=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
self.attention = DoubleAtten(planes * self.expansion)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.attention(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Layer):
def __init__(self,
block,
depth=50,
width=64,
num_classes=1000,
with_pool=True):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.groups = 1
self.base_width = width
self.num_classes = num_classes
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2D
self.inplanes = 64
self.dilation = 1
self.conv1 = nn.Conv2D(
3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
if num_classes > 0:
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion), )
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
def wide_resnet50_2(pretrained=False, **kwargs):
kwargs['width'] = 64 * 2
return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs)
def wide_resnet101_2(pretrained=False, **kwargs):
kwargs['width'] = 64 * 2
return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained,
**kwargs)
da_res50 = resnet50(num_classes=10)
paddle.Model(da_res50).summary((1,3,224,224))
-------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===============================================================================
Conv2D-105 [[1, 3, 224, 224]] [1, 64, 112, 112] 9,408
BatchNorm2D-54 [[1, 64, 112, 112]] [1, 64, 112, 112] 256
ReLU-18 [[1, 64, 112, 112]] [1, 64, 112, 112] 0
MaxPool2D-2 [[1, 64, 112, 112]] [1, 64, 56, 56] 0
Conv2D-107 [[1, 64, 56, 56]] [1, 64, 56, 56] 4,096
BatchNorm2D-56 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-19 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-108 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-57 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-109 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-58 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
Conv2D-110 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
Conv2D-111 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
Conv2D-112 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
DoubleAtten-18 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-106 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-55 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-17 [[1, 64, 56, 56]] [1, 256, 56, 56] 0
Conv2D-113 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,384
BatchNorm2D-59 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-20 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-114 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-60 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-115 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-61 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
Conv2D-116 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
Conv2D-117 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
Conv2D-118 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
DoubleAtten-19 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
BottleneckBlock-18 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-119 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,384
BatchNorm2D-62 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-21 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-120 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-63 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-121 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-64 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
Conv2D-122 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
Conv2D-123 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
Conv2D-124 [[1, 256, 56, 56]] [1, 256, 56, 56] 65,792
DoubleAtten-20 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
BottleneckBlock-19 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-126 [[1, 256, 56, 56]] [1, 128, 56, 56] 32,768
BatchNorm2D-66 [[1, 128, 56, 56]] [1, 128, 56, 56] 512
ReLU-22 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-127 [[1, 128, 56, 56]] [1, 128, 28, 28] 147,456
BatchNorm2D-67 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-128 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-68 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
Conv2D-129 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-130 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-131 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
DoubleAtten-21 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-125 [[1, 256, 56, 56]] [1, 512, 28, 28] 131,072
BatchNorm2D-65 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-20 [[1, 256, 56, 56]] [1, 512, 28, 28] 0
Conv2D-132 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-69 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-23 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-133 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-70 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-134 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-71 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
Conv2D-135 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-136 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-137 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
DoubleAtten-22 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
BottleneckBlock-21 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-138 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-72 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-24 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-139 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-73 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-140 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-74 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
Conv2D-141 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-142 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-143 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
DoubleAtten-23 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
BottleneckBlock-22 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-144 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-75 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-25 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-145 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-76 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-146 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-77 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
Conv2D-147 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-148 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
Conv2D-149 [[1, 512, 28, 28]] [1, 512, 28, 28] 262,656
DoubleAtten-24 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
BottleneckBlock-23 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-151 [[1, 512, 28, 28]] [1, 256, 28, 28] 131,072
BatchNorm2D-79 [[1, 256, 28, 28]] [1, 256, 28, 28] 1,024
ReLU-26 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-152 [[1, 256, 28, 28]] [1, 256, 14, 14] 589,824
BatchNorm2D-80 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-153 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-81 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-154 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-155 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-156 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
DoubleAtten-25 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-150 [[1, 512, 28, 28]] [1, 1024, 14, 14] 524,288
BatchNorm2D-78 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-24 [[1, 512, 28, 28]] [1, 1024, 14, 14] 0
Conv2D-157 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-82 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-27 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-158 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-83 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-159 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-84 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-160 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-161 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-162 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
DoubleAtten-26 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
BottleneckBlock-25 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-163 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-85 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-28 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-164 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-86 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-165 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-87 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-166 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-167 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-168 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
DoubleAtten-27 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
BottleneckBlock-26 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-169 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-88 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-29 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-170 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-89 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-171 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-90 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-172 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-173 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-174 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
DoubleAtten-28 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
BottleneckBlock-27 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-175 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-91 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-30 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-176 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-92 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-177 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-93 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-178 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-179 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-180 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
DoubleAtten-29 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
BottleneckBlock-28 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-181 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-94 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-31 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-182 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-95 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-183 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-96 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-184 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-185 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
Conv2D-186 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 1,049,600
DoubleAtten-30 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
BottleneckBlock-29 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-188 [[1, 1024, 14, 14]] [1, 512, 14, 14] 524,288
BatchNorm2D-98 [[1, 512, 14, 14]] [1, 512, 14, 14] 2,048
ReLU-32 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-189 [[1, 512, 14, 14]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-99 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-190 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-100 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
Conv2D-191 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
Conv2D-192 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
Conv2D-193 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
DoubleAtten-31 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-187 [[1, 1024, 14, 14]] [1, 2048, 7, 7] 2,097,152
BatchNorm2D-97 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-30 [[1, 1024, 14, 14]] [1, 2048, 7, 7] 0
Conv2D-194 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,048,576
BatchNorm2D-101 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
ReLU-33 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-195 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-102 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-196 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-103 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
Conv2D-197 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
Conv2D-198 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
Conv2D-199 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
DoubleAtten-32 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
BottleneckBlock-31 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-200 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,048,576
BatchNorm2D-104 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
ReLU-34 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-201 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-105 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-202 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-106 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
Conv2D-203 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
Conv2D-204 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
Conv2D-205 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 4,196,352
DoubleAtten-33 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
BottleneckBlock-32 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
AdaptiveAvgPool2D-2 [[1, 2048, 7, 7]] [1, 2048, 1, 1] 0
Linear-2 [[1, 2048]] [1, 10] 20,490
===============================================================================
Total params: 83,985,610
Trainable params: 83,879,370
Non-trainable params: 106,240
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 429.91
Params size (MB): 320.38
Estimated Total Size (MB): 750.87
-------------------------------------------------------------------------------
{'total_params': 83985610, 'trainable_params': 83879370}
Cifar10数据准备
import paddle.vision.transforms as T
from paddle.vision.datasets import Cifar10
paddle.set_device('gpu')
# 数据准备
transform = T.Compose([
T.Resize(size=(224,224)),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
T.ToTensor()
])
train_dataset = Cifar10(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
ResNet50在Cifar10训练
# 模型准备
res50 = paddle.vision.models.resnet18(num_classes=10)
res50.train()
# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=res50.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()
res50_loss = []
res50_acc = []
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader):
inputs = data[0]
labels = data[1].unsqueeze(1)
predicts = res50(inputs)
loss = loss_fn(predicts, labels)
acc = paddle.metric.accuracy(predicts, labels)
loss.backward()
if batch_id % 100 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
if batch_id % 20 == 0:
res50_loss.append(loss.numpy())
res50_acc.append(acc.numpy())
optim.step()
optim.clear_grad()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:653: UserWarning: When training, we now always track global mean and variance.
"When training, we now always track global mean and variance.")
epoch: 0, batch_id: 0, loss is: [2.8489058], acc is: [0.1875]
epoch: 0, batch_id: 100, loss is: [1.5169208], acc is: [0.453125]
epoch: 0, batch_id: 200, loss is: [1.4419072], acc is: [0.453125]
epoch: 0, batch_id: 300, loss is: [1.303793], acc is: [0.5625]
epoch: 0, batch_id: 400, loss is: [1.1453766], acc is: [0.5]
epoch: 0, batch_id: 500, loss is: [0.97646654], acc is: [0.609375]
epoch: 0, batch_id: 600, loss is: [0.9759238], acc is: [0.6875]
epoch: 0, batch_id: 700, loss is: [0.8780391], acc is: [0.609375]
epoch: 1, batch_id: 0, loss is: [1.042079], acc is: [0.703125]
epoch: 1, batch_id: 100, loss is: [0.83035773], acc is: [0.703125]
epoch: 1, batch_id: 200, loss is: [0.8466539], acc is: [0.765625]
epoch: 1, batch_id: 300, loss is: [0.93592036], acc is: [0.71875]
epoch: 1, batch_id: 400, loss is: [0.90168834], acc is: [0.71875]
epoch: 1, batch_id: 500, loss is: [0.79210997], acc is: [0.6875]
epoch: 1, batch_id: 600, loss is: [0.76145744], acc is: [0.734375]
epoch: 1, batch_id: 700, loss is: [0.8608778], acc is: [0.765625]
epoch: 2, batch_id: 0, loss is: [0.71813166], acc is: [0.75]
epoch: 2, batch_id: 100, loss is: [0.66491127], acc is: [0.765625]
epoch: 2, batch_id: 200, loss is: [0.5883076], acc is: [0.84375]
epoch: 2, batch_id: 300, loss is: [0.49795514], acc is: [0.8125]
epoch: 2, batch_id: 400, loss is: [0.85554576], acc is: [0.75]
epoch: 2, batch_id: 500, loss is: [0.6621829], acc is: [0.8125]
epoch: 2, batch_id: 600, loss is: [0.5962929], acc is: [0.796875]
epoch: 2, batch_id: 700, loss is: [0.42166245], acc is: [0.875]
epoch: 3, batch_id: 0, loss is: [0.44783327], acc is: [0.8125]
epoch: 3, batch_id: 100, loss is: [0.51038265], acc is: [0.796875]
epoch: 3, batch_id: 200, loss is: [0.5499249], acc is: [0.828125]
epoch: 3, batch_id: 300, loss is: [0.3780781], acc is: [0.859375]
epoch: 3, batch_id: 400, loss is: [0.5779753], acc is: [0.828125]
epoch: 3, batch_id: 500, loss is: [0.52967703], acc is: [0.828125]
epoch: 3, batch_id: 600, loss is: [0.34549445], acc is: [0.921875]
epoch: 3, batch_id: 700, loss is: [0.80611515], acc is: [0.734375]
epoch: 4, batch_id: 0, loss is: [0.21761498], acc is: [0.953125]
epoch: 4, batch_id: 100, loss is: [0.558988], acc is: [0.828125]
epoch: 4, batch_id: 200, loss is: [0.4339955], acc is: [0.84375]
epoch: 4, batch_id: 300, loss is: [0.45421124], acc is: [0.84375]
epoch: 4, batch_id: 400, loss is: [0.55842537], acc is: [0.78125]
epoch: 4, batch_id: 500, loss is: [0.46884495], acc is: [0.859375]
epoch: 4, batch_id: 600, loss is: [0.29323775], acc is: [0.875]
epoch: 4, batch_id: 700, loss is: [0.24813098], acc is: [0.9375]
epoch: 5, batch_id: 0, loss is: [0.26143038], acc is: [0.890625]
epoch: 5, batch_id: 100, loss is: [0.28214324], acc is: [0.90625]
epoch: 5, batch_id: 200, loss is: [0.28322464], acc is: [0.90625]
epoch: 5, batch_id: 300, loss is: [0.5579711], acc is: [0.796875]
epoch: 5, batch_id: 400, loss is: [0.2829948], acc is: [0.921875]
epoch: 5, batch_id: 500, loss is: [0.22569823], acc is: [0.921875]
epoch: 5, batch_id: 600, loss is: [0.36216933], acc is: [0.859375]
epoch: 5, batch_id: 700, loss is: [0.39805618], acc is: [0.84375]
epoch: 6, batch_id: 0, loss is: [0.17924805], acc is: [0.953125]
epoch: 6, batch_id: 100, loss is: [0.29918194], acc is: [0.90625]
epoch: 6, batch_id: 200, loss is: [0.18964693], acc is: [0.90625]
epoch: 6, batch_id: 300, loss is: [0.36167118], acc is: [0.90625]
epoch: 6, batch_id: 400, loss is: [0.2007184], acc is: [0.9375]
epoch: 6, batch_id: 500, loss is: [0.15255482], acc is: [0.90625]
epoch: 6, batch_id: 600, loss is: [0.33478457], acc is: [0.84375]
epoch: 6, batch_id: 700, loss is: [0.2562605], acc is: [0.875]
epoch: 7, batch_id: 0, loss is: [0.11516868], acc is: [0.96875]
epoch: 7, batch_id: 100, loss is: [0.10812954], acc is: [0.953125]
epoch: 7, batch_id: 200, loss is: [0.08551887], acc is: [0.984375]
epoch: 7, batch_id: 300, loss is: [0.34660637], acc is: [0.875]
epoch: 7, batch_id: 400, loss is: [0.1358226], acc is: [0.953125]
epoch: 7, batch_id: 500, loss is: [0.13017304], acc is: [0.96875]
epoch: 7, batch_id: 600, loss is: [0.46732765], acc is: [0.828125]
epoch: 7, batch_id: 700, loss is: [0.19166371], acc is: [0.9375]
epoch: 8, batch_id: 0, loss is: [0.07097062], acc is: [0.96875]
epoch: 8, batch_id: 100, loss is: [0.03216531], acc is: [0.984375]
epoch: 8, batch_id: 200, loss is: [0.1407477], acc is: [0.9375]
epoch: 8, batch_id: 300, loss is: [0.06303663], acc is: [0.984375]
epoch: 8, batch_id: 400, loss is: [0.07170855], acc is: [0.96875]
epoch: 8, batch_id: 500, loss is: [0.1672141], acc is: [0.90625]
epoch: 8, batch_id: 600, loss is: [0.05715706], acc is: [1.]
epoch: 8, batch_id: 700, loss is: [0.2710502], acc is: [0.9375]
epoch: 9, batch_id: 0, loss is: [0.06866629], acc is: [0.984375]
epoch: 9, batch_id: 100, loss is: [0.04419067], acc is: [0.984375]
epoch: 9, batch_id: 200, loss is: [0.07565454], acc is: [0.96875]
epoch: 9, batch_id: 300, loss is: [0.15519926], acc is: [0.953125]
epoch: 9, batch_id: 400, loss is: [0.01717619], acc is: [1.]
epoch: 9, batch_id: 500, loss is: [0.08571189], acc is: [0.96875]
epoch: 9, batch_id: 600, loss is: [0.05870717], acc is: [0.96875]
epoch: 9, batch_id: 700, loss is: [0.07783952], acc is: [0.984375]
DA_ResNet50在Cifar10数据集训练
# 模型准备
da_res50 = resnet18(num_classes=10)
da_res50.train()
# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=da_res50.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()
da_res50_loss = []
da_res50_acc = []
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader):
inputs = data[0]
labels = data[1].unsqueeze(1)
predicts = da_res50(inputs)
loss = loss_fn(predicts, labels)
acc = paddle.metric.accuracy(predicts, labels)
loss.backward()
if batch_id % 100 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
if batch_id % 20 == 0:
da_res50_loss.append(loss.numpy())
da_res50_acc.append(acc.numpy())
optim.step()
optim.clear_grad()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:653: UserWarning: When training, we now always track global mean and variance.
"When training, we now always track global mean and variance.")
epoch: 0, batch_id: 0, loss is: [2.6978078], acc is: [0.15625]
epoch: 0, batch_id: 100, loss is: [1.7868621], acc is: [0.359375]
epoch: 0, batch_id: 200, loss is: [1.4539795], acc is: [0.4375]
epoch: 0, batch_id: 300, loss is: [1.3904684], acc is: [0.515625]
epoch: 0, batch_id: 400, loss is: [1.0738535], acc is: [0.578125]
epoch: 0, batch_id: 500, loss is: [1.276068], acc is: [0.546875]
epoch: 0, batch_id: 600, loss is: [0.9131775], acc is: [0.671875]
epoch: 0, batch_id: 700, loss is: [0.9701067], acc is: [0.65625]
epoch: 1, batch_id: 0, loss is: [1.0521469], acc is: [0.6875]
epoch: 1, batch_id: 100, loss is: [0.8612659], acc is: [0.6875]
epoch: 1, batch_id: 200, loss is: [0.88844216], acc is: [0.703125]
epoch: 1, batch_id: 300, loss is: [0.80324924], acc is: [0.65625]
epoch: 1, batch_id: 400, loss is: [0.7941488], acc is: [0.734375]
epoch: 1, batch_id: 500, loss is: [0.72946227], acc is: [0.734375]
epoch: 1, batch_id: 600, loss is: [0.69075954], acc is: [0.78125]
epoch: 1, batch_id: 700, loss is: [0.49935198], acc is: [0.8125]
epoch: 2, batch_id: 0, loss is: [0.6254372], acc is: [0.796875]
epoch: 2, batch_id: 100, loss is: [0.6929467], acc is: [0.75]
epoch: 2, batch_id: 200, loss is: [0.6014596], acc is: [0.765625]
epoch: 2, batch_id: 300, loss is: [0.58570975], acc is: [0.796875]
epoch: 2, batch_id: 400, loss is: [0.56933963], acc is: [0.828125]
epoch: 2, batch_id: 500, loss is: [0.7758298], acc is: [0.703125]
epoch: 2, batch_id: 600, loss is: [0.7611967], acc is: [0.765625]
epoch: 2, batch_id: 700, loss is: [0.6714103], acc is: [0.84375]
epoch: 3, batch_id: 0, loss is: [0.4566578], acc is: [0.8125]
epoch: 3, batch_id: 100, loss is: [0.4512155], acc is: [0.84375]
epoch: 3, batch_id: 200, loss is: [0.79440355], acc is: [0.765625]
epoch: 3, batch_id: 300, loss is: [0.40544614], acc is: [0.859375]
epoch: 3, batch_id: 400, loss is: [0.3677606], acc is: [0.84375]
epoch: 3, batch_id: 500, loss is: [0.5795288], acc is: [0.765625]
epoch: 3, batch_id: 600, loss is: [0.3732043], acc is: [0.890625]
epoch: 3, batch_id: 700, loss is: [0.752839], acc is: [0.71875]
epoch: 4, batch_id: 0, loss is: [0.48794582], acc is: [0.828125]
epoch: 4, batch_id: 100, loss is: [0.27974662], acc is: [0.875]
epoch: 4, batch_id: 200, loss is: [0.35176855], acc is: [0.84375]
epoch: 4, batch_id: 300, loss is: [0.43276876], acc is: [0.8125]
epoch: 4, batch_id: 400, loss is: [0.45843226], acc is: [0.796875]
epoch: 4, batch_id: 500, loss is: [0.22956792], acc is: [0.921875]
epoch: 4, batch_id: 600, loss is: [0.3693059], acc is: [0.875]
epoch: 4, batch_id: 700, loss is: [0.466375], acc is: [0.8125]
epoch: 5, batch_id: 0, loss is: [0.19983137], acc is: [0.890625]
epoch: 5, batch_id: 100, loss is: [0.34767026], acc is: [0.890625]
epoch: 5, batch_id: 200, loss is: [0.31151807], acc is: [0.890625]
epoch: 5, batch_id: 300, loss is: [0.3181343], acc is: [0.84375]
epoch: 5, batch_id: 400, loss is: [0.36661047], acc is: [0.828125]
epoch: 5, batch_id: 500, loss is: [0.20549971], acc is: [0.921875]
epoch: 5, batch_id: 600, loss is: [0.3835126], acc is: [0.875]
epoch: 5, batch_id: 700, loss is: [0.29854366], acc is: [0.9375]
epoch: 6, batch_id: 0, loss is: [0.21725702], acc is: [0.953125]
epoch: 6, batch_id: 100, loss is: [0.10980862], acc is: [0.953125]
epoch: 6, batch_id: 200, loss is: [0.24021214], acc is: [0.90625]
epoch: 6, batch_id: 300, loss is: [0.17704055], acc is: [0.953125]
epoch: 6, batch_id: 400, loss is: [0.21923174], acc is: [0.921875]
epoch: 6, batch_id: 500, loss is: [0.17121044], acc is: [0.9375]
epoch: 6, batch_id: 600, loss is: [0.10075981], acc is: [1.]
epoch: 6, batch_id: 700, loss is: [0.2989301], acc is: [0.921875]
epoch: 7, batch_id: 0, loss is: [0.19572908], acc is: [0.921875]
epoch: 7, batch_id: 100, loss is: [0.2877647], acc is: [0.9375]
epoch: 7, batch_id: 200, loss is: [0.2735257], acc is: [0.875]
epoch: 7, batch_id: 300, loss is: [0.16631857], acc is: [0.96875]
epoch: 7, batch_id: 400, loss is: [0.19605], acc is: [0.953125]
epoch: 7, batch_id: 500, loss is: [0.11403949], acc is: [0.984375]
epoch: 7, batch_id: 600, loss is: [0.2396015], acc is: [0.9375]
epoch: 7, batch_id: 700, loss is: [0.22148657], acc is: [0.921875]
epoch: 8, batch_id: 0, loss is: [0.0501596], acc is: [0.984375]
epoch: 8, batch_id: 100, loss is: [0.07064686], acc is: [0.96875]
epoch: 8, batch_id: 200, loss is: [0.1000964], acc is: [0.9375]
epoch: 8, batch_id: 300, loss is: [0.21059893], acc is: [0.9375]
epoch: 8, batch_id: 400, loss is: [0.1722869], acc is: [0.9375]
epoch: 8, batch_id: 500, loss is: [0.1418913], acc is: [0.953125]
epoch: 8, batch_id: 600, loss is: [0.09865286], acc is: [0.953125]
epoch: 8, batch_id: 700, loss is: [0.18124923], acc is: [0.953125]
epoch: 9, batch_id: 0, loss is: [0.07974362], acc is: [0.96875]
epoch: 9, batch_id: 100, loss is: [0.09133874], acc is: [0.96875]
epoch: 9, batch_id: 200, loss is: [0.18680792], acc is: [0.9375]
epoch: 9, batch_id: 300, loss is: [0.09332339], acc is: [0.96875]
epoch: 9, batch_id: 400, loss is: [0.05375797], acc is: [0.96875]
epoch: 9, batch_id: 500, loss is: [0.02048387], acc is: [1.]
epoch: 9, batch_id: 600, loss is: [0.06067649], acc is: [0.984375]
epoch: 9, batch_id: 700, loss is: [0.0408816], acc is: [1.]
import matplotlib.pyplot as plt
plt.figure(figsize=(18,12))
plt.subplot(211)
plt.xlabel('iter')
plt.ylabel('loss')
plt.title('train loss')
x=range(len(ca_res50_loss))
plt.plot(x,res50_loss,color='b',label='ResNet50')
plt.plot(x,ca_res50_loss,color='r',label='ResNet50 + CA')
plt.legend()
plt.grid()
plt.subplot(212)
plt.xlabel('iter')
plt.ylabel('acc')
plt.title('train acc')
x=range(len(ca_res50_acc))
plt.plot(x, res50_acc, color='b',label='ResNet50')
plt.plot(x, ca_res50_acc, color='r',label='ResNet50 + CA')
plt.legend()
plt.grid()
label='ResNet50')
plt.plot(x,ca_res50_loss,color='r',label='ResNet50 + CA')
plt.legend()
plt.grid()
plt.subplot(212)
plt.xlabel('iter')
plt.ylabel('acc')
plt.title('train acc')
x=range(len(ca_res50_acc))
plt.plot(x, res50_acc, color='b',label='ResNet50')
plt.plot(x, ca_res50_acc, color='r',label='ResNet50 + CA')
plt.legend()
plt.grid()
plt.show()
绘制 ResNet50 和 DA_ResNet50 训练曲线
总结
- 提出了一个通用的公式,通过通用的收集和分布函数来捕获长期的特征相关性
- 提出了一种用于收集和分布长距离特征的双注意块,它是一种有效的二次特征统计和自适应特征分配的体系结构。该块可以用较低的计算和内存占用来建模长期的相互依赖关系,同时显著提高图像/视频识别性能
- 通过广泛的消融研究来调查提出的A2-Net的影响,并通过与当前技术水平的比较来证明它在图像识别和视频动作识别任务的一些公共基准上的优越性能。
特别感谢:仰世而来丶(本文参考了https://aistudio.baidu.com/aistudio/projectdetail/1884947?channelType=0&channel=0)
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
更多推荐
所有评论(0)