使用AniGAN之cycle loss版进行人脸动漫化

事先说明:该论文没有复现成功的公开代码,当然作者没有公开源码,然后我凭借着我的经验一顿狂改,然后重新结合cyclegan 进行训练,然后不要看我这个训练效果不咋地,但是为了有这个结果,我花费了几百的算力。很难训的啊,前前后后的修改,呜呜呜,越说越难过,大家给个免费的fork给我回回血吧!!!!

1. 效果展示:

(我选了几张比较有人样的),从右到左分别是动漫人脸A,真实人脸B,动漫人脸A基于真实人脸B真实化(用于辅助),真实人脸B基于动漫人脸A动漫化(目标)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

这些结果是我训练了10epoch不到的结果,我没有自己的GPU,所以就在aistudio玩玩,

2. AniGAN的架构以及核心讲解

2.1 模型整体架构图

在这里插入图片描述

2.2 PoLIN和AdaPoLIN

这就是本文提出的归一层,PoLIN是自己的特征进行归一化,AdaPoLIN是将动漫域的均值和方差信息注入到真实人脸域的归一化信息中(如果是真实人脸动漫化)

在这里插入图片描述

在这里插入图片描述

这个时候让大家看看U-GAT-IT论文提出的归一层:

(U-GAT-IT这篇论文做的是两个不同域图片转换,比如真实人脸域进行动漫化,但是这篇论文输入的是单一的图片,就是如果你做的是真实人脸动漫化,你输入一张真实人脸的图片,然后模型输出动漫化的图片,无法进行更详细的控制)。

在这里插入图片描述

AdaLIN:作者认为IN和LN在翻译任务上各有好处。IN有注意更好保留源域的细节,这是由于IN是channel-wise的norm,但风格特性就不够。而LN因为是layer-wise的norm,更好的学到风格的统计量。因此在特征表征上,更希望使用IN保留源域的细节,在解码(decoder)上,更希望利用这些源域的细节特征,转换到风格上,这就是为啥上采样的部分使用的是LN。在上采样前,使用adaLIN,将两种norm的优点吸收,根据两个域的数据分布自己决定偏向IN还是LN,更适合unpair的图像翻译任务。----->该段文字来源于 https://blog.csdn.net/qq_34914551/article/details/112641492

也就是说AdaPoLIN用卷积去替换了(1-p)IN+(p)LN这种结合方式,就是让IN和LN结合更加紧密一点。

2.3 论文使用的判别器思想

在这里插入图片描述

2.4论文的损失设计(注意这是论文的设计,不是我这个项目的设计!!!!!)

损失的设计往往是一个网络训练思想的体现,X是源域就是指真实人脸域,Y是目标域就是动漫域

2.4.1 生成对抗损失

在这里插入图片描述

我们从这个损失就可以看出来,我们需要模型同一份参数不单单做真实域->到动漫域,还做动漫域到真人域,我认为这个给模型训练的压力太大了,不可取,虽然通过让动漫域到真实人脸理论可以让模型学习到人脸的结构性信息,但是这依旧是理论上的,我是没有体会到过这个理论的好处,我只能明白原论文这套方案有多么难训练。(以下就是原论文发训练方案的结果)
在这里插入图片描述

在这里插入图片描述

本项目我的方案是使用LSGAN loss,一个模型动漫人脸真实化,还有一个模型训练真实人脸动漫化。

2.4.2 Feature matching loss和Domain-aware feature matching loss.

生成器利用判别器的share layers做一个特征对齐
在这里插入图片描述

在这里插入图片描述

2.4.3 Reconstruction loss

在这里插入图片描述

3. 我的代码

我的方案会慢慢随着代码进行讲解

3.1 生成器的搭建

  1. 使用Conv2DTranspose替换nn.upsample,这是我的经验
  2. 刚开始我的生成器参数只有25MB左右,然后太少了,于是我中间给它加残差,加MLP,参数增加到45MB左右,参数太少模型不容易学习
import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F

class ConvLayer(nn.Layer):
    '''
    ResNeXtBottleneck
    '''
    def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
        super(ConvLayer, self).__init__()
        D = out_channels // 2
        self.out_channels = out_channels
        self.conv_reduce = nn.Conv2D(in_channels, D, kernel_size=1, stride=1, padding=0, bias_attr=False)
        self.conv_conv = nn.Conv2D(D, D, kernel_size=3, stride=stride, padding=dilate, dilation=dilate,
                                   groups=cardinality,
                                   bias_attr=False)
        self.conv_expand = nn.Conv2D(D, out_channels, kernel_size=1, stride=1, padding=0, bias_attr=False)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            # self.shortcut.add_module('shortcut',
            #                          nn.AvgPool2d(2, stride=2))
            # self.shortcut = nn.AvgPool2D(2, stride=2)
            self.shortcut = nn.Conv2D(in_channels,out_channels,3,stride=stride,padding=1)
    def forward(self, x):
        bottleneck = self.conv_reduce(x)
        bottleneck = F.leaky_relu(bottleneck, 0.2, True)
        bottleneck = self.conv_conv(bottleneck)
        bottleneck = F.leaky_relu(bottleneck, 0.2, True)
        bottleneck = self.conv_expand(bottleneck)
        # print("bottleneck",bottleneck.shape)
        # print("x",x.shape)
        x = self.shortcut(x)
        # print("x",x.shape)        
        return x + bottleneck


x = paddle.randn([4,256,64,64])
ConvLayer()(x).shape
W0726 00:11:07.054811 15967 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0726 00:11:07.058779 15967 device_context.cc:465] device: 0, cuDNN Version: 7.6.





[4, 256, 64, 64]
class PoLIN(nn.Layer):
    def __init__(self, dim):
        super(PoLIN, self).__init__()
        self.conv1x1 = nn.Conv2D(dim*2, dim, 1, 1, 0, bias_attr=False)

    def forward(self, input):
        IN = nn.InstanceNorm2D(input.shape[1],  weight_attr=False, bias_attr=False)(input)
        LN = nn.LayerNorm(input.shape[1:], weight_attr=False, bias_attr=False)(input)
        LIN = paddle.concat((IN,LN),axis=1)
        result = self.conv1x1(LIN)
        return result


class Ada_PoLIN(nn.Layer):
    def __init__(self, dim):
        super(Ada_PoLIN, self).__init__()
        self.Conv1x1 = nn.Conv2D(dim*2, dim, 1, 1, 0, bias_attr=False)

    def forward(self, input, params):
        IN = nn.InstanceNorm2D(input.shape[1],  weight_attr=False, bias_attr=False)(input)
        LN = nn.LayerNorm(input.shape[1:], weight_attr=False, bias_attr=False)(input)
        LIN = paddle.concat((IN,LN),axis=1)
        b,c,w,h = LIN.shape
        params = nn.AdaptiveAvgPool2D(1)(params)
        mid = params.shape[1] // 2
#         print(params.size())
        gamma = params[:, :mid]
        beta = params[:, mid:]
        c = self.Conv1x1(LIN)

        # gamma= gamma.unsqueeze(2).unsqueeze(3)
        # beta= beta.unsqueeze(2).unsqueeze(3)
        result = gamma * c + beta

        return result

x = paddle.randn([4,256,64,64])

y = paddle.randn([4,512,64,64])
PoLIN(256)(x)
Ada_PoLIN(256)(x,y).shape
[4, 256, 64, 64]
import paddle
import paddle.nn as nn
class ResBlock(nn.Layer):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        
        def block(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
            layers = []
            layers += [nn.Conv2D(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias_attr =bias)]
            layers += [nn.InstanceNorm2D(num_features=out_channels)]
            
            layers += [nn.ReLU()]
            layers += [nn.Conv2D(in_channels=out_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias_attr =bias)]
            layers += [nn.InstanceNorm2D(num_features=out_channels)]

            cbr = nn.Sequential(*layers)

            return cbr
        
        self.block_1 = block(in_channels,out_channels)
        self.block_2 = block(out_channels,out_channels)
        self.block_3 = block(out_channels,out_channels)
        self.block_4 = block(out_channels,out_channels)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        # block 1
        residual = x
        out = self.block_1(x)
        out = self.relu(out)
        
        # block 2
        residual = out
        out = self.block_2(out)
        out += residual
        out = self.relu(out)
        
        # block 3
        residual = out
        out = self.block_3(out)
        out += residual
        out = self.relu(out)
        
        # block 4
        residual = out
        out = self.block_4(out)
        out += residual
        out = self.relu(out)
        
        return out

x = paddle.randn([4,3,256,256])
ResBlock(3,7)(x).shape
[4, 7, 256, 256]
x = paddle.to_tensor([2,3])
x0=x
x0+=1
x
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:





Tensor(shape=[2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
       [2, 3])
conv_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.02)



class ASC_block(nn.Layer):
    def __init__(self, input_dim = 256, dim = 256, num_ASC_layers = 4):
        super(ASC_block, self).__init__()
        self.input_dim = input_dim
        self.num_ASC_layers = num_ASC_layers
        self.ConvLayer = []
        self.NormLayer = []
        self.Ada_PoLINLayer = []
        for _ in range(self.num_ASC_layers):
            self.ConvLayer += [ConvLayer(self.input_dim, dim, stride = 1)]
            self.NormLayer += [Ada_PoLIN(dim)]
            self.input_dim = dim
        self.ConvLayer = nn.LayerList(self.ConvLayer)
        self.NormLayer = nn.LayerList(self.NormLayer)
        self.res_block = ResBlock(dim,dim)


    def forward(self, x, Ada_PoLIN_params): 
        for ConvLayer, NormLayer in zip(self.ConvLayer, self.NormLayer):
            x = ConvLayer(x)
            x = NormLayer(x,Ada_PoLIN_params)    
        x1 = self.res_block(x)
        x+=x1           
        return x



class Decoder_ConvLayer(nn.Layer):
    '''
    D = in_channels // 2
    '''
    def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
        super(Decoder_ConvLayer, self).__init__()
        D = in_channels // 2
        self.out_channels = out_channels
        self.conv_reduce = nn.Conv2D(in_channels, D, kernel_size=1, stride=1, padding=0, bias_attr=False)
        self.conv_conv = nn.Conv2D(D, D, kernel_size=3, stride=stride, padding=dilate, dilation=dilate,
                                   groups=cardinality,
                                   bias_attr=False)
        self.conv_expand = nn.Conv2D(D, out_channels, kernel_size=1, stride=1, padding=0, bias_attr=False)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            # self.shortcut.add_module('shortcut',
            #                          nn.AvgPool2d(2, stride=2))
            # self.shortcut = nn.AvgPool2D(2, stride=2)
            self.shortcut = nn.Conv2D(in_channels,out_channels,3,stride=stride,padding=1)
    def forward(self, x):
        bottleneck = self.conv_reduce(x)
        bottleneck = F.leaky_relu(bottleneck, 0.2, True)
        bottleneck = self.conv_conv(bottleneck)
        bottleneck = F.leaky_relu(bottleneck, 0.2, True)
        bottleneck = self.conv_expand(bottleneck)
        # print("bottleneck",bottleneck.shape)
        # print("x",x.shape)
        x = self.shortcut(x)
        # print("x",x.shape)        
        return x + bottleneck

class MLP(nn.Layer):
    def __init__(self, dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(dim,dim*4)
        self.fc2 = nn.Linear(dim*4,dim)

    def forward(self, x0):
        x = x0  
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.2, True)
        x = self.fc2(x)
        x0+=x
        return x0
class FST_block(nn.Layer):
    def __init__(self, input_dim, dim,kernel = False):
        super(FST_block, self).__init__()
        self.input_dim = input_dim
        self.block = []
        # self.block += [nn.Upsample(scale_factor = 2)]
        self.block += [nn.Conv2DTranspose(
            input_dim,input_dim,4,2,1,
            bias_attr=False,weight_attr=paddle.ParamAttr(initializer=conv_initializer)
            )]
        self.block += [Decoder_ConvLayer(self.input_dim, dim,1,8)]       
        self.block += [PoLIN(dim)]

        self.block += [Decoder_ConvLayer(dim, dim,1,8)]       
        self.block += [PoLIN(dim)]
        self.block += [Decoder_ConvLayer(dim, dim,1,8)]       
        self.block += [PoLIN(dim)]
        self.block += [Decoder_ConvLayer(dim, dim,1,8,2)]       
        self.block += [PoLIN(dim)]
        self.block += [Decoder_ConvLayer(dim, dim,1,8,4)]       
        self.block += [PoLIN(dim)]
        self.block += [Decoder_ConvLayer(dim, dim,1,8,2)]       
        self.block += [PoLIN(dim)]        
        self.block += [Decoder_ConvLayer(dim, dim,1,8)]
        if kernel:
            for i in range(12):
                self.block += [PoLIN(dim)]        
                self.block += [Decoder_ConvLayer(dim, dim,1,8)]                                   
        self.block = nn.Sequential(*self.block)
        self.Ada_PoLIN = Ada_PoLIN(dim)
        self.mlp = MLP(dim)                                                                     
                      
    def forward(self, x, Ada_PoLIN_params): 
        x = self.block(x)
        # print(x.shape)
        x = self.Ada_PoLIN(x, Ada_PoLIN_params)
        bs,c,h,w = x.shape
        x = x.flatten(2)
        x = x.transpose([0,2,1])
        x = self.mlp(x)
        x = x.transpose([0,2,1])
        x = x.reshape([bs,c,h,w])
        return x

x = paddle.randn([4,256,64,64])
y = paddle.randn([4,512,64,64])
ASC_block(256,256)(x,y).shape
x = paddle.randn([4,256,32,32])
y = paddle.randn([4,256,64,64])
FST_block(256,128)(x,y).shape
[4, 128, 64, 64]
提取风格的Encoder
class Style_Encoder(nn.Layer):
    def __init__(self):
        super(Style_Encoder,self).__init__()
        self.conv0 = nn.Sequential(
            ConvLayer(3,32,1,cardinality=4, dilate= 1),
            ConvLayer(32,64,2,cardinality=8, dilate= 2),
            ConvLayer(64,64,1,cardinality=8, dilate= 4),
            ConvLayer(64,64,1,cardinality=8, dilate= 2),
            ConvLayer(64,64,1,cardinality = 8,dilate = 1))


        self.conv1 = nn.Sequential(
            ConvLayer(64,64,1,cardinality=8, dilate= 1),
            ConvLayer(64,128,2,cardinality=16, dilate= 2),
            ConvLayer(128,128,1,cardinality = 16,dilate = 1),
            ConvLayer(128,128,1,cardinality=16, dilate= 1))

        self.conv2 = nn.Sequential(
            ConvLayer(128,128,1,cardinality=16, dilate= 1),
            ConvLayer(128,256,2,cardinality=32, dilate= 2),
            ConvLayer(256,256,1,cardinality=32, dilate= 4),
            ConvLayer(256,256,1,cardinality = 32,dilate = 2))
        
        
        self.conv3 = nn.Sequential(
            ConvLayer(256,256,1,cardinality=16, dilate= 1),
            ConvLayer(256,512,2,cardinality=32, dilate= 1),
            ConvLayer(512,512,1,cardinality=32, dilate= 1),
        )
    def forward(self,x):
        out_list = []
        x = self.conv0(x)
        out_list.append(x)

        x = self.conv1(x)
        out_list.append(x)

        x = self.conv2(x)
        out_list.append(x)

        x = self.conv3(x)
        out_list.append(x)

        return out_list

x = paddle.randn([4,3,256,256])
style_list = Style_Encoder()(x)
for i in style_list:
    print(i.shape)
[4, 64, 128, 128]
[4, 128, 64, 64]
[4, 256, 32, 32]
[4, 512, 16, 16]
提取内容的Encoder

class Content_Encoder(nn.Layer):
    def __init__(self):
        super(Content_Encoder,self).__init__()
        self.conv0 = nn.Sequential(
            ConvLayer(3,32,1,cardinality=4, dilate= 1),
            ConvLayer(32,32,2,cardinality=8, dilate= 2),

            ConvLayer(32,64,1,cardinality=8, dilate= 4),
            ConvLayer(64,64,2,cardinality=8, dilate= 2),

            ConvLayer(64,64,1,cardinality = 8,dilate = 1),
            ConvLayer(64,64,1,cardinality=8, dilate= 1),
            ConvLayer(64,128,2,cardinality=16, dilate= 2),

            ConvLayer(128,128,1,cardinality = 16,dilate = 1),
            ConvLayer(128,128,1,cardinality=16, dilate= 2),
            ConvLayer(128,128,1,cardinality=16, dilate= 4),
            ConvLayer(128,256,2,cardinality=32, dilate= 2),

            ConvLayer(256,256,1,cardinality=32, dilate= 1),
            ConvLayer(256,256,1,cardinality = 32,dilate = 1))
        
        

    def forward(self,x):
        x = self.conv0(x)
        return x

x = paddle.randn([4,3,256,256])
Content_Encoder()(x).shape
[4, 256, 16, 16]
class Model(nn.Layer):
    def __init__(self):
        super(Model,self).__init__()
        self.content_encoder = Content_Encoder()
        self.style_encoder = Style_Encoder()
        self.asc_block = ASC_block(256,256)

        self.num_FST_layers = 3
        self.FST_block_list = []
        self.FST_block_list += [FST_block(256,128,True)]
        self.FST_block_list += [FST_block(128,64)]
        self.FST_block_list += [FST_block(64,32)]
        self.FST_block_list = nn.LayerList(self.FST_block_list)

        self.out_convlayers = nn.Sequential(
            nn.Conv2DTranspose(
            32,32,4,2,1,
            bias_attr=False,weight_attr=paddle.ParamAttr(initializer=conv_initializer)
            ),
            Decoder_ConvLayer(32,32,1,8),
            nn.Conv2D(32,16,3,1,1),
            nn.LeakyReLU(0.2),
            nn.Conv2D(16,3,3,1,1),
            nn.Tanh()
        )
    def forward(self,content,style):
        content_feat = self.content_encoder(content)
        style_feat_list = self.style_encoder(style)
        cs_feat = self.asc_block(content_feat,style_feat_list[-1])
        i = -2
        for fst_block in self.FST_block_list:
            cs_feat = fst_block(cs_feat,style_feat_list[i])
            # print("cs_feat",cs_feat.shape)
            i -=1
        # print(i)
        cs_feat = self.out_convlayers(cs_feat)
        return cs_feat
x = paddle.randn([4,3,256,256])
y = paddle.randn([4,3,256,256])
Model()(x,y).shape
[4, 3, 256, 256]

3.2 判别器的搭建

from Normal import build_norm_layer
import paddle.nn as nn
import paddle
SpectralNorm = build_norm_layer('spectral')
class Dis_ConvLayer(nn.Layer):
    '''
    '''
    def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
        super(Dis_ConvLayer, self).__init__()
        D = out_channels // 2
        self.out_channels = out_channels
        self.conv_reduce = SpectralNorm(nn.Conv2D(in_channels, D, kernel_size=1, stride=1, padding=0, bias_attr=False))
        self.conv_conv = SpectralNorm(nn.Conv2D(D, D, kernel_size=3, stride=stride, padding=dilate, dilation=dilate,
                                   groups=cardinality,
                                   bias_attr=False))
        self.conv_expand = SpectralNorm(nn.Conv2D(D, out_channels, kernel_size=1, stride=1, padding=0, bias_attr=False))
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = SpectralNorm(nn.Conv2D(in_channels,out_channels,3,stride=stride,padding=1))
    def forward(self, x):
        bottleneck = self.conv_reduce(x)
        bottleneck = F.leaky_relu(bottleneck, 0.2, True)
        bottleneck = self.conv_conv(bottleneck)
        bottleneck = F.leaky_relu(bottleneck, 0.2, True)
        bottleneck = self.conv_expand(bottleneck)
        # print("bottleneck",bottleneck.shape)
        # print("x",x.shape)
        x = self.shortcut(x)
        # print("x",x.shape)        
        return x + bottleneck

import paddle
import paddle.nn as nn
class NetD(nn.Layer):
    def __init__(self, ndf=64):
        super(NetD, self).__init__()

        self.NetU0 = nn.Sequential(SpectralNorm(nn.Conv2D(3, ndf, kernel_size=7, stride=1, padding=3, bias_attr=False)),  # 512
                                  nn.LeakyReLU(0.2, True),
                                  SpectralNorm(nn.Conv2D(ndf, ndf, kernel_size=4, stride=2, padding=1, bias_attr=False)),  # 256
                                  nn.LeakyReLU(0.2, True))
        
        self.NetU1 = nn.Sequential(
                                  Dis_ConvLayer(ndf, ndf, cardinality=8, dilate=1),
                                  Dis_ConvLayer(ndf, ndf, cardinality=8, dilate=1, stride=2),
                                  )

        self.NetReal0 = nn.Sequential(
                                    SpectralNorm(nn.Conv2D(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias_attr=False)),
                                    nn.LeakyReLU(0.2, True),                           
                                    Dis_ConvLayer(ndf * 2, ndf * 2, cardinality=8, dilate=1),
                                    Dis_ConvLayer(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2))  # 64
        self.NetReal1 = nn.Sequential(
                                    SpectralNorm(nn.Conv2D(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias_attr=False)),
                                    nn.LeakyReLU(0.2, True),
                                    Dis_ConvLayer(ndf * 4, ndf * 4, cardinality=8, dilate=1),
                                   Dis_ConvLayer(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2),  # 32
                                   SpectralNorm(nn.Conv2D(ndf * 4, ndf * 8, kernel_size=3, stride=1, padding=1, bias_attr=False)),  # 32
                                   nn.LeakyReLU(0.2, True),
                                   Dis_ConvLayer(ndf * 8, ndf * 8, cardinality=8, dilate=1),
                                   Dis_ConvLayer(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),  # 16
                                   )

        self.NetAnime0 = nn.Sequential(
                                    SpectralNorm(nn.Conv2D(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias_attr=False)),
                                    nn.LeakyReLU(0.2, True),                           
                                    Dis_ConvLayer(ndf * 2, ndf * 2, cardinality=8, dilate=1),
                                    Dis_ConvLayer(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2))  # 64
        self.NetAnime1 = nn.Sequential(
                                    SpectralNorm(nn.Conv2D(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias_attr=False)),
                                    nn.LeakyReLU(0.2, True),
                                    Dis_ConvLayer(ndf * 4, ndf * 4, cardinality=8, dilate=1),
                                   Dis_ConvLayer(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2),  # 32
                                   SpectralNorm(nn.Conv2D(ndf * 4, ndf * 8, kernel_size=3, stride=1, padding=1, bias_attr=False)),  # 32
                                   nn.LeakyReLU(0.2, True),
                                   Dis_ConvLayer(ndf * 8, ndf * 8, cardinality=8, dilate=1),
                                   Dis_ConvLayer(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2))  # 16

    def forward(self, x,domain = "A"):
        u0 = self.NetU0(x)
        u1 = self.NetU1(u0)
        # print(x.shape)
        if domain =="A":
            f = self.NetAnime0(u1)
            x = self.NetAnime1(f)
        else :
            f = self.NetReal0(u1)
            x = self.NetReal1(f)
        # out =x
        # out = self.out(x.reshape([color.shape[0], -1]))
        return u0,u1,f,x

x = paddle.randn([4,3,256,256])
# feat = paddle.randn([4,512,32,32])
len(NetD()(x,"A"))
4

3.3 数据处理

# 解压数据集,只需执行一次 ,得到动漫人脸
import os
if not os.path.isdir("./data/anime_face"):
    os.mkdir("./data/anime_face")
!unzip -qo data/data110820/faces.zip  -d ./data/anime_face


# 解压数据集,只需执行一次 ,得到真实人脸
import os
if not os.path.isdir("./data/real_face"):
    os.mkdir("./data/real_face")
!unzip -qo data/data79149/cartoon_A2B.zip  -d ./data/real_face



from paddle.vision.transforms import CenterCrop,Resize
transform = Resize((256,256))
#构造dataset
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
# IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG']
import paddle
import cv2
import os
def data_maker(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname) and ("outfit" not in fname):
                path = os.path.join(root, fname)
                images.append(path)

    return sorted(images)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


class AnimeDataset(paddle.io.Dataset):
    def __init__(self):
        super(AnimeDataset,self).__init__()
        self.real_image_dirs =data_maker("data/real_face/cartoon_A2B/train")
        self.anime_image_dirs = data_maker("data/anime_face/faces")
        self.sizes = [
            len(fold) for fold in [self.anime_image_dirs, self.real_image_dirs]
        ]
        self.size = max(self.sizes)
        self.repadding()
    # cv2.imread直接读取为GBR,把通道换成RGB
    @staticmethod
    def loader(path):
        return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
                            cv2.COLOR_BGR2RGB)

    def repadding(self):
        for one_dir in [self.real_image_dirs,self.anime_image_dirs]:
            cur_size = len(one_dir)
            if cur_size < self.size:
                pad_num = self.size - cur_size
                pad = np.random.choice(cur_size, pad_num, replace=True)
                for i in list(pad):
                    one_dir.append(one_dir[i])



    def __getitem__(self, index):
        # try:
        anime_img = AnimeDataset.loader(self.anime_image_dirs[index])
        real_img = AnimeDataset.loader(self.real_image_dirs[index])
        # except:
        #     print(self.anime_image_dirs[index],self.real_image_dirs[index])
        real_img = real_img[:,:256,:]
        real_img = transform(real_img)
        anime_img = transform(anime_img)


        
        return real_img,anime_img

    def __len__(self):
        return self.size
for a,b in AnimeDataset():
    print(a.shape,b.shape)
    break
(256, 256, 3) (256, 256, 3)
batch_size =8
datas = AnimeDataset()
data_loader =  paddle.io.DataLoader(datas,batch_size=batch_size,shuffle =True,drop_last=True,num_workers=16)
for real_img,anime_img in data_loader:
    print(real_img.shape,anime_img.shape)
    break
[8, 256, 256, 3] [8, 256, 256, 3]

3.4 两个生成器对象实例化

real2anime_generator = Model()

discriminator = NetD()
# generator = Model()
anime2real_generator = Model()
import paddle
import math
class LinearDecayWithWarmupKeep(paddle.optimizer.lr.LambdaDecay):
    def __init__(self,
                 learning_rate,
                 total_steps,
                 warmup,
                 decay,
                 last_epoch=-1,
                 verbose=False):
        warmup_steps = warmup if isinstance(warmup,int) else int(
            math.floor(warmup * total_steps))

        decay_steps = decay if isinstance(decay,int) else int(
            math.floor(decay * total_steps))

        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(max(1, warmup_steps))

            if current_step > warmup_steps and current_step<decay_steps:
                return 1
            return max(0.0,
                       float(total_steps - current_step) /
                       float(max(1, total_steps - decay_steps)))

        super(LinearDecayWithWarmupKeep, self).__init__(learning_rate, lr_lambda,
                                                    last_epoch, verbose)
scheduler_G = paddle.optimizer.lr.StepDecay(learning_rate=1e-4, step_size=3, gamma=0.9, verbose=True)
scheduler_D = paddle.optimizer.lr.StepDecay(learning_rate=2e-4, step_size=3, gamma=0.9, verbose=True)
# scheduler_G = LinearDecayWithWarmupKeep(5e-4,100,5,20)
# scheduler_D = LinearDecayWithWarmupKeep(5e-4,100,5,20)
real2anime_optimizer_G = paddle.optimizer.Adam(learning_rate=scheduler_G,parameters=real2anime_generator.parameters(),beta1=0.5, beta2 =0.999)
anime2real_optimizer_G = paddle.optimizer.Adam(learning_rate=scheduler_G,parameters=anime2real_generator.parameters(),beta1=0.5, beta2 =0.999)
# optimizer_G = paddle.optimizer.Adam(learning_rate=scheduler_G,parameters=generator.parameters(),beta1=0.5, beta2 =0.999)
# optimizer_G = paddle.optimizer.RMSProp(learning_rate = scheduler_G, rho=0.95, epsilon=1e-04, momentum=0.0, centered=False, parameters=generator.parameters(), weight_decay=None, grad_clip=None, name=None)
optimizer_D = paddle.optimizer.RMSProp(learning_rate = scheduler_D, rho=0.95, epsilon=1e-04, momentum=0.0, centered=False, parameters=discriminator.parameters(), weight_decay=None, grad_clip=None, name=None)

# optimizer_D = paddle.optimizer.Adam(learning_rate=scheduler_D,parameters=discriminator.parameters(),beta1=0.5, beta2 =0.999)
Epoch 0: StepDecay set learning rate to 0.0001.
Epoch 0: StepDecay set learning rate to 0.0002.
'''
导入预训练参数文件接着炼丹
'''
# M_path ='Amodel_state2.pdparams'
# layer_state_dictm = paddle.load(M_path)
# real2anime_generator.set_state_dict(layer_state_dictm)

# M_path ='Bmodel_state2.pdparams'
# layer_state_dictm = paddle.load(M_path)
# anime2real_generator.set_state_dict(layer_state_dictm)


# D_path ='discriminator_params/Dmodel_state2.pdparams'
# layer_state_dictD = paddle.load(D_path)
# discriminator.set_state_dict(layer_state_dictD)
'\n导入预训练参数文件接着炼丹\n'
from paddle.nn.initializer import KaimingNormal,Constant
def weight_init(module):
    for n,m in module.named_children():
        # print("initialize:"+n)
        if isinstance(m,nn.Conv2D):
            KaimingNormal()(m.weight,m.weight.block)
        
            if m.bias is not None:
                Constant(0.0)(m.bias)
        if isinstance(m,nn.Conv2DTranspose):
            KaimingNormal()(m.weight,m.weight.block)
        
            if m.bias is not None:
                Constant(0.0)(m.bias)
        elif isinstance(m,nn.Conv1D):
            
            KaimingNormal()(m.weight,m.weight.block)
            if m.bias is not None:
                Constant(0.0)(m.bias)
        
        elif isinstance(m,(nn.BatchNorm2D,nn.InstanceNorm2D)):
            Constant(1.0)(m.weight)
            Constant(3000.0)(m._variance)
            Constant(3000.0)(m._mean)
            if m.bias is not None:
                Constant(0.0)(m.bias)
        
        elif isinstance(m,nn.Linear):
            KaimingNormal()(m.weight,m.weight.block)
            if m.bias is not None:
                Constant(0.0)(m.bias)
        else:
            pass
save_dir_model = "model_params"
save_dir_Discriminator = "discriminator_params"
import random
from visualdl import LogWriter
log_writer = LogWriter("./log/gnet")

3.5 *判别器判断的是前20个iter中生成器生成的图片,减少模式奔溃的可能性

Aimage_cache = []
CACHE_capacity = 20
def step_A_cache(A_image,i):
    if i<CACHE_capacity:
        Aimage_cache.append(A_image)
        return A_image
    else :
        Aimage_cache.append(A_image)
        Aimage_cache.pop(0)
        r = random.randint(0,CACHE_capacity-1)
        return Aimage_cache[r]


Bimage_cache = []
def step_B_cache(B_image,i):
    if i<CACHE_capacity:
        Bimage_cache.append(B_image)
        return B_image
    else :
        Bimage_cache.append(B_image)
        Bimage_cache.pop(0)
        r = random.randint(0,CACHE_capacity-1)
        return Bimage_cache[r]


3.6 训练部分

每3次iterz中,2次是real2anime_generator,1次是anime2real_generator训练,毕竟主要做的是人脸动漫化

adversarial_loss = paddle.nn.MSELoss()
l1_loss = nn.L1Loss()


weight_init(anime2real_generator)
weight_init(real2anime_generator)
weight_init(discriminator)

from tqdm import tqdm
step = 0
EPOCHES = 100
for epoch in  range(EPOCHES):
    for data in tqdm(data_loader):
        # try:

            real_data,anime_data = [i/127.5-1 for i in data]
            real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
            anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])

            if step%3 != 0:
                for p in real2anime_generator.parameters():
                    p.stop_gradient = False
                for p in anime2real_generator.parameters():
                    p.stop_gradient = True
            else:
                for p in real2anime_generator.parameters():
                    p.stop_gradient = True
                for p in anime2real_generator.parameters():
                    p.stop_gradient = False
            for p in discriminator.parameters():
                p.stop_gradient = True
            img_fake_B2A = real2anime_generator(real_data,anime_data)
            _,_,_,fake_output_B2A = discriminator(img_fake_B2A,"A")
            g_adversarial_loss_B2A = adversarial_loss(fake_output_B2A,paddle.ones_like(fake_output_B2A))*10

            img_fake_A2B = anime2real_generator(anime_data,real_data)
            _,_,_,fake_output_A2B = discriminator(img_fake_A2B,"B")
            g_adversarial_loss_A2B = adversarial_loss(fake_output_A2B,paddle.ones_like(fake_output_A2B))*10

            img_fake_B2A2B = anime2real_generator(img_fake_B2A,real_data)
            img_fake_A2B2A = real2anime_generator(img_fake_A2B,anime_data)
            rec_loss_B = l1_loss(img_fake_B2A2B,real_data)*100
            rec_loss_A = l1_loss(img_fake_A2B2A,anime_data)*100

            u0a,u1a,fa,_ = discriminator(anime_data,"A")
            u0ga,u1ga,fga,_=discriminator(img_fake_A2B2A,"A")

            u0r,u1r,fr,_ = discriminator(real_data,"B")
            u0gr,u1gr,fgr,_=discriminator(img_fake_B2A2B,"B")

            feature_matching_loss = l1_loss(u0a,u0ga) +l1_loss(u1a,u1ga)+l1_loss(u0r,u0gr) +l1_loss(u1r,u1gr)+l1_loss(fa,fga) +l1_loss(fr,fgr)
            feature_matching_loss*=100

            if step%3 != 0:            
                g_loss = g_adversarial_loss_B2A + feature_matching_loss + rec_loss_B +rec_loss_A
            else:
                g_loss = g_adversarial_loss_A2B + feature_matching_loss + rec_loss_B +rec_loss_A
            # g_loss = g_adversarial_loss_B2A  + feature_matching_loss + rec_loss_B

            # g_loss = feature_matching_loss + rec_loss
            
            g_loss.backward()
            if step%3 != 0:
                real2anime_optimizer_G.step()
                real2anime_optimizer_G.clear_grad()
            else:
                anime2real_optimizer_G.step()
                anime2real_optimizer_G.clear_grad() 



            for p in real2anime_generator.parameters():
                p.stop_gradient = True
            for p in anime2real_generator.parameters():
                p.stop_gradient = True
            for p in discriminator.parameters():
                p.stop_gradient = False

            img_fake_A2B0 = step_B_cache(img_fake_A2B,step)
            img_fake_B2A0 = step_A_cache(img_fake_B2A,step)
            _,_,_,fake_output_B2A = discriminator(img_fake_B2A0.detach(),"A")
            _,_,_,fake_output_A2B = discriminator(img_fake_A2B0.detach(),"B")
            # _,_,_,fake_output_A2B2A = discriminator(img_fake_A2B2A.detach(),"A")
            # _,_,_,fake_output_B2A2B = discriminator(img_fake_B2A2B.detach(),"B")
            
            _,_,_,real_output = discriminator(real_data,"B")
            _,_,_,anime_output = discriminator(anime_data,"A")
            d_B2A_loss = adversarial_loss(fake_output_B2A, paddle.zeros_like(fake_output_B2A))*10
            d_A2B_loss = adversarial_loss(fake_output_A2B,paddle.zeros_like(fake_output_A2B))*10
            # d_A2B2A_loss = adversarial_loss(fake_output_A2B2A, paddle.zeros_like(fake_output_A2B2A))*10
            # d_B2A2B_loss = adversarial_loss(fake_output_B2A2B, paddle.zeros_like(fake_output_B2A2B))*10
            d_real_loss = adversarial_loss(real_output, paddle.ones_like(real_output))*10
            d_anime_loss = adversarial_loss(anime_output, paddle.ones_like(anime_output))*10

            d_loss = d_B2A_loss+d_A2B_loss+d_real_loss+d_anime_loss
            # d_loss = d_B2A_loss+d_A2B2A_loss+d_B2A2B_loss+d_real_loss+d_anime_loss
            
            d_loss.backward()
            optimizer_D.step()
            optimizer_D.clear_grad()

            if step%2==0:

                log_writer.add_scalar(tag='train/g_adversarial_loss_B2A', step=step, value=g_adversarial_loss_B2A.numpy()[0])
                log_writer.add_scalar(tag='train/g_adversarial_loss_A2B', step=step, value=g_adversarial_loss_A2B.numpy()[0])

                log_writer.add_scalar(tag='train/feature_matching_loss', step=step, value=feature_matching_loss.numpy()[0])
                log_writer.add_scalar(tag='train/rec_lossA', step=step, value=rec_loss_A.numpy()[0])
                log_writer.add_scalar(tag='train/rec_lossB', step=step, value=rec_loss_B.numpy()[0])
                log_writer.add_scalar(tag='train/g_loss', step=step, value=g_loss.numpy()[0])


                log_writer.add_scalar(tag='train/d_B2A_loss', step=step, value=d_B2A_loss.numpy()[0])
                log_writer.add_scalar(tag='train/d_A2B_loss', step=step, value=d_A2B_loss.numpy()[0])
                log_writer.add_scalar(tag='train/d_real_loss', step=step, value=d_real_loss.numpy()[0])
                log_writer.add_scalar(tag='train/d_anime_loss', step=step, value=d_anime_loss.numpy()[0])
                log_writer.add_scalar(tag='train/d_loss', step=step, value=d_loss.numpy()[0])

                




            step+=1
            # print(i)
            if step%100 == 3:
                print(step,"g_adversarial_loss_B2A",g_adversarial_loss_B2A.numpy()[0],"feature_matching_loss",feature_matching_loss.numpy()[0],"rec_lossA",rec_loss_A.numpy()[0],\
                "g_loss",g_loss.numpy()[0] ,"d_loss",d_loss.numpy()[0],)

                real_data = (real_data+1)*127.5
                anime_data = (anime_data+1)*127.5
                img_fake_B2A = (img_fake_B2A+1)*127.5
                img_fake_A2B = (img_fake_A2B+1)*127.5

                g_output = paddle.concat([img_fake_B2A,img_fake_A2B,real_data,anime_data],axis = 3).detach().numpy()                      # tensor -> numpy
                g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
                g_output = g_output.astype(np.uint8)
                cv2.imwrite(os.path.join("./result", 'epoch'+str(step).zfill(3)+'.png'),cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))


            
            if step%100 == 3:

                save_param_path_d = os.path.join(save_dir_Discriminator, 'Dmodel_state'+str(2)+'.pdparams')
                paddle.save(discriminator.state_dict(), save_param_path_d)

                save_param_path_mA = os.path.join(save_dir_model, 'Amodel_state'+str(2)+'.pdparams')
                paddle.save(real2anime_generator.state_dict(), save_param_path_mA)

                save_param_path_mB = os.path.join(save_dir_model, 'Bmodel_state'+str(2)+'.pdparams')
                paddle.save(anime2real_generator.state_dict(), save_param_path_mB)
        # except:
            # pass
    scheduler_G.step()
    scheduler_D.step()

3.7 测试代码

from paddle.vision.transforms import CenterCrop,Resize
transform = Resize((256,256))
generator = Model()
generator.set_state_dict(paddle.load("model_params/Amodel_state2.pdparams"))
anime_path = "data/anime_face/faces/000215-01.jpg"
real_path = "data/real_face/cartoon_A2B/test/01434.png"
anime_data = cv2.cvtColor(cv2.imread(anime_path, flags=cv2.IMREAD_COLOR),cv2.COLOR_BGR2RGB)
real_data = cv2.cvtColor(cv2.imread(real_path, flags=cv2.IMREAD_COLOR),cv2.COLOR_BGR2RGB)[:,:256,:]

anime_data = transform(anime_data)
real_data = transform(real_data)


anime_data = paddle.to_tensor(anime_data,dtype=paddle.float32).unsqueeze(0)
real_data = paddle.to_tensor(real_data,dtype=paddle.float32).unsqueeze(0)

real_data,anime_data = (real_data/127.5-1,anime_data/127.5-1)
real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
# real_data = paddle.zeros([1,3,256,256],dtype=paddle.float32)
img_fake_B2A = generator(real_data,anime_data)
anime_data = (anime_data+1)*127.5
real_data = (real_data+1)*127.5
img_fake_B2A = (img_fake_B2A+1)*127.5

g_output = paddle.concat([img_fake_B2A,real_data,anime_data],axis = 3).detach().numpy()                      # tensor -> numpy
g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
g_output = g_output.astype(np.uint8)
cv2.imwrite(os.path.join("./test", 'epoch32'+'.png'),cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))

4. 总结

接下来我会试图写一个paddle版本的U-GAT-IT自己实验下,然后这个建立一个缓存池方便判别器判别之前的图片而不是生成器同一个iter生成的图片这个技巧很好用,然后就是如果大家自己现实有好用的GPU,大家可以自己接着训练,我给的参数文件并不是最好的结果,因为我一不小心删掉了,然后现在这个参数文件是过程中的。建议大家从头训练,体会过程。

欢迎fork!!!让我为了训练几百算力的损失稍微回一点血,谢谢大家!!!

此文仅为搬运,原作链接:https://aistudio.baidu.com/aistudio/projectdetail/4355679

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐