转自AI Studio,原文链接:

【第六期论文复现赛-变化检测】SNUNet-CD - 飞桨AI Studio

 

第六期论文复现赛-变化检测】A Densely Connected Siamese Network for Change Detection of VHR Images

一、前言介绍

论文简介

SNUNET-CD的结构如下图所示。作者提出了一种用于变化检测的稠密链接网络,即SNUNet-CD (siamese network和NestedUNet的组合),受DenseNet和NestedUNet的启发,设计了一个密集的连接的连体网络用于变更检测。通过编码器和解码器之间、解码器和解码器之间的密集跳过连接,它可以保持高分辨率、细粒度的表示。提出了集成通道关注模块(ECAM)的深度监控方法。通过ECAM,可以细化不同语义层次的最具代表性的特征,并用于最终的分类。

二、网络结构

:这里有关网络的分析,参考CSDN博客:狗都能看懂的变化检测网络Siam-NestedUNet讲解——解决工业检测的痛点

  • 模型的整体架构已在上图展示,该网络是典型的encoder–decoder结构,可以分为三大部分进行拆解:
    • 网络的backbone,类似于UNet++
    • 提取两幅图像差异信息的孪生网络结构
    • 网络最后为了加强不同级别输出的信息的ECAM模块

2.1 Backbone

  • 可以看到模型的主干由UNet++衍生而来。我们不看双输入部分,只看backbone,从x1,1x^{1,1}x1,1 下采样到x4,0x^{4,0}x4,0,然后再通过上采样到x0,4x^{0,4}x0,4这一部分,它是呈现一个U型结构,和UNet类似,是经典的图像分割中非常经典的Encoder-Decoder结构。
  • 同时在两层卷积中使用了类似残差网络的连接,这是参考DenseNet采用密集残差边,可以解决两个问题:
    • 梯度回传时,浅层网络难以优化的问题
    • 加强特征融合,使得深层网络可以结合浅层网络的特征,同时融合了低层的细节信息和高层的语义信息,增大了低层的感受野,使得低层在做小目标检测时能获得更多上下文信息

2.2 Siamese Network-孪生网络

  • 如上图所示,孪生网络有两个输入,其诞生的初衷是为了解决小数据集泛化性差的问题。一个输入对应一个网络,最终会的得到两个输出,这两个输出对应这两个输入的高维特征,对其简单做差可近似看为二者的loss,loss越小代表差异越小,loss越大代表差异越大。通常情况下,两个输入的网络权重是共享的
  • SNUNet中,在输入图片时,将两个时相的图片分别进行encode,并只要在跳跃连接时将两组特征concat起来,再进行相对应的decode,得到该级别的输出

上述两个模块的paddle源代码如下所示

class SNUNet(nn.Layer, KaimingInitMixin):
    """
    Args:
        in_channels (int): The number of bands of the input images.
        num_classes (int): The number of target classes.
        width (int, optional): The output channels of the first convolutional layer. Default: 32.
    """

    def __init__(self, in_channels, num_classes, width=32):
        super(SNUNet, self).__init__()

        filters = (width, width * 2, width * 4, width * 8, width * 16)

        self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0])
        self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1])
        self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2])
        self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3])
        self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4])
        self.down1 = MaxPool2x2()
        self.down2 = MaxPool2x2()
        self.down3 = MaxPool2x2()
        self.down4 = MaxPool2x2()
        self.up1_0 = Up(filters[1])
        self.up2_0 = Up(filters[2])
        self.up3_0 = Up(filters[3])
        self.up4_0 = Up(filters[4])

        self.conv0_1 = ConvBlockNested(filters[0] * 2 + filters[1], filters[0],
                                       filters[0])
        self.conv1_1 = ConvBlockNested(filters[1] * 2 + filters[2], filters[1],
                                       filters[1])
        self.conv2_1 = ConvBlockNested(filters[2] * 2 + filters[3], filters[2],
                                       filters[2])
        self.conv3_1 = ConvBlockNested(filters[3] * 2 + filters[4], filters[3],
                                       filters[3])
        self.up1_1 = Up(filters[1])
        self.up2_1 = Up(filters[2])
        self.up3_1 = Up(filters[3])

        self.conv0_2 = ConvBlockNested(filters[0] * 3 + filters[1], filters[0],
                                       filters[0])
        self.conv1_2 = ConvBlockNested(filters[1] * 3 + filters[2], filters[1],
                                       filters[1])
        self.conv2_2 = ConvBlockNested(filters[2] * 3 + filters[3], filters[2],
                                       filters[2])
        self.up1_2 = Up(filters[1])
        self.up2_2 = Up(filters[2])

        self.conv0_3 = ConvBlockNested(filters[0] * 4 + filters[1], filters[0],
                                       filters[0])
        self.conv1_3 = ConvBlockNested(filters[1] * 4 + filters[2], filters[1],
                                       filters[1])
        self.up1_3 = Up(filters[1])

        self.conv0_4 = ConvBlockNested(filters[0] * 5 + filters[1], filters[0],
                                       filters[0])

        self.ca_intra = ChannelAttention(filters[0], ratio=4)
        self.ca_inter = ChannelAttention(filters[0] * 4, ratio=16)

        self.conv_out = Conv1x1(filters[0] * 4, num_classes)

        self.init_weight()

    def forward(self, t1, t2):
        x0_0_t1 = self.conv0_0(t1)
        x1_0_t1 = self.conv1_0(self.down1(x0_0_t1))
        x2_0_t1 = self.conv2_0(self.down2(x1_0_t1))
        x3_0_t1 = self.conv3_0(self.down3(x2_0_t1))

        x0_0_t2 = self.conv0_0(t2)
        x1_0_t2 = self.conv1_0(self.down1(x0_0_t2))
        x2_0_t2 = self.conv2_0(self.down2(x1_0_t2))
        x3_0_t2 = self.conv3_0(self.down3(x2_0_t2))
        x4_0_t2 = self.conv4_0(self.down4(x3_0_t2))

        x0_1 = self.conv0_1(
            paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1))
        x1_1 = self.conv1_1(
            paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1))
        x0_2 = self.conv0_2(
            paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1))

        x2_1 = self.conv2_1(
            paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1))
        x1_2 = self.conv1_2(
            paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(
            paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1))

        x3_1 = self.conv3_1(
            paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1))
        x2_2 = self.conv2_2(
            paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(
            paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(
            paddle.concat(
                [x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))

        out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1)

        intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0)
        m_intra = self.ca_intra(intra)
        out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1, 4, 1, 1)))

        pred = self.conv_out(out)
        return [pred]

2.3 Ensemble Channel Attention Module- 集成通道注意力模块

  • 在经过SNUNet的encode-decode之后,最终获得4个和原图大小相同的输出。虽然大小一样,但是不同的输出之间语义层次和空间位置的表达也不相同。
  • 浅层的输出具有准确的空间位置信息,而深层的输出具有更细致的语义信息,因此融合这些特征时需要考虑不同层次输出之间语义信息空间位置的差异,SNUNet采用ECAM模块来进行融合
  • ECAM模块的具体结构在整体结构图中的(b)部分,实现的逻辑可以看成是一个残差块 + 两个通道注意力模块构成,上面贴的代码已经将ECAM大致逻辑写好,下面贴通道注意力机制的代码
class ChannelAttention(nn.Layer):
    """
    Args:
        in_ch (int): The number of channels of the input features.
        ratio (int, optional): The channel reduction ratio. Default: 8.
    """

    def __init__(self, in_ch, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)
        self.max_pool = nn.AdaptiveMaxPool2D(1)
        self.fc1 = Conv1x1(in_ch, in_ch // ratio, bias=False, act=True)
        self.fc2 = Conv1x1(in_ch // ratio, in_ch, bias=False)

    def forward(self, x):
        avg_out = self.fc2(self.fc1(self.avg_pool(x)))
        max_out = self.fc2(self.fc1(self.max_pool(x)))
        out = avg_out + max_out
        return F.sigmoid(out)

三、复现精度

在CDD的测试集的测试效果如下表,达到验收指标,F1-Score=95.3%

Networkoptepochbatch_sizedatasetF1-ScoremIOU
SNUNET-32AdamW10016CDD95.54%95.12%

注意:验收评估的模型为SNUNet-32

四、环境与数据准备

  • 克隆仓库

In [1]

!git clone https://github.com/kongdebug/SNUNet-Paddle.git
正克隆到 'SNUNet-Paddle'...
remote: Enumerating objects: 1027, done.
remote: Counting objects: 100% (1027/1027), done.
remote: Compressing objects: 100% (804/804), done.
remote: Total 1027 (delta 214), reused 960 (delta 189), pack-reused 0
接收对象中: 100% (1027/1027), 12.22 MiB | 7.44 MiB/s, 完成.
处理 delta 中: 100% (214/214), 完成.
检查连接... 完成。
  • 解压数据,并进行处理

In [2]

# 解压数据
!unzip -qo data/data29275/CDData.zip -d ./work/

In [ ]

# 安装相应依赖
%cd SNUNet-Paddle/
!pip install -r requirements.txt

In [4]

# 生成模型训练需要的.txt文件
!python ./data/process_cdd_data.py --data_dir=../work/Real/subset
数据集划分已完成。

五、快速体验

  • 模型训练
    • 注意:由于SNUNET-CD在训练时没有使用Normalize处理,所以可能会导致前几个epoch的loss比较大,在第10到第16个epoch时可正常

In [ ]

!python tutorials/train/snunet.py --data_dir=../work/Real/subset --out_dir=./output/snunet/
  • 模型验证
    • 最优模型权重已放入work/output/snunet/best_model文件夹下
    • 可将--weight_path参数替换为自己训练出的模型权重路径

In [10]

!python tutorials/eval/snunet_eval.py --data_dir=../work/Real/subset \
                                      --weight_path=../work/output/snunet/best_model/model.pdparams
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
[04-26 02:26:41 MainThread @logger.py:242] Argv: tutorials/eval/snunet_eval.py --data_dir=../work/Real/subset --weight_path=../work/output/snunet/best_model/model.pdparams
[04-26 02:26:41 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
2022-04-26 02:26:42 [INFO]	10000 samples in file ../work/Real/subset/train.txt
2022-04-26 02:26:42 [INFO]	3000 samples in file ../work/Real/subset/test.txt
W0426 02:26:42.497771  4362 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 02:26:42.502462  4362 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2022-04-26 02:26:45 [INFO]	Loading pretrained model from ../work/output/snunet/best_model/model.pdparams
2022-04-26 02:26:45 [INFO]	There are 186/186 variables loaded into SNUNet.
2022-04-26 02:26:45 [INFO]	Start to evaluate(total_samples=3000, total_steps=3000)...
OrderedDict([('miou', 0.9511789327930941), ('category_iou', array([0.98762963, 0.91472823])), ('oacc', 0.989078862508138), ('category_acc', array([0.99284724, 0.96190638])), ('kappa', 0.9492419817634077), ('category_F1-score', array([0.99377632, 0.95546534]))])
  • 模型预测
    • 使用最优模型权重对模型进行预测
    • 参数介绍:
      • weight 训练好的权重
      • A,B, 是T1影像路径,T2影像路径
      • pre 预测图片存储的位置

In [11]

!python tutorials/predict/snunet_pred.py --weight=../work/output/snunet/best_model/model.pdparams \
                                         --A=../work/Real/subset/test/A/00002.jpg --B=../work/Real/subset/test/B/00002.jpg \
                                         --pre=../work/pre.png
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
[04-26 02:30:00 MainThread @logger.py:242] Argv: tutorials/predict/snunet_pred.py --weight=../work/output/snunet/best_model/model.pdparams --A=../work/Real/subset/test/A/00002.jpg --B=../work/Real/subset/test/B/00002.jpg --pre=../work/pre.png
[04-26 02:30:00 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
W0426 02:30:00.780316  4623 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 02:30:00.785306  4623 device_context.cc:465] device: 0, cuDNN Version: 7.6.
finish!

In [13]

# 展示预测的结果,最后一张为真值
import matplotlib.pyplot as plt
from PIL import Image

T1 = Image.open(r"../work/Real/subset/test/A/00002.jpg")
T2 = Image.open(r"../work/Real/subset/test/B/00002.jpg")
GT = Image.open(r"../work/Real/subset/test/OUT/00002.jpg")
pred = Image.open(r"../work/pre.png")

plt.figure(figsize=(16, 8))
plt.subplot(1,4,1), plt.title('T1')
plt.imshow(T1), plt.axis('off')
plt.subplot(1,4,2), plt.title('T2') 
plt.imshow(T2), plt.axis('off')
plt.subplot(1,4,3), plt.title('pred') 
plt.imshow(pred), plt.axis('off')
plt.subplot(1,4,4), plt.title('GT') 
plt.imshow(GT), plt.axis('off')
plt.show()

<Figure size 1152x576 with 4 Axes>
  • SNUNet模型导出

In [14]

!python deploy/export/export_model.py --model_dir=../work/output/snunet/best_model/ \
                                      --save_dir=./inference_model/ 
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
[04-26 02:34:02 MainThread @logger.py:242] Argv: deploy/export/export_model.py --model_dir=../work/output/snunet/best_model/ --save_dir=./inference_model/
[04-26 02:34:02 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
W0426 02:34:02.990654  4975 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 02:34:02.995680  4975 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2022-04-26 02:34:05 [INFO]	Model[SNUNet] loaded.
2022-04-26 02:34:08 [INFO]	The model for the inference deployment is saved in ./inference_model/.

六、TIPC基础链条测试

该部分依赖auto_log,需要进行安装,安装方式如下:

auto_log的详细介绍参考https://github.com/LDOUBLEV/AutoLog

In [ ]

!git clone https://github.com/LDOUBLEV/AutoLog
!pip3 install -r requirements.txt
!python3 setup.py bdist_wheel
!pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
  • 准备数据

In [ ]

!bash ./test_tipc/prepare.sh test_tipc/configs/SNUNET/train_infer_python.txt 'lite_train_lite_infer'
  • 测试

In [ ]

!bash test_tipc/test_train_inference_python.sh test_tipc/configs/SNUNET/train_infer_python.txt 'lite_train_lite_infer'

七、项目总结

  • 本项目对SNUNet进行了简单的介绍,包括总体的模型结构和具体的网络细节,帮助大家更好的理解SNUNet网络
  • 同时本项目给出了SNUNet的Paddle复现仓库的使用方法,可以快速进行训练、评估和预测,以及SNUNet模型的导出
  • 遥感类的复现使用PaddleRS套件可以赢在起跑线上,同时官方给出的《论文复现赛指南》非常有借鉴意义,教程中的每个复现关键节点都已经指出,教会了我如何对一篇论文进行复现,除了本次比赛,对今后的学习也帮助很大

八、致谢

  • 再次由衷的感谢飞桨团队提供的算力支持,感谢RD小姐姐的解答与帮助。也很感激飞桨能够开源很多套件帮助我的学习与科研。同时再次感谢古代飞奔向未来的样子两位的帮助。
  • 再话痨一下,自己当初接触AI studio就是因为要做毕设了,但实验室的卡要排队,在网上看到百度能每天提供免费的算力,果断加入了”白嫖“的行列。后来毕设要用到GAN,苦于如何入门的时候,飞桨恰好推出生成对抗网络七日打卡营的课程。真的很感谢飞桨,希望以后越来越好,谢谢!
Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐