背景

流年似水,青春永怀。所以,我们会用照片和视频记录下曾经的青葱岁月。但是,镜头前神出鬼没的“金手指”、表情违和的路人甲,以及照片中光辉的我们脚下的垃圾袋,都能轻易将那些本应是瞬间定格的“永久美好”变成了“永久遗憾”。现在有了深度学习模型的加持,去除这些图片、视频中的“不速之客”越来越易如反掌,更不用说修复划痕、折痕这些小问题了。只需将图片中不想保留的内容“扣掉”,模型就会根据擦除内容周围的上下文信息自动补全图片背景。专业领域管这种图像处理叫做“Image Inpainting”,即“图像补全”。本文就来唠唠如今深度学习加持的“图像 Inpainting” 为什么行、怎么行以及谁最行。(当然是下面要介绍的主角AOT-GAN最行了~)

图像补全这个活儿,如果是修复比较细的划痕,一些基于插值的传统方法已经取得了非常不错的修复效果。而一些比较大的窟窿,插值方法修复起来就力不从心了。因为,如果图像缺损的的部分较大的话,修复的部分会包含比划痕更多的结构信息。插值方法只能通过临近像素近似缺损像素,显然是很难补全缺损的结构信息的。这时,使用深度学习模型中的生成对抗网络(Generative Adversarial Network),修补图像的大面积缺失往往能取得更好的修复效果。当然,使用生成对抗网络GAN也不是能够一劳永逸的。各路大佬们通仍在通过调整模型结构、优化损失函数等手段,不断提高模型的图像补全能力,平衡生成内容的丰富性和忠实度等等。这里介绍的 AOT-GAN 就是大佬们最新出品的图像补全 SOTA 模型。

本文主要针对 AOT-GAN 所做的改进进行介绍,如果有小伙伴对 GAN 模型还不太熟悉可参考我原来写的一些项目,可以从这一篇开始踏坑~~
《一文搞懂生成对抗网络之经典GAN》

论文介绍

在这里插入图片描述

1.问题的提出与解决

现今,使用GAN网络对高分变率图像进行补全时产生的结构扭曲和细节模糊还不是处理得很好。Aggregated Contextual Transformations forHigh-Resolution Image Inpainting这篇文章提出的 AOT-GAN (上下文信息聚合转换生成对抗网络)主要从以下这两个方面着手解决这个问题:

  1. 有效利用远距离信息增强填充图像的合理性。
  2. 提高图像大面积缺失部分的填充质量。

为此模型主要进行了两个方面的针对改进:

  1. 使用聚合上下文特征转换生成对抗网络(Aggregated COntextual-Transformation GAN),聚合多尺度上下文特征以增强对远距离特征和丰富结构细节的捕捉。这是通过改进网络结构提出 AOT 模块实现的。
  2. 采用量身定制的“mask 预测”增强判别器,以使其能够更好的区分生成的部分和原图片部分。这是通过修改 Loss 函数的计算实现的。

在这里插入图片描述

上图可见 AOT-GAN 尤其是在大面积填充的结构和细节上超越了以往的模型。

在这里插入图片描述

FID、SSIM 和 PSNR 指标也都较以往的模型有显著提高。

下面,就从模型结构和loss两方面介绍一下这些改进。

2.模型结构

AOT-GAN 是一个比较典型的 GAN 模型,由生成器和判别器组成。

在这里插入图片描述

1)判别器

判别器由5层2D卷积层堆叠而成,通过卷积的 stride 步幅和 padding 逐步进行下采样,使用了谱归一化稳定参数更新。激活使用 LeakyRELU 避免神经元停止更新。输入是形状为 3 × 512 × 512 (c, h, w)的数据集图片(判别真图片时)或“生成的合成图片”(判别甲图片时,由生成器生成的缺失部分和数据集图片的其余部分拼接而成)。输出形状为 1 × 30 × 30,实际上就是一个 PatchGAN 的判别器。

感觉这个判别器是比较轻量级的,应该是因为这是个 Inpainting 任务的缘故。这里使用判别器计算对抗 loss 的目的主要是提高图片的清晰度,避免其他部分的均方差 loss 导致的图片模糊。所以,在图片补全任务中,对抗 loss 的权重设得比较小,判别器参数也比较少。不然的话,GAN 的“想象力”太丰富了,反而会导致补全部分的还原忠实度降低。

2)生成器

生成器的首尾两端分别是三层的编码器和解码器结构,中间层的特征提取模块(上图中 Generator 结构中的蓝色部分)由8层 AOT Block 堆叠而成。输入由训练集图片挖去待补全部分(3 × 512 × 512)叠加一层 mask (1 × 512 × 512)组成。3通道的RGB图片加上一层1通道的mask层,所以输入的形状是 4 × 512 × 512 (c, h, w)。输出是形状为 3 × 512 × 512 (c, h, w)的补全后图片。

这里生成器没有像现在好多 GAN 模型那样也加上谱归一化,应该是因为 Loss 中对抗损失的比重较小,更新还是挺稳定的。激活用的是 RELU,我在后来鼓捣模型时使用换成 LeakyRELU(其它好多GAN都是全用LeakyRELU的),但效果反而降低。明白的大佬请留言指点!

3)AOT Block

聚合长距离上下文信息的活儿就是靠这生成器里的8层 AOT Block 搞定的,这也是 AOT-GAN 法术之奥义所在~~ 。下面,我们就掰开包子品品馅,看看这 AOT Block 是怎么聚合上下文信息的。

在这里插入图片描述

对比普通 CNN 网络常用的 Residual Block 残差块, AOT Block 做了以下改进:

  1. 使用膨胀率分别为1、2、3、4的4组空洞卷积,在通道维度上拼接在一起,代替普通残差块里的第一层 3 × 3 卷积层。每组膨胀卷积的通道数是原来的四分之一,拼接在一起后保持通道数不变。这样,AOT Block 就能将不同空间距离的上下文特征聚合在一起。
  2. 使用了一组 3 × 3 卷积,代替了普通残差块中的 Identity 跨层连接。而且,AOT Block 在聚合模块和添加的那组 3 × 3 卷积构成的两组通道之间添加了一个 Gate 门限,让模型在空间维度上自动选择是否使用聚合通道。这样做主要是为了降低图片的填充区域与原保留区域之间的颜色偏差。

文章作者在这些改进上做了详实的消融实验,以验证其有效性:

在这里插入图片描述

上图对比了使用 AOT Block 前后的效果。

使用前: 由卷积核尺寸为 3 × 3 的残差块堆叠起来的效果可见,大片填充区域的结构信息严重缺失。虽说理论上 3 × 3 卷积堆叠起来可见范围是能够覆盖远距特征的,但从实践上看填充时还是主要参考了临近特征。

使用后: 加持了 AOT Block 模块带来的空间注意力能力后,肉眼可见,生成效果大幅提升。

而且,作者对前述 AOT Block 中,用 Gate 门限代替 Identity 连接的改进,也分别进行了采用 Identity 连接、GatedConv.、1 × 1 Conv.以及本文 Gate 方案的消融实验:

在这里插入图片描述
在这里插入图片描述

结果,无论是图片直观还是指标看, AOT 方案的效果都是最好的,必须uuuuu的~~

觉得 AOT Block 的设计是挺巧妙的,而且非常能涨点,就是相对与常见的 Residual Block 来说,如果想对参数进行裁剪优化,还是比较有挑战性的。

3.损失函数

AOT-GAN 的损失函数使用了均方差损失,即“最小二乘GAN”的损失函数。

1)判别器损失

在这里插入图片描述

这个 Loss 函数的作用就是:将判别器真实图片的输出拉向“真值1”,将图片的生成部分的判别值拉向“0”。因为,从上式看当 D(X) 越接近1,D(Z)越接近经过 σ 缩放的标签反转后的向量 (1-M) 时,整个 Loss 的值越小。(普通的 GAN 的判别器会直接将 D(Z) 拉向 “0”,这里的处理是 AOT-GAN 做的 Soft Mask PatchGAN 的改进,后面还会详述)

这是 GAN 模型的基本套路,这里就不在赘述。在这部分知识上需要进一步启发的小伙伴可从本文开头提到的那个项目《一文搞懂生成对抗网络之经典GAN》开始进补~~

这里需要着重介绍一下的是 AOT-GAN 对采用的 PatchGAN 判别器的改进版本-- Soft Mask-Guided PatchGAN (SM-PatchGAN)

在这里插入图片描述

首先,如上图所示。典型的 PatchGAN 判别器的做法是将全部的真图片判别输出特征图拉向=“真值1”,相应的将生成图片的判别输出拉向“假值0”。但在 AOT-GAN 中,作者为了突出优化图片的补全部分,所以在判别生成的补全图片时,如上图中间 HM-PatchGAN 所示,分别将判别输出中的补全部分拉向“0”,而将判别输出中生成的原图片部分的拉向“1”。但是,这样处理虽然强化了补全部分的生成质量,但也会导致生成补全部分与原图片部分的特征融合不够充分。因此,作者又使用得软标签得 SM-PatchGAN 来进一步提升效果。SM-PatchGAN 的处理也比较方便,只要给硬标签 Hard Mask 加上高斯滤波进行模糊处理即可。

在这里插入图片描述

视觉效果上来看 SM-PatchGAN 是最好的

在这里插入图片描述

指标上看 SM-PatchGAN 也是综合最佳的

2)生成器损失

AOT-GAN 的生成器损失由 重建(Reconstruct) Loss、Perceptual Loss、Style Loss、对抗(Adv.) Loss 四部分组成。

  • 重建(Reconstruct) Loss 就是直接计算生成图片与原图片的像素误差,也就是 L1 Loss,公式如下:

在这里插入图片描述

  • Perceptual Loss 计算生成图片与原图片经过在 ImageNet 上预训练的 VGG19 模型输出的各层特征图的 L1 Loss,公式如下:

在这里插入图片描述

  • Style Loss 计算生成图片与原图片的 SSIM (Structure Similarity Index Measure),公式如下:

在这里插入图片描述

  • 对抗(Adv.) Loss,也即 GAN Loss 与判别器一样使用均方差损失(最小二乘损失),公式如下:

在这里插入图片描述

  • 生成器总的损失为这四部分 Loss 加权相加,加权比例为:λadv = 0.01, λrec = 1,λper = 0.1, λsty = 250。

公式如下:

在这里插入图片描述

现在 GAN 模型的 Loss 组成也都差不太多,基本上就是: 重建 L1 Loss 引导生成宏观结构, 对抗 Loss 提升细节清晰度,后来又利用大数据集上的预训练 VGG19 模型计算 Perceptual Loss,以对齐生成图片与原图片各个层次尺度上的特征。现在,直接直接将生成图片的评价指标 SSIM 拿来作为目标函数,去优化生成图片的色调,风格等结构化特征了。一个 GAN 模型开始用,然后各个都跟着用,反正就是为了涨点么~~ 。

从我玩过的几个GAN模型上发现,相对于其他几种常见的GAN损失函数,如 VanillaGAN 的负对数似然损失、WGAN 的沃森斯坦距离损失以及 HingeLoss(合页损失等)来说,最小二乘损失更加适合在 AOT-GAN 中使用。因为,经验上,即使最小二乘损失在数值已经被优化得非常小得情况下,GAN 模型仍然能比较稳定得训练,并持续善生成效果。这就非常适合这里 对抗 Loss 占比和绝对数值都很小得情况。

代码

读万卷书,码万里路…Talk is cheap, Show me the code ~~

0.准备数据集

%cd /home/aistudio/
import os
if not os.path.exists('data/aot'):
    os.mkdir('data/aot')
!unzip -qa -d data/ data/data89198/test_mask.zip
%mv data/mask/testing_mask_dataset/ data/aot/train_mask
%rm -r data/mask/
!tar -xf /home/aistudio/data/data89198/val_large.tar -C data
%mv data/val_large/ data/aot/train_img

if not os.path.exists('data/aot'):
    os.mkdir('data/aot')
!unzip -qa -d data/aot/ val_img.zip
!unzip -qa -d data/aot/ val_mask.zip
/home/aistudio

1.训练

1)设置全局变量、超参

class OPT():
    def __init__(self):
        super(OPT, self).__init__()
        # 在AI Studio上用A100单卡训练时,设置为8(bs=8);在使用V100四卡训练时设为6(bs=6x4=24)
        # self.batch_size = 6 # V100单卡、多卡训练
        # self.batch_size = 8 # A100单卡训练
        self.batch_size = 1
        self.img_size = 512 # 生成图片尺寸
        self.rates = [1, 2, 4, 8] # 各个尺度空洞卷积的膨胀率
        self.block_num = 8 # 生成器中AOT模块的层数
        self.l1_weight = 1 # L1 Loss的加权
        self.style_weight = 250 # Style Loss的加权
        self.perceptual_weight = .1 # Perceptu Loss的加权
        self.adversal_weight = .01 # GAN Loss的加权
        self.lrg = 1e-4 # 生成器学习率
        self.lrd = 1e-4 # 判别器学习率
        self.beta1 = .5 # Adam优化器超参
        self.beta2 = .999 # Adam优化器超参

        self.dataset_path = 'data/aot' # 训练、验证数据集存放路径
        self.output_path = 'output' # chenk point,log等存放路径
        self.vgg_weight_path = 'data/data89198/vgg19feats.pdparams' # vgg19 预训练参数存放路径
        

opt = OPT()

2)定义数据集处理过程

from PIL import Image, ImageOps
import os
import numpy as np

from paddle.io import Dataset, DataLoader
from paddle.vision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, RandomRotation, ColorJitter, Resize

# 定义数据集对象
class PlaceDateset(Dataset):
    def __init__(self, opt, istrain=True):
        super(PlaceDateset, self).__init__()

        self.image_path = []
        def get_all_sub_dirs(root_dir): # 递归读取全部子文件夹图片文件
            file_list = []
            def get_sub_dirs(r_dir):
                for root, dirs, files in os.walk(r_dir):
                    if len(files) > 0:
                        for f in files:
                            file_list.append(os.path.join(root, f))
                    if len(dirs) > 0:
                        for d in dirs:
                            get_sub_dirs(os.path.join(root, d))
                    break
            get_sub_dirs(root_dir)
            return file_list

        # 设置训练集、验证集数据(图片和mask)存放路径
        if istrain:
            self.img_list = get_all_sub_dirs(os.path.join(opt.dataset_path, 'train_img'))
            self.mask_dir = os.path.join(opt.dataset_path, 'train_mask')
        else:
            self.img_list = get_all_sub_dirs(os.path.join(opt.dataset_path, 'val_img'))
            self. mask_dir = os.path.join(opt.dataset_path, 'val_mask')
        self.img_list = np.sort(np.array(self.img_list))
        _, _, mask_list = next(os.walk(self.mask_dir))
        self.mask_list = np.sort(mask_list)


        self.istrain = istrain
        self.opt = opt

        # 训练阶段分别应用至图片和mask的数据增强
        if istrain:
            self.img_trans = Compose([
                RandomResizedCrop(opt.img_size),
                RandomHorizontalFlip(),
                ColorJitter(0.05, 0.05, 0.05, 0.05),
            ])
            self.mask_trans = Compose([
                Resize([opt.img_size, opt.img_size], interpolation='nearest'),
                RandomHorizontalFlip(),
            ])
        else:
            self.img_trans = Compose([
                Resize([opt.img_size, opt.img_size], interpolation='bilinear'),
            ])
            self.mask_trans = Compose([
                Resize([opt.img_size, opt.img_size], interpolation='nearest'),
            ])
        
        self.istrain = istrain

    # 送入模型的RGB图片数据归一化到(-1,+1)区间,形状为[n, c, h, w]
    # mask尺寸与图片一致,0对应原图片像素,1对应缺失像素
    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx])
        mask = Image.open(os.path.join(self.mask_dir, self.mask_list[np.random.randint(0, self.mask_list.shape[0])]))
        img = self.img_trans(img)
        mask = self.mask_trans(mask)

        mask = mask.rotate(np.random.randint(0, 45))
        img = img.convert('RGB')
        mask = mask.convert('L')

        img = np.array(img).astype('float32')
        img = (img / 255.) * 2. - 1.
        img = np.transpose(img, (2, 0, 1))
        mask = np.array(mask).astype('float32') / 255.
        mask = np.expand_dims(mask, 0)

        return img, mask, self.img_list[idx]

    def __len__(self):
        return len(self.img_list)

3)定义模型结构,包括 AOT Block、生成器、判别器和计算 Perceptual Loss 的VGG19 模型

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.utils import spectral_norm

#  Aggregated Contextual Transformations模块,是构成模型的核心结构,负责提取多尺度特征
class AOTBlock(nn.Layer):
    def __init__(self, dim, rates):
        super(AOTBlock, self).__init__()

        self.rates = rates
        for i, rate in enumerate(rates):
            self.__setattr__(
                'block{}'.format(str(i).zfill(2)), 
                nn.Sequential(
                    nn.Pad2D(rate, mode='reflect'),
                    nn.Conv2D(dim, dim//4, 3, 1, 0, dilation=int(rate)),
                    nn.ReLU()))
        self.fuse = nn.Sequential(
            nn.Pad2D(1, mode='reflect'),
            nn.Conv2D(dim, dim, 3, 1, 0, dilation=1))
        self.gate = nn.Sequential(
            nn.Pad2D(1, mode='reflect'),
            nn.Conv2D(dim, dim, 3, 1, 0, dilation=1))

    def forward(self, x):
        out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))]
        out = paddle.concat(out, 1)
        out = self.fuse(out)
        mask = my_layer_norm(self.gate(x))
        mask = F.sigmoid(mask)
        return x * (1 - mask) + out * mask

def my_layer_norm(feat):
    mean = feat.mean((2, 3), keepdim=True)
    std = feat.std((2, 3), keepdim=True) + 1e-9
    feat = 2 * (feat - mean) / std - 1
    feat = 5 * feat
    return feat

class UpConv(nn.Layer):
    def __init__(self, inc, outc, scale=2):
        super(UpConv, self).__init__()
        self.scale = scale
        self.conv = nn.Conv2D(inc, outc, 3, 1, 1)

    def forward(self, x):
        return self.conv(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True))

# 生成器
class InpaintGenerator(nn.Layer):
    def __init__(self, opt):
        super(InpaintGenerator, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Pad2D(3, mode='reflect'),
            nn.Conv2D(4, 64, 7, 1, 0),
            nn.ReLU(),
            nn.Conv2D(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2D(128, 256, 4, 2, 1),
            nn.ReLU()
        )

        self.middle = nn.Sequential(*[AOTBlock(256, opt.rates) for _ in range(opt.block_num)])

        self.decoder = nn.Sequential(
            UpConv(256, 128),
            nn.ReLU(),
            UpConv(128, 64),
            nn.ReLU(),
            nn.Conv2D(64, 3, 3, 1, 1)
        )

    def forward(self, x, mask):
        x = paddle.concat([x, mask], 1)
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = paddle.tanh(x)

        return x

# 判别器
class Discriminator(nn.Layer):
    def __init__(self, ):
        super(Discriminator, self).__init__()
        inc = 3
        self.conv = nn.Sequential(
            spectral_norm(nn.Conv2D(inc, 64, 4, 2, 1, bias_attr=False)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2D(64, 128, 4, 2, 1, bias_attr=False)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2D(128, 256, 4, 2, 1, bias_attr=False)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2D(256, 512, 4, 1, 1, bias_attr=False)),
            nn.LeakyReLU(0.2),
            nn.Conv2D(512, 1, 4, 1, 1)
        )

    def forward(self, x):
        feat = self.conv(x)
        return feat

# 用于计算Perceptual Loss和Style Loss的vgg19模型(使用ImageNet预训练权重)
class VGG19F(nn.Layer):
    def __init__(self):
        super(VGG19F, self).__init__()

        self.feature_0 = nn.Conv2D(3, 64, 3, 1, 1)
        self.relu_1 = nn.ReLU()
        self.feature_2 = nn.Conv2D(64, 64, 3, 1, 1)
        self.relu_3 = nn.ReLU()

        self.mp_4 = nn.MaxPool2D(2, 2, 0)
        self.feature_5 = nn.Conv2D(64, 128, 3, 1, 1)
        self.relu_6 = nn.ReLU()
        self.feature_7 = nn.Conv2D(128, 128, 3, 1, 1)
        self.relu_8 = nn.ReLU()

        self.mp_9 = nn.MaxPool2D(2, 2, 0)
        self.feature_10 = nn.Conv2D(128, 256, 3, 1, 1)
        self.relu_11 = nn.ReLU()
        self.feature_12 = nn.Conv2D(256, 256, 3, 1, 1)
        self.relu_13 = nn.ReLU()
        self.feature_14 = nn.Conv2D(256, 256, 3, 1, 1)
        self.relu_15 = nn.ReLU()
        self.feature_16 = nn.Conv2D(256, 256, 3, 1, 1)
        self.relu_17 = nn.ReLU()

        self.mp_18 = nn.MaxPool2D(2, 2, 0)
        self.feature_19 = nn.Conv2D(256, 512, 3, 1, 1)
        self.relu_20 = nn.ReLU()
        self.feature_21 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_22 = nn.ReLU()
        self.feature_23 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_24 = nn.ReLU()
        self.feature_25 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_26 = nn.ReLU()

        self.mp_27 = nn.MaxPool2D(2, 2, 0)
        self.feature_28 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_29 = nn.ReLU()
        self.feature_30 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_31 = nn.ReLU()
        self.feature_32 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_33 = nn.ReLU()
        self.feature_34 = nn.Conv2D(512, 512, 3, 1, 1)
        self.relu_35 = nn.ReLU()

    def forward(self, x):
        x = self.stand(x)
        feats = []
        group = []
        x = self.feature_0(x)
        x = self.relu_1(x)
        group.append(x)
        x = self.feature_2(x)
        x = self.relu_3(x)
        group.append(x)
        feats.append(group)
        
        group = []
        x = self.mp_4(x)
        x = self.feature_5(x)
        x = self.relu_6(x)
        group.append(x)
        x = self.feature_7(x)
        x = self.relu_8(x)
        group.append(x)
        feats.append(group)

        group = []
        x = self.mp_9(x)
        x = self.feature_10(x)
        x = self.relu_11(x)
        group.append(x)
        x = self.feature_12(x)
        x = self.relu_13(x)
        group.append(x)
        x = self.feature_14(x)
        x = self.relu_15(x)
        group.append(x)
        x = self.feature_16(x)
        x = self.relu_17(x)
        group.append(x)
        feats.append(group)

        group = []
        x = self.mp_18(x)
        x = self.feature_19(x)
        x = self.relu_20(x)
        group.append(x)
        x = self.feature_21(x)
        x = self.relu_22(x)
        group.append(x)
        x = self.feature_23(x)
        x = self.relu_24(x)
        group.append(x)
        x = self.feature_25(x)
        x = self.relu_26(x)
        group.append(x)
        feats.append(group)

        group = []
        x = self.mp_27(x)
        x = self.feature_28(x)
        x = self.relu_29(x)
        group.append(x)
        x = self.feature_30(x)
        x = self.relu_31(x)
        group.append(x)
        x = self.feature_32(x)
        x = self.relu_33(x)
        group.append(x)
        x = self.feature_34(x)
        x = self.relu_35(x)
        group.append(x)
        feats.append(group)

        return feats

    def stand(self, x):
        mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1])
        std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1])
        y = (x + 1.) / 2.
        y = (y - mean) / std
        return y

4)定义各部分 Loss,包括 Perceptual Loss、Style Loss 和 对抗Loss(即 GAN 的 Loss)

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

class L1(): 
    def __init__(self,):
        self.calc = nn.L1Loss()
    
    def __call__(self, x, y):
        return self.calc(x, y)

# 计算原图片和生成图片通过vgg19模型各个层输出的激活特征图的L1 Loss
class Perceptual():
    def __init__(self, vgg, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
        super(Perceptual, self).__init__()
        self.vgg = vgg
        self.criterion = nn.L1Loss()
        self.weights = weights

    def __call__(self, x, y):
        x = F.interpolate(x, (opt.img_size, opt.img_size), mode='bilinear', align_corners=True)
        y = F.interpolate(y, (opt.img_size, opt.img_size), mode='bilinear', align_corners=True)
        x_features = self.vgg(x)
        y_features = self.vgg(y)
        content_loss = 0.0
        for i in range(len(self.weights)):
            content_loss += self.weights[i] * self.criterion(x_features[i][0], y_features[i][0]) # 此vgg19预训练模型无bn层,所以尝试不用rate
        return content_loss

# 通过vgg19模型,计算原图片与生成图片风格相似性的Loss
class Style():
    def __init__(self, vgg):
        super(Style, self).__init__()
        self.vgg = vgg
        self.criterion = nn.L1Loss()

    def compute_gram(self, x):
        b, c, h, w = x.shape
        f = x.reshape([b, c, w * h])
        f_T = f.transpose([0, 2, 1])
        G = paddle.matmul(f, f_T) / (h * w * c)
        return G

    def __call__(self, x, y):
        x = F.interpolate(x, (opt.img_size, opt.img_size), mode='bilinear', align_corners=True)
        y = F.interpolate(y, (opt.img_size, opt.img_size), mode='bilinear', align_corners=True)
        x_features = self.vgg(x)
        y_features = self.vgg(y)
        style_loss = 0.0
        blocks = [2, 3, 4, 5]
        layers = [2, 4, 4, 2]
        for b, l in list(zip(blocks, layers)):
            b = b - 1
            l = l - 1
            style_loss += self.criterion(self.compute_gram(x_features[b][l]), self.compute_gram(y_features[b][l]))
        return style_loss

# 对叠加在图片上的mask边缘进行高斯模糊处理
def gaussian_blur(input, kernel_size, sigma):
    def get_gaussian_kernel(kernel_size: int, sigma: float) -> paddle.Tensor:
        def gauss_fcn(x, window_size, sigma):
            return -(x - window_size // 2)**2 / float(2 * sigma**2)
        gauss = paddle.stack([paddle.exp(paddle.to_tensor(gauss_fcn(x, kernel_size, sigma)))for x in range(kernel_size)])
        return gauss / gauss.sum()


    b, c, h, w = input.shape
    ksize_x, ksize_y = kernel_size
    sigma_x, sigma_y = sigma
    kernel_x = get_gaussian_kernel(ksize_x, sigma_x)
    kernel_y = get_gaussian_kernel(ksize_y, sigma_y)
    kernel_2d = paddle.matmul(kernel_x, kernel_y, transpose_y=True)
    kernel = kernel_2d.reshape([1, 1, ksize_x, ksize_y])
    kernel = kernel.repeat_interleave(c, 0)
    padding = [(k - 1) // 2 for k in kernel_size]
    return F.conv2d(input, kernel, padding=padding, stride=1, groups=c)

# GAN Loss,采用最小二乘Loss
class Adversal():
    def __init__(self, ksize=71): 
        self.ksize = ksize
        self.loss_fn = nn.MSELoss()
    
    def __call__(self, netD, fake, real, masks): 
        fake_detach = fake.detach()

        g_fake = netD(fake)
        d_fake  = netD(fake_detach)
        d_real = netD(real)

        _, _, h, w = g_fake.shape
        b, c, ht, wt = masks.shape
        
        # 对齐判别器输出特征图与mask的尺寸
        if h != ht or w != wt:
            masks = F.interpolate(masks, size=(h, w), mode='bilinear', align_corners=True)
        d_fake_label = gaussian_blur(1 - masks, (self.ksize, self.ksize), (10, 10)).detach()
        d_real_label = paddle.ones_like(d_real)
        g_fake_label = paddle.ones_like(g_fake)

        dis_loss = [self.loss_fn(d_fake, d_fake_label).mean(), self.loss_fn(d_real, d_real_label).mean()]
        gen_loss = (self.loss_fn(g_fake, g_fake_label) * masks / paddle.mean(masks)).mean()

        return dis_loss, gen_loss

注:上面计算对抗损失 Adversal()的代码,作者进行了一些优化,这里为了方便小伙伴们理解,我将其改为和前面介绍的论文中计算 Adv. Loss 公式一致的写法。

5)定义输出文件夹初始化过程

import os
import numpy as np

# 初始化输出文件夹(默认为项目路径下的output/文件夹
def init_output(output_path):
    if not os.path.exists(output_path):
        os.mkdir(output_path)
        # 记录当前迭代步数
        current_step = np.array([0])
        np.save(os.path.join(output_path, "current_step"), current_step)
        print('训练输出目录['+output_path+']初始化完成')
    # 存储生成器、判别器check point
    if not os.path.exists(os.path.join(output_path, "model")):
        os.mkdir(os.path.join(output_path, "model"))
    # 存储训练时生成的图片
    if not os.path.exists(os.path.join(output_path, 'pic')):
        os.mkdir(os.path.join(output_path, 'pic'))
    # 存储预测时生成的图片
    if not os.path.exists(os.path.join(output_path, 'pic_val')):
        os.mkdir(os.path.join(output_path, 'pic_val'))

6)定义训练循环、记录日志、打印日志等处理过程

import numpy as np
import time
import os
import math
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

import paddle
import paddle.distributed as dist
from paddle.io import DistributedBatchSampler
from paddle.io import Dataset, DataLoader

def train(show_interval=100, save_interval=500, total_iter=1000000, epoch_num=1000000):
    # 初始化训练输出路径
    init_output(opt.output_path)
    
    dist.init_parallel_env()

    # 读取当前训练进度
    current_step = np.load(os.path.join(opt.output_path, 'current_step.npy'))[0]
    print('已经完成 ['+str(current_step)+'] 步训练,开始继续训练...')

    # 定义数据读取用的DataLoader
    pds = PlaceDateset(opt)
    batchsamp = DistributedBatchSampler(pds, shuffle=True, batch_size=opt.batch_size, drop_last=True)
    loader = DataLoader(pds, batch_sampler=batchsamp, num_workers=4)
    data_total_num = pds.__len__()

    # 初始化生成器、判别器、计算Perceptu Loss用的VGG19模型(权重迁移自PyTorch)
    # vgg模型不参与训练,设为预测模式
    g = InpaintGenerator(opt)
    g = paddle.DataParallel(g)
    d = Discriminator()
    d = paddle.DataParallel(d)
    vgg19 = VGG19F()
    vgg_state_dict = paddle.load(opt.vgg_weight_path)
    vgg19.set_state_dict(vgg_state_dict)
    g.train()
    d.train()
    vgg19.eval()

    # 定义优化器
    opt_g = paddle.optimizer.Adam(learning_rate=opt.lrg, beta1=opt.beta1, beta2=opt.beta2, parameters=g.parameters())
    opt_d = paddle.optimizer.Adam(learning_rate=opt.lrd, beta1=opt.beta1, beta2=opt.beta2, parameters=d.parameters())

    # 读取保存的模型权重、优化器参数
    if current_step > 0:
        print('读取存储的模型权重、优化器参数...')
        time.sleep(.1)
        para = paddle.load(os.path.join(opt.output_path, "model/g.pdparams"))
        time.sleep(.1)
        g.set_state_dict(para)
        time.sleep(.1)
        para = paddle.load(os.path.join(opt.output_path, "model/d.pdparams"))
        time.sleep(.1)
        d.set_state_dict(para)
        time.sleep(.1)
        para = paddle.load(os.path.join(opt.output_path, "model/g.pdopt"))
        time.sleep(.1)
        opt_g.set_state_dict(para)
        time.sleep(.1)
        para = paddle.load(os.path.join(opt.output_path, "model/d.pdopt"))
        time.sleep(.1)
        opt_d.set_state_dict(para)
        time.sleep(.1)

    # 定义各部分loss
    l1_loss = L1()
    perceptual_loss = Perceptual(vgg19)
    style_loss = Style(vgg19)
    adv_loss = Adversal()

    # 设置训练时生成图片的存储路径
    pic_path = os.path.join(opt.output_path, 'pic')
              
    # 训练循环
    for epoch in range(epoch_num):
        start = time.time()
        if current_step >= total_iter:
            break
        for step, data in enumerate(loader):
            if current_step >= total_iter:
                break
            current_step += 1

            # 给图片加上mask
            img, mask, fname = data
            img_masked = (img * (1 - mask)) + mask
            pred_img = g(img_masked, mask)
            comp_img = (1 - mask) * img + mask * pred_img

            # 模型参数更新过程
            loss_g = {}
            loss_g['l1'] = l1_loss(img, pred_img) * opt.l1_weight
            loss_g['perceptual'] = perceptual_loss(img, pred_img) * opt.perceptual_weight
            loss_g['style'] = style_loss(img, pred_img) * opt.style_weight
            dis_loss, gen_loss = adv_loss(d, comp_img, img, mask)
            loss_g['adv_g'] = gen_loss * opt.adversal_weight
            loss_g_total = loss_g['l1'] + loss_g['perceptual'] + loss_g['style'] + loss_g['adv_g']
            loss_d_fake = dis_loss[0]
            loss_d_real = dis_loss[1]
            loss_d_total = loss_d_fake + loss_d_real
            opt_g.clear_grad()
            opt_d.clear_grad()
            loss_g_total.backward()
            loss_d_total.backward()
            opt_g.step()
            opt_d.step()

            # 写log文件,保存生成的图片,定期保存模型check point
            log_interval = 1 if current_step < 10000 else 100
            if dist.get_rank() == 0: # 只在主进程执行
                if current_step % log_interval == 0:
                    logfn = 'log.txt'
                    f = open(os.path.join(opt.output_path, logfn), 'a')
                    logtxt = 'current_step:[' + str(current_step) +                             ']\t' + 'g_l1:' + str(loss_g['l1'].numpy()) +                             '\t' + 'g_perceptual:' + str(loss_g['perceptual'].numpy()) +                             '\t' + 'g_style:' + str(loss_g['style'].numpy()) +                             '\t' + 'g_adversal:' + str(loss_g['adv_g'].numpy()) +                             '\t' + 'g_total:' + str(loss_g_total.numpy()) +                             '\t' + 'd_fake:' + str(loss_d_fake.numpy()) +                             '\t' + 'd_real:' + str(loss_d_real.numpy()) +                             '\t' + 'd_total:' + str(loss_d_total.numpy()) +                             '\t' + 'filename:[' + fname[0] +                             ']\t' + 'time:[' + time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) + ']\n'
                    f.write(logtxt)
                    f.close()    

                # show img
                if current_step % show_interval == 0:
                    print('current_step:', current_step, 'epoch:', epoch,                         'step:['+str(step)+'/'+str(math.ceil(data_total_num / opt.batch_size))+']'                         'g_l1:', loss_g['l1'].numpy(),                         'g_perceptual:', loss_g['perceptual'].numpy(),                         'g_style:', loss_g['style'].numpy(),                         'g_adversal:', loss_g['adv_g'].numpy(),                         'g_total:', loss_g_total.numpy(),                         'd_fake:', loss_d_fake.numpy(),                         'd_real:', loss_d_real.numpy(),                         'd_total:', loss_d_total.numpy(),                         'filename:', fname[0],                         time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
                    img_show1 = (img.numpy()[0].transpose((1,2,0)) + 1.) / 2.
                    img_show2 = (pred_img.numpy()[0].transpose((1,2,0)) + 1.) / 2.
                    img_show3 = (comp_img.numpy()[0].transpose((1,2,0)) + 1.) / 2.
                    img_show4 = mask.numpy()[0][0]
                    plt.figure(figsize=(12,4),dpi=80)
                    plt.subplot(1, 4, 1)
                    plt.imshow(img_show1)
                    plt.subplot(1, 4, 2)
                    plt.imshow(img_show2)
                    plt.subplot(1, 4, 3)
                    plt.imshow(img_show3)
                    plt.subplot(1, 4, 4)
                    plt.imshow(img_show4)
                    plt.show()

                    img_show2 = (pred_img.numpy()[0].transpose((1,2,0)) + 1.) / 2.
                    img_show2 = (img_show2 * 256).astype('uint8')
                    img_show2 = cv2.cvtColor(img_show2, cv2.COLOR_RGB2BGR)
                    img_show4 = (mask.numpy()[0][0] * 255).astype('uint8')
                    cv2.imwrite(os.path.join(pic_path, os.path.split(fname[0])[1]), img_show2)
                    cv2.imwrite(os.path.join(pic_path, os.path.split(fname[0])[1].replace('.', '_mask.')), img_show4)

                # 定时存盘
                if current_step % save_interval == 0:
                    time.sleep(.1)
                    para = g.state_dict()
                    time.sleep(.1)
                    paddle.save(para, os.path.join(opt.output_path, "model/g.pdparams"))
                    time.sleep(.1)
                    para = d.state_dict()
                    time.sleep(.1)
                    paddle.save(para, os.path.join(opt.output_path, "model/d.pdparams"))
                    time.sleep(.1)
                    para = opt_g.state_dict()
                    time.sleep(.1)
                    paddle.save(para, os.path.join(opt.output_path, "model/g.pdopt"))
                    time.sleep(.1)
                    para = opt_d.state_dict()
                    time.sleep(.1)
                    paddle.save(para, os.path.join(opt.output_path, "model/d.pdopt"))
                    time.sleep(.1)
                    np.save(os.path.join(opt.output_path, 'current_step'), np.array([current_step]))
                    print('第['+str(current_step)+']步模型保存。保存路径:', os.path.join(opt.output_path, "model"))
            
            # 存储clock
            if current_step % 10 == 0:
                clock = np.array([str(current_step), time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))])
                np.savetxt(os.path.join(opt.output_path, 'clock.txt'), clock, fmt='%s', delimiter='\t')

    # 训练迭代完成时保存模型参数
    time.sleep(1)
    para = g.state_dict()
    time.sleep(1)
    paddle.save(para, os.path.join(opt.output_path, "model/g.pdparams"))
    time.sleep(1)
    para = d.state_dict()
    time.sleep(1)
    paddle.save(para, os.path.join(opt.output_path, "model/d.pdparams"))
    time.sleep(1)
    para = opt_g.state_dict()
    time.sleep(1)
    paddle.save(para, os.path.join(opt.output_path, "model/g.pdopt"))
    time.sleep(1)
    para = opt_d.state_dict()
    time.sleep(1)
    paddle.save(para, os.path.join(opt.output_path, "model/d.pdopt"))
    time.sleep(1)
    np.save(os.path.join(opt.output_path, 'current_step'), np.array([current_step]))
    print('第['+str(current_step)+']步模型保存。保存路径:', os.path.join(opt.output_path, "model"))
    print('Finished training! Total Iteration:', current_step)

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

7)开始训练

train(show_interval=1, save_interval=10, total_iter=1000000)

2.使用预训练参数进行预测

预测用的图片存放在 ‘data/aot/val_img/’ 路径下, 预测用的 mask 存放在 ‘data/aot/val_mask/’ 路径下,顺序要与图片一一对应。预训练参数存放在 ‘data/data89198/g.pdparams’ 路径下,是在 Place365 Standard 数据集上训练的。输入图片和mask的尺寸为 512 × 512。

import numpy as np
import os
import time
from PIL import Image
import paddle
from paddle.vision.transforms import Resize
# from model.config import opt
# from model.model import InpaintGenerator
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

img_path = 'data/aot/val_img/'
mask_path = 'data/aot/val_mask/'
model_path = 'data/data89198/g.pdparams'

for _, _, files in os.walk(img_path):
    pics = np.sort(np.array(files))
    break
for _, _, files in os.walk(mask_path):
    masks = np.sort(np.array(files))
    break

def predict(img_path, mask_path, g):
    # 读取原图片与mask掩码图片并进行resize、格式转换
    img = Image.open(img_path)
    mask = Image.open(mask_path)
    img = Resize([opt.img_size, opt.img_size], interpolation='bilinear')(img)
    mask = Resize([opt.img_size, opt.img_size], interpolation='nearest')(mask)
    img = img.convert('RGB')
    mask = mask.convert('L')
    img = np.array(img)
    mask = np.array(mask)
    img_show1 = img
    img_show3 = mask

    # 图片数据归一化到(-1, +1)区间,形状为[n, c, h, w], 取值为[1, 3, 512, 512]
    # mask图片数据归一化为0、1二值。0代表原图片像素,1代表缺失像素。形状为[n, c, h, w], 取值为[1, 1, 512, 512]
    img = (img.astype('float32') / 255.) * 2. - 1.
    img = np.transpose(img, (2, 0, 1))
    mask = np.expand_dims(mask.astype('float32') / 255., 0)
    img = paddle.to_tensor(np.expand_dims(img, 0))
    mask = paddle.to_tensor(np.expand_dims(mask, 0))
    
    # 预测
    img_masked = (img * (1 - mask)) + mask # 将掩码叠加到图片上
    pred_img = g(img_masked, mask) # 用加掩码后的图片和掩码生成预测图片
    comp_img = (1 - mask) * img + mask * pred_img # 使用原图片和预测图片合成最终的推理结果图片
    img_show2 = (comp_img.numpy()[0].transpose((1,2,0)) + 1.) / 2.

    # 显示
    plt.figure(figsize=(12,4),dpi=80)
    plt.subplot(1, 3, 1)
    plt.imshow(img_show1)
    plt.subplot(1, 3, 2)
    plt.imshow(img_show2)
    plt.subplot(1, 3, 3)
    plt.imshow(img_show3)
    plt.show()

# 初始化生成器,读取参数
g = InpaintGenerator(opt)
g.eval()
time.sleep(.1)
para = paddle.load(model_path)
time.sleep(.1)
g.set_state_dict(para)

for pic, mask in zip(pics, masks):
g, 0))
    mask = paddle.to_tensor(np.expand_dims(mask, 0))
    
    # 预测
    img_masked = (img * (1 - mask)) + mask # 将掩码叠加到图片上
    pred_img = g(img_masked, mask) # 用加掩码后的图片和掩码生成预测图片
    comp_img = (1 - mask) * img + mask * pred_img # 使用原图片和预测图片合成最终的推理结果图片
    img_show2 = (comp_img.numpy()[0].transpose((1,2,0)) + 1.) / 2.

    # 显示
    plt.figure(figsize=(12,4),dpi=80)
    plt.subplot(1, 3, 1)
    plt.imshow(img_show1)
    plt.subplot(1, 3, 2)
    plt.imshow(img_show2)
    plt.subplot(1, 3, 3)
    plt.imshow(img_show3)
    plt.show()

# 初始化生成器,读取参数
g = InpaintGenerator(opt)
g.eval()
time.sleep(.1)
para = paddle.load(model_path)
time.sleep(.1)
g.set_state_dict(para)

for pic, mask in zip(pics, masks):
    predict(os.path.join(img_path, pic), os.path.join(mask_path, mask), g)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

可见,Logo 去除的效果还是很不错滴~~

后记

鼓捣这个模型的过程中,为了涨点,我对 AOT-Block 进行了改进,并添加了新的监督,在小数据集上跑确实能涨点~~ 。目前正上手进行模型的优化、裁剪。这个 AOT 结构真是不好下剪子啊(汗~~ )。

兄弟我继续鼓捣了,发现了什么好玩的东东再来和伙伴们分享吧,待续…

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐