转自AI Studio,原文链接:基于参考图给线稿上色项目V2(SCFT),不吹不黑,你值得拥有 - 飞桨AI Studio

1.激动人心的show部分

训练效果展示:

2. 分析任务部分

首先什么是基于参考图给线稿上色任务?

给出一张线稿A,给出一张已经上色好的图片B,我们需要让模型把B的色彩合理的迁移到A的线稿架构上得到最终图片C,那么对于C来说,A提供的是主要结构性信息,B提供的是色彩信息,并且模型需要让B的色彩对于A的结构进行一个自适应扭曲。

那么这个适合就需要提一提我们的成对数据集了,线稿和对应的上色图,这里展示一下:

 上色图和线稿都是512*512大小,这是一张图片,h和w分别为512,1024.那么色彩参考图就是取W维度的前512然后进行TPS插值扭曲,线稿就是W维度后512,然后为了好训练,我把色彩参考图和线稿都缩小为256*256

好了这就是我们的数据集,注意如果我们直接把色彩参考图就等于上色成功的图片,这样模型就会很大依赖色彩参考图的信息,而不会把色彩基于线稿结构进行扭曲,但是我们实际测试的过程中怎么会出现这种情况呢?所以会出现这种结果,以下是错误结果展示: 

3.论文讲解部分

本项目主要基于这篇论文进行复现展示

论文题目:Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence

论文总体框架

接下来我选取重点进行分析讲解:

1. 色彩固定(模型存在色彩偏见)的问题:

就是我们不能让模型学习到看到苹果的样子就是红色的,也就是苹果现实过程中也有可能是黄色的,或者绿色的。避免模型有这种色彩偏见。 解决这个问题,就是给图像I增加一个色彩噪声(a(*)),得到Igt作为ground truth,这样就可以极大的缓解这个色彩偏见的问题,算一种数据增强吧。

但是实际我自认为数据集够大,1w多张照片,于是就没有进行这步操作了,其实主要是懒得找这部分代码,哈哈哈哈。

2. 训练和测试参考图差距很大的问题(important)

如果输入的参考图(提供色彩信息)和Ground Truth一样,那么模型就会很怠惰的大幅度依赖参考图的结构信息,就是不能把参考图的结构和色彩信息区分开,这个是一大痛点,因为在测试的过程中参考图和模型的空间结构信息是差距很大的,也就是说我们需要让模型有把参考图按照线稿结构扭曲的能力。 因此我们不能把Igt当作参考图,我们需要给Igt添加一个空间扭曲这样得到的结果为参考图,这个空间扰动就是通过s(*)TPS(the thin plate splines)实现的.

3. 特征连接,就是把每层的特征层都缩小到最核心特征的大小,然后再按照特征维度进行一个concat.

4.通过模型架构层面让参考图的信息能结合到线稿的结构信息中去(SCFT)。

就是qkv的思想,q是线稿结构信息,k和v为参考图特征信息。就是让色彩信息基于线稿结构信息进行扭曲罢了。

损失部分

这个trip部分约束的是q和k部分应该也有s(*)空间扭曲的信息,这个部分我没怎么看懂。 

 

其他就是VGG的特征对齐,生成图片和GT的MSELoss,利用Gram矩阵的style loss,还有GAN的对抗loss。

我参考的github代码(pytorch):https://github.com/seungjae-won/paper-to-code-Reference-Based-Sketch-Image-Colorization/blob/master/model.py

4.代码部分

可一键运行,最后代码块为测试代码块

4.1 实现TPS薄板样条插值进行扭曲部分

In [1]

import paddle.nn.functional as F
import os
import numpy as np
import paddle
import paddle.nn as nn
import math




class LocalizationNetwork(nn.Layer):
    """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """

    def __init__(self, F, I_channel_num):
        super(LocalizationNetwork, self).__init__()
        self.F = F
        self.I_channel_num = I_channel_num
        self.conv = nn.Sequential(
            nn.Conv2D(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
                      bias_attr=False), nn.BatchNorm2D(64), nn.ReLU(True),
            nn.MaxPool2D(2, 2),  # batch_size x 64 x I_height/2 x I_width/2
            nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False), nn.BatchNorm2D(128), nn.ReLU(True),
            nn.MaxPool2D(2, 2),  # batch_size x 128 x I_height/4 x I_width/4
            nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False), nn.BatchNorm2D(256), nn.ReLU(True),
            nn.MaxPool2D(2, 2),  # batch_size x 256 x I_height/8 x I_width/8
            nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False), nn.BatchNorm2D(512), nn.ReLU(True),
            nn.AdaptiveAvgPool2D(1)  # batch_size x 512
        )

        self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
        self.localization_fc2 = nn.Linear(256, self.F * 2)

        # Init fc2 in LocalizationNetwork
        # self.localization_fc2.weight.set_value(paddle.zeros([256,20]))   #原代码是不注释的
        """ see RARE paper Fig. 6 (a) """
        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
        ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
        ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        initial_bias_attr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
        self.localization_fc2.bias.value = paddle.to_tensor(initial_bias_attr).astype("float32").reshape([-1])

    def forward(self, batch_I):
        """
        input:     batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
        output:    batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
        """
        batch_size = batch_I.shape[0]
        features = self.conv(batch_I).reshape([batch_size, -1])
        batch_C_prime = self.localization_fc2(self.localization_fc1(features)).reshape([batch_size, self.F, 2])
        return batch_C_prime


class GridGenerator(nn.Layer):
    """ Grid Generator of RARE, which produces P_prime by multipling T with P """

    def __init__(self, F, I_r_size):
        """ Generate P_hat and inv_delta_C for later """
        super(GridGenerator, self).__init__()
        self.eps = 1e-6
        self.I_r_height, self.I_r_width = I_r_size
        self.F = F
        self.C = self._build_C(self.F)  # F x 2
        self.P = self._build_P(self.I_r_width, self.I_r_height)
        a = paddle.to_tensor(self._build_inv_delta_C(self.F, self.C)).astype("float32")
        self.register_buffer("inv_delta_C",a )  # F+3 x F+3
        self.register_buffer("P_hat", paddle.to_tensor(self._build_P_hat(self.F, self.C, self.P)).astype("float32"))  # n x F+3

    def _build_C(self, F):
        """ Return coordinates of fiducial points in I_r; C """
        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
        ctrl_pts_y_top = -1 * np.ones(int(F / 2))
        ctrl_pts_y_bottom = np.ones(int(F / 2))
        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
        return C  # F x 2

    def _build_inv_delta_C(self, F, C):
        """ Return inv_delta_C which is needed to calculate T """
        hat_C = np.zeros((F, F), dtype=float)  # F x F
        for i in range(0, F):
            for j in range(i, F):
                r = np.linalg.norm(C[i] - C[j])
                hat_C[i, j] = r
                hat_C[j, i] = r
        np.fill_diagonal(hat_C, 1)
        hat_C = (hat_C ** 2) * np.log(hat_C)
        # print(C.shape, hat_C.shape)
        delta_C = np.concatenate(  # F+3 x F+3
            [
                np.concatenate([np.ones((F, 1)), C, hat_C], axis=1),  # F x F+3
                np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1),  # 2 x F+3
                np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1)  # 1 x F+3
            ],
            axis=0
        )
        inv_delta_C = np.linalg.inv(delta_C)
        return inv_delta_C  # F+3 x F+3

    def _build_P(self, I_r_width, I_r_height):
        I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width  # self.I_r_width
        I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height  # self.I_r_height
        P = np.stack(  # self.I_r_width x self.I_r_height x 2
            np.meshgrid(I_r_grid_x, I_r_grid_y),
            axis=2
        )
        return P.reshape([-1, 2])  # n (= self.I_r_width x self.I_r_height) x 2

    def _build_P_hat(self, F, C, P):
        n = P.shape[0]  # n (= self.I_r_width x self.I_r_height)
        P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1))  # n x 2 -> n x 1 x 2 -> n x F x 2
        C_tile = np.expand_dims(C, axis=0)  # 1 x F x 2
        P_diff = P_tile - C_tile  # n x F x 2
        rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False)  # n x F
        rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps))  # n x F
        P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
        return P_hat  # n x F+3

    def build_P_prime(self, batch_C_prime):
        """ Generate Grid from batch_C_prime [batch_size x F x 2] """
        batch_size = batch_C_prime.shape[0]
        batch_inv_delta_C = self.inv_delta_C.tile([batch_size, 1, 1])
        batch_P_hat = self.P_hat.tile([batch_size, 1, 1])
        batch_C_prime_with_zeros = paddle.concat((batch_C_prime, paddle.zeros([
            batch_size, 3, 2]).astype("float32")), axis=1)  # batch_size x F+3 x 2
        batch_T = paddle.matmul(batch_inv_delta_C, batch_C_prime_with_zeros)  # batch_size x F+3 x 2
        batch_P_prime = paddle.matmul(batch_P_hat, batch_T)  # batch_size x n x 2
        return batch_P_prime  # batch_size x n x 2



class TPS_SpatialTransformerNetwork(nn.Layer):
    """ Rectification Network of RARE, namely TPS based STN """

    def __init__(self, F =10 , I_size = (256,256), I_r_size = (256,256), I_channel_num= 3):
        """ Based on RARE TPS
        input:
            batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
            I_size : (height, width) of the input image I
            I_r_size : (height, width) of the rectified image I_r
            I_channel_num : the number of channels of the input image I
        output:
            batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
        """
        super(TPS_SpatialTransformerNetwork, self).__init__()
        self.F = F
        self.I_size = I_size
        self.I_r_size = I_r_size  # = (I_r_height, I_r_width)
        self.I_channel_num = I_channel_num
        self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
        self.GridGenerator = GridGenerator(self.F, self.I_r_size)

    def forward(self, batch_I):
        batch_C_prime = self.LocalizationNetwork(batch_I)  # batch_size x K x 2
        build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime)  # batch_size x n (= I_r_width x I_r_height) x 2
        build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0], self.I_r_size[0], self.I_r_size[1], 2])
        batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners = False)

        return batch_I_r

In [2]

import cv2
from PIL import Image

from paddle.vision.transforms import CenterCrop,Resize
from paddle.vision.transforms import RandomRotation
# transform0 = RandomRotation(90)
transform1 = CenterCrop((256,256))

'''
放缩到[0,1]和放缩到[-1,1]都行,自己看实际效果

'''
# img = paddle.randn([1,3,256,256])
img_A = cv2.imread("QQ截图20220416214500.jpg")#内容图
img_A = transform1(img_A)
cv2.imwrite("before3.jpg",img_A)
img_A = cv2.cvtColor(img_A,cv2.COLOR_BGR2RGB)

g_input = img_A.astype('float32')/255         # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
x = (TPS_SpatialTransformerNetwork()(g_input)*255).clip(0,255)[0].transpose([1,2,0]).numpy()
# print(x)
cv2.imwrite("after3.jpg",cv2.cvtColor(x,cv2.COLOR_RGB2BGR))



img_A = cv2.imread("QQ截图20220416214500.jpg")#内容图
img_A = transform1(img_A)
cv2.imwrite("before5.jpg",img_A)
img_A = cv2.cvtColor(img_A,cv2.COLOR_BGR2RGB)

g_input = img_A.astype('float32')/127.5-1         # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
x = ((TPS_SpatialTransformerNetwork()(g_input)+1)*127.5).clip(0,255)[0].transpose([1,2,0]).numpy()
# print(x)
cv2.imwrite("after5.jpg",cv2.cvtColor(x,cv2.COLOR_RGB2BGR))
W0417 22:04:35.689335 16805 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0417 22:04:35.692854 16805 device_context.cc:465] device: 0, cuDNN Version: 7.6.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:653: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")
True

TPS利用上方代码块测试

before:

after:

 

4.2 模型主体架构部分

In [3]

import paddle
import paddle.nn as nn
class ResBlock(nn.Layer):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        
        def block(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
            layers = []
            layers += [nn.Conv2D(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias_attr =bias)]
            layers += [nn.BatchNorm2D(num_features=out_channels)]
            
            layers += [nn.ReLU()]
            layers += [nn.Conv2D(in_channels=out_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias_attr =bias)]
            layers += [nn.BatchNorm2D(num_features=out_channels)]

            cbr = nn.Sequential(*layers)

            return cbr
        
        self.block_1 = block(in_channels,out_channels)
        self.block_2 = block(out_channels,out_channels)
        self.block_3 = block(out_channels,out_channels)
        self.block_4 = block(out_channels,out_channels)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        # block 1
        residual = x
        out = self.block_1(x)
        out = self.relu(out)
        
        # block 2
        residual = out
        out = self.block_2(out)
        out += residual
        out = self.relu(out)
        
        # block 3
        residual = out
        out = self.block_3(out)
        out += residual
        out = self.relu(out)
        
        # block 4
        residual = out
        out = self.block_4(out)
        out += residual
        out = self.relu(out)
        
        return out

x = paddle.randn([4,3,256,256])
ResBlock(3,7)(x).shape
[4, 7, 256, 256]

In [4]

import paddle
import paddle.nn as nn
class Encoder(nn.Layer):
    
    def __init__(self, in_channels = 3):
        super(Encoder, self).__init__()
        
        def CL2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, LR_negative_slope=0.2):
            layers = []
            layers += [nn.Conv2D(in_channels=in_channels, out_channels=out_channels,
                                kernel_size=kernel_size, stride=stride, padding=padding,
                                bias_attr = bias)]
            layers += [nn.LeakyReLU(LR_negative_slope)]
            cbr = nn.Sequential(*layers)
            return cbr
        
        # conv_layer
        self.conv1 = CL2d(in_channels,16)
        self.conv2 = CL2d(16,16)
        self.conv3 = CL2d(16,32,stride=2)
        self.conv4 = CL2d(32,32)
        self.conv5 = CL2d(32,64,stride=2)
        self.conv6 = CL2d(64,64)
        self.conv7 = CL2d(64,128,stride=2)
        self.conv8 = CL2d(128,128)
        self.conv9 = CL2d(128,256,stride=2)
        self.conv10 = CL2d(256,256)
        
        # downsample_layer
        self.downsample1 = nn.AvgPool2D(kernel_size=16, stride=16)
        self.downsample2 = nn.AvgPool2D(kernel_size=8, stride=8)
        self.downsample3 = nn.AvgPool2D(kernel_size=4, stride=4)
        self.downsample4 = nn.AvgPool2D(kernel_size=2, stride=2)
        
    def forward(self, x):

        f1 = self.conv1(x)
        f2 = self.conv2(f1)
        f3 = self.conv3(f2)
        f4 = self.conv4(f3)
        f5 = self.conv5(f4)
        f6 = self.conv6(f5)
        f7 = self.conv7(f6)
        f8 = self.conv8(f7)
        f9 = self.conv9(f8)
        f10 = self.conv10(f9)
        
        F = [f9, f8, f7, f6, f5, f4, f3, f2 ,f1]
        
        v1 = self.downsample1(f1)
        v2 = self.downsample1(f2)
        v3 = self.downsample2(f3)
        v4 = self.downsample2(f4)
        v5 = self.downsample3(f5)
        v6 = self.downsample3(f6)
        v7 = self.downsample4(f7)
        v8 = self.downsample4(f8)

        V = paddle.concat((v1,v2,v3,v4,v5,v6,v7,v8,f9,f10), axis=1)
        h,w = V.shape[2],V.shape[3]
        V = paddle.reshape(V,(V.shape[0],V.shape[1],h*w))
        V = paddle.transpose(V,[0,2,1])
        
        return V,F,(h,w)
x = paddle.randn([4,3,256,256])
a,b,_ = Encoder()(x)
print(a.shape)
[4, 256, 992]

In [5]


class UNetDecoder(nn.Layer):
    def __init__(self):
        super(UNetDecoder, self).__init__()

        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            layers += [nn.Conv2D(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias_attr=bias)]
            # layers += [nn.BatchNorm2D(num_features=out_channels)]
            layers += [nn.InstanceNorm2D(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr
        

        self.dec5_1 = CBR2d(in_channels=992+992, out_channels=256)
        self.unpool4 = nn.Conv2DTranspose(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias_attr=True)

        self.dec4_2 = CBR2d(in_channels=512+128, out_channels=128)
        self.dec4_1 = CBR2d(in_channels=128+128, out_channels=128)
        self.unpool3 = nn.Conv2DTranspose(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias_attr=True)

        self.dec3_2 = CBR2d(in_channels=128+64, out_channels=64)
        self.dec3_1 = CBR2d(in_channels=64+64, out_channels=64)
        self.unpool2 = nn.Conv2DTranspose(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias_attr=True)

        self.dec2_2 = CBR2d(in_channels=64+32, out_channels=32)
        self.dec2_1 = CBR2d(in_channels=32+32, out_channels=32)
        self.unpool1 = nn.Conv2DTranspose(in_channels=32, out_channels=32,
                                          kernel_size=2, stride=2, padding=0, bias_attr=True)

        self.dec1_2 = CBR2d(in_channels=32+16, out_channels=16)
        self.dec1_1 = CBR2d(in_channels=16+16, out_channels=16)

        self.fc = nn.Conv2D(in_channels=16, out_channels=3, kernel_size=1, stride=1, padding=0, bias_attr=True)

    def forward(self, x, F):
        
        dec5_1 = self.dec5_1(x)
        unpool4 = self.unpool4(paddle.concat((dec5_1,F[0]),axis=1))

        dec4_2 = self.dec4_2(paddle.concat((unpool4,F[1]),axis=1))
        dec4_1 = self.dec4_1(paddle.concat((dec4_2,F[2]),axis=1))
        unpool3 = self.unpool3(dec4_1)

        dec3_2 = self.dec3_2(paddle.concat((unpool3,F[3]),axis=1))
        dec3_1 = self.dec3_1(paddle.concat((dec3_2,F[4]),axis=1))
        unpool2 = self.unpool2(dec3_1)

        dec2_2 = self.dec2_2(paddle.concat((unpool2,F[5]),axis=1))
        dec2_1 = self.dec2_1(paddle.concat((dec2_2,F[6]),axis=1))
        unpool1 = self.unpool1(dec2_1)
        
        dec1_2 = self.dec1_2(paddle.concat((unpool1,F[7]),axis=1))
        dec1_1 = self.dec1_1(paddle.concat((dec1_2, F[8]),axis=1))

        x = self.fc(dec1_1)
        x = nn.Tanh()(x)
        
        return x

In [6]

import math
import paddle.nn.functional as F
class SCFT(nn.Layer):
    
    def __init__(self, sketch_channels, reference_channels, dv=992):
        super(SCFT, self).__init__()
        
        self.dv = paddle.to_tensor(dv).astype("float32")
        
        self.w_q = nn.Linear(dv,dv)
        self.w_k = nn.Linear(dv,dv)
        self.w_v = nn.Linear(dv,dv)
        
    def forward(self, Vs, Vr,shape):
        h,w = shape
        quary = self.w_q(Vs)
        key = self.w_k(Vr)
        value = self.w_v(Vr)

        c = paddle.add(self.scaled_dot_product(quary,key,value), Vs)
        
        c = paddle.transpose(c,[0,2,1])
        c = paddle.reshape(c,(c.shape[0],c.shape[1],h,w))
        
        return c, quary, key, value

    def masked_fill(self,x, mask, value):
        y = paddle.full(x.shape, value, x.dtype)
        return paddle.where(mask, y, x)


    # https://www.quantumdl.com/entry/11%EC%A3%BC%EC%B0%A82-Attention-is-All-You-Need-Transformer
    def scaled_dot_product(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        d_k = query.shape[-1]
        # print(key.shape)
        scores = paddle.matmul(query, key.transpose([0,2, 1])) \
                / math.sqrt(d_k)
        if mask is not None:
            scores = self.masked_fill(scores,mask == 0, -1e9)
        p_attn = F.softmax(scores, axis = -1)

        if dropout is not None:
            p_attn = nn.Dropout(0.2)(p_attn)
        return paddle.matmul(p_attn, value)

In [7]

import paddle
import paddle.nn as nn
class Generator(nn.Layer):
    
    def __init__(self, sketch_channels=1, reference_channels=3, LR_negative_slope=0.2):
        super(Generator, self).__init__()
        
        self.encoder_sketch = Encoder(sketch_channels)
        self.encoder_reference = Encoder(reference_channels)
        self.scft = SCFT(sketch_channels, reference_channels)
        self.resblock = ResBlock(992, 992)
        self.unet_decoder = UNetDecoder()
    
    def forward(self, sketch_img, reference_img):
        
        # encoder 
        Vs, F,shape = self.encoder_sketch(sketch_img)
        Vr, _ ,_= self.encoder_reference(reference_img)
        
        # scft
        c, quary, key, value = self.scft(Vs,Vr,shape)
        
        # resblock
        c_out = self.resblock(c)
        
        # unet decoder
        I_gt = self.unet_decoder(paddle.concat((c,c_out),axis=1), F)

        return I_gt, quary, key, value

In [8]

s = paddle.randn([4,1,512,512])
r = paddle.randn([4,3,512,512])
out,q,k,v = Generator()(s,r)
print(out.shape)
[4, 3, 512, 512]

In [9]

import paddle
import paddle.nn as nn

# https://github.com/meliketoy/LSGAN.pytorch/blob/master/networks/Discriminator.py
# LSGAN Discriminator
class Discriminator(nn.Layer):
    def __init__(self, ndf, nChannels):
        super(Discriminator, self).__init__()
        # input : (batch * nChannels * image width * image height)
        # Discriminator will be consisted with a series of convolution networks

        self.layer1 = nn.Sequential(
            # Input size : input image with dimension (nChannels)*64*64
            # Output size: output feature vector with (ndf)*32*32
            nn.Conv2D(
                in_channels = nChannels,
                out_channels = ndf,
                kernel_size = 4,
                stride = 2,
                padding = 1,
                bias_attr = False
            ),
            nn.BatchNorm2D(ndf),
            nn.LeakyReLU(0.2)
        )

        self.layer2 = nn.Sequential(
            # Input size : input feature vector with (ndf)*32*32
            # Output size: output feature vector with (ndf*2)*16*16
            nn.Conv2D(
                in_channels = ndf,
                out_channels = ndf*2,
                kernel_size = 4,
                stride = 2,
                padding = 1,
                bias_attr = False
            ),
            nn.BatchNorm2D(ndf*2),
            nn.LeakyReLU(0.2)
        )

        self.layer3 = nn.Sequential(
            # Input size : input feature vector with (ndf*2)*16*16
            # Output size: output feature vector with (ndf*4)*8*8
            nn.Conv2D(
                in_channels = ndf*2,
                out_channels = ndf*4,
                kernel_size = 4,
                stride = 2,
                padding = 1,
                bias_attr = False
            ),
            nn.BatchNorm2D(ndf*4),
            nn.LeakyReLU(0.2)
        )

        self.layer4 = nn.Sequential(
            # Input size : input feature vector with (ndf*4)*8*8
            # Output size: output feature vector with (ndf*8)*4*4
            nn.Conv2D(
                in_channels = ndf*4,
                out_channels = ndf*8,
                kernel_size = 4,
                stride = 2,
                padding = 1,
                bias_attr = False
            ),
            nn.BatchNorm2D(ndf*8),
            nn.LeakyReLU(0.2)
        )

        self.layer5 = nn.Sequential(
            # Input size : input feature vector with (ndf*8)*4*4
            # Output size: output probability of fake/real image
            nn.Conv2D(
                in_channels = ndf*8,
                out_channels = 1,
                kernel_size = 4,
                stride = 1,
                padding = 0,
                bias_attr = False
            ),
            # nn.Sigmoid() -- Replaced with Least Square Loss
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)

        return out
x = paddle.randn([4,3,256,256])
Discriminator(64,3)(x).shape
[4, 1, 13, 13]

4.3 模型数据集啊损失啊杂七杂八部分

In [10]


# 解压数据集,只需执行一次
import os
if not os.path.isdir("./data/d"):
    os.mkdir("./data/d")
! unzip -qo data/data128161/archive.zip -d ./data/d

In [11]

from VGG_Model import VGG19
import paddle
VGG = VGG19()
x = paddle.randn([4,3,256,256])
b = VGG(x)
for i in b:
    print(i.shape)
[4, 64, 256, 256]
[4, 128, 128, 128]
[4, 256, 64, 64]
[4, 512, 32, 32]
[4, 512, 16, 16]

In [12]

from visualdl import LogWriter
log_writer = LogWriter("./log/gnet")

In [13]


from paddle.vision.transforms import CenterCrop,Resize
transform = Resize((256,256))
#构造dataset
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
import paddle
import cv2
import os
def data_maker(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname) and ("outfit" not in fname):
                path = os.path.join(root, fname)
                images.append(path)

    return sorted(images)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


class AnimeDataset(paddle.io.Dataset):
    """
    """
    def __init__(self):
        super(AnimeDataset,self).__init__()
        self.anime_image_dirs =data_maker("data/d/data/train")
        self.size = len(self.anime_image_dirs)
    # cv2.imread直接读取为GBR,把通道换成RGB
    @staticmethod
    def loader(path):
        return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
                            cv2.COLOR_BGR2RGB)
    def __getitem__(self, index):
        img = AnimeDataset.loader(self.anime_image_dirs[index])
        img_a = img[:,:512,:]
        img_a =transform(img_a)
        img_b = img[:,512:,:]
        img_b = transform(img_b)
        
        return img_a,img_b

    def __len__(self):
        return self.size

In [14]

for a,b in AnimeDataset():
    print(a.shape,b.shape)
    break
(256, 256, 3) (256, 256, 3)

In [15]

batch_size = 16
datas = AnimeDataset()
data_loader =  paddle.io.DataLoader(datas,batch_size=batch_size,shuffle =True,drop_last=True)
for input_img,masks in data_loader:
    print(input_img.shape,masks.shape)
    break
[16, 256, 256, 3] [16, 256, 256, 3]

In [16]

generator = Generator()
discriminator = Discriminator(16,4)
tps_transformation = TPS_SpatialTransformerNetwork(F=10, I_size=(256, 256), 
                                                               I_r_size=(256, 256), 
                                                               I_channel_num=3)

In [17]


scheduler_G = paddle.optimizer.lr.StepDecay(learning_rate=1e-4, step_size=3, gamma=0.9, verbose=True)
scheduler_D = paddle.optimizer.lr.StepDecay(learning_rate=2e-4, step_size=3, gamma=0.9, verbose=True)

optimizer_G = paddle.optimizer.Adam(learning_rate=scheduler_G,parameters=generator.parameters(),beta1=0.5, beta2 =0.999)
optimizer_D = paddle.optimizer.Adam(learning_rate=scheduler_D,parameters=discriminator.parameters(),beta1=0.5, beta2 =0.999)
Epoch 0: StepDecay set learning rate to 0.0001.
Epoch 0: StepDecay set learning rate to 0.0002.

In [18]

# # model和discriminator参数文件导入
# M_path ='model_params/Mmodel_state3.pdparams'
# layer_state_dictm = paddle.load(M_path)
# generator.set_state_dict(layer_state_dictm)


# D_path ='discriminator_params/Dmodel_state3.pdparams'
# layer_state_dictD = paddle.load(D_path)
# discriminator.set_state_dict(layer_state_dictD)

In [19]

EPOCHEES = 30
i = 0

In [20]

save_dir_model = "model_params"
save_dir_Discriminator = "discriminator_params"

In [21]

def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)

def style_loss(fake, style):

    gram_loss = nn.L1Loss()(gram(fake), gram(style))
    return gram_loss

In [22]

def scaled_dot_product(query, key, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.shape[-1]
    scores = paddle.matmul(query, key.transpose([0,2, 1])) \
                / math.sqrt(d_k)
    return scores

triplet_margin = 12
def similarity_based_triple_loss(anchor, positive, negative):
    distance = scaled_dot_product(anchor, positive) - scaled_dot_product(anchor, negative) + triplet_margin
    loss = paddle.mean( paddle.maximum(distance, paddle.zeros_like(distance)))
    return loss

In [23]

from tqdm import tqdm

In [24]

# # 训练代码,如果想训练就取消注释
# adversarial_loss = paddle.nn.MSELoss()
# l1_loss = nn.L1Loss()
# step =0
# for epoch in range(EPOCHEES):
#     # if(step >1000):
#         # break
#     for appearance_img, sketch_img in tqdm(data_loader):
#         # try:
#             # if(step >1000):
#             #     break
#             # print(input_img.shape,mask.shape)
#             appearance_img =paddle.transpose(x=appearance_img.astype("float32")/127.5-1,perm=[0,3,1,2])
#             # color_noise = paddle.tanh(paddle.randn(shape = appearance_img.shape))
#             # appearance_img += color_noise
#             # appearance_img = paddle.tanh(appearance_img)

#             sketch_img = paddle.max( paddle.transpose(x=sketch_img.astype("float32")/255,perm=[0,3,1,2]),axis=1,keepdim=True)

#             reference_img = tps_transformation(appearance_img).detach()

#             # ---------------------
#             #  Train Generator
#             # ---------------------
#             fake_I_gt, quary, key, value = generator(sketch_img,reference_img)
#             fake_output = discriminator(paddle.concat((fake_I_gt,sketch_img), axis=1))
#             g_adversarial_loss = adversarial_loss(fake_output,paddle.ones_like(fake_output))
#             g_l1_loss = l1_loss(fake_I_gt, appearance_img)*20
#             g_triplet_loss = similarity_based_triple_loss(quary, key, value)
#             g_vggloss = paddle.to_tensor(0.)
#             g_style_loss= paddle.to_tensor(0.)
#             rates = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
#             # _, fake_features = VGG( paddle.multiply (img_fake,loss_mask))
#             # _, real_features = VGG(paddle.multiply (input_img,loss_mask))

#             fake_features = VGG(fake_I_gt)
#             real_features = VGG(appearance_img)

#             for i in range(len(fake_features)):
#                 a,b = fake_features[i], real_features[i]
#                 # if i ==len(fake_features)-1:
#                 #     a = paddle.multiply( a,F.interpolate(loss_mask,a.shape[-2:]))
#                 #     b = paddle.multiply( b,F.interpolate(loss_mask,b.shape[-2:]))
#                 g_vggloss += rates[i] * l1_loss(a,b)
#                 g_style_loss += rates[i] *  style_loss(a,b)
                

#             g_vggloss /=30

#             g_style_loss/=10
#             # print(step,"g_adversarial_loss",g_adversarial_loss.numpy()[0],"g_triplet_loss",g_triplet_loss.numpy()[0],"g_vggloss",g_vggloss.numpy()[0],"g_styleloss", \
#                 # g_style_loss.numpy()[0],"g_l1_loss",g_l1_loss.numpy()[0],"g_loss",g_loss.numpy()[0])
#             g_loss = g_l1_loss + g_triplet_loss + g_adversarial_loss + g_style_loss + g_vggloss
            
#             g_loss.backward()
#             optimizer_G.step()
#             optimizer_G.clear_grad()
            
#             # ---------------------
#             #  Train Discriminator
#             # ---------------------
#             fake_output = discriminator(paddle.concat((fake_I_gt.detach(),sketch_img), axis=1))
#             real_output = discriminator(paddle.concat((appearance_img,sketch_img), axis=1))
#             d_real_loss = adversarial_loss(real_output, paddle.ones_like(real_output))
#             d_fake_loss = adversarial_loss(fake_output, paddle.zeros_like(fake_output))
#             d_loss = d_real_loss+d_fake_loss
            
#             d_loss.backward()
#             optimizer_D.step()
#             optimizer_D.clear_grad()


#             if step%2==0:
#                 log_writer.add_scalar(tag='train/d_real_loss', step=step, value=d_real_loss.numpy()[0])
#                 log_writer.add_scalar(tag='train/d_fake_loss', step=step, value=d_fake_loss.numpy()[0])
                
#                 log_writer.add_scalar(tag='train/d_all_loss', step=step, value=d_loss.numpy()[0])

                
#                 # log_writer.add_scalar(tag='train/col_loss', step=step, value=col_loss.numpy()[0])

#                 log_writer.add_scalar(tag='train/g_adversarial_loss', step=step, value=g_adversarial_loss.numpy()[0])
#                 log_writer.add_scalar(tag='train/g_triplet_loss', step=step, value=g_triplet_loss.numpy()[0])
#                 log_writer.add_scalar(tag='train/g_vggloss', step=step, value=g_vggloss.numpy()[0])
#                 log_writer.add_scalar(tag='train/g_style_loss', step=step, value=g_style_loss.numpy()[0])
#                 log_writer.add_scalar(tag='train/g_l1_loss', step=step, value=g_l1_loss.numpy()[0])
#                 log_writer.add_scalar(tag='train/g_loss', step=step, value=g_loss.numpy()[0])


#             step+=1
#             # print(i)
#             if step%100 == 3:
#                 print(step,"g_adversarial_loss",g_adversarial_loss.numpy()[0],"g_triplet_loss",g_triplet_loss.numpy()[0],"g_vggloss",g_vggloss.numpy()[0],"g_styleloss", \
#                 g_style_loss.numpy()[0],"g_l1_loss",g_l1_loss.numpy()[0],"g_loss",g_loss.numpy()[0])
#                 print(step,"dreal_loss",d_real_loss.numpy()[0],"dfake_loss",d_fake_loss.numpy()[0],"d_all_loss",d_loss.numpy()[0])

#                 # img_fake = paddle.multiply (img_fake,loss_mask)
#                 appearance_img = (appearance_img+1)*127.5
#                 reference_img = (reference_img+1)*127.5
#                 fake_I_gt = (fake_I_gt+1)*127.5


#                 g_output = paddle.concat([appearance_img,reference_img,fake_I_gt],axis = 3).detach().numpy()                      # tensor -> numpy
#                 g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
#                 g_output = g_output.astype(np.uint8)
#                 cv2.imwrite(os.path.join("./result", 'epoch'+str(step).zfill(3)+'.png'),cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
#                 # generator.train()
            
#             if step%100 == 3:
#                 # save_param_path_g = os.path.join(save_dir_generator, 'Gmodel_state'+str(step)+'.pdparams')
#                 # paddle.save(model.generator.state_dict(), save_param_path_g)
#                 save_param_path_d = os.path.join(save_dir_Discriminator, 'Dmodel_state'+str(3)+'.pdparams')
#                 paddle.save(discriminator.state_dict(), save_param_path_d)
#                 # save_param_path_e = os.path.join(save_dir_encoder, 'Emodel_state'+str(1)+'.pdparams')
#                 # paddle.save(model.encoder.state_dict(), save_param_path_e)
#                 save_param_path_m = os.path.join(save_dir_model, 'Mmodel_state'+str(3)+'.pdparams')
#                 paddle.save(generator.state_dict(), save_param_path_m)
#             # break
#         # except:
#         #     pass
#         # break
#     scheduler_G.step()
#     scheduler_D.step()

4.4测试代码

In [35]


model = Generator()
M_path ='Mmodel_state3.pdparams'
layer_state_dictm = paddle.load(M_path)
model.set_state_dict(layer_state_dictm)


path1 ="data/d/data/train/10007.png"
# path1 = "before3.jpg"
img = cv2.cvtColor(cv2.imread(path1, flags=cv2.IMREAD_COLOR),cv2.COLOR_BGR2RGB)

# cv2.imwrite(os.path.join("./test", "2955026.png"), cv2.cvtColor(img,cv2.COLOR_RGB2BGR))
from paddle.vision.transforms import CenterCrop,Resize
transform = Resize((256,256))
img_a = img[:,:512,:]
img_a =transform(img_a)


reference_img =paddle.transpose(x=paddle.to_tensor(img_a).unsqueeze(0).astype("float32")/127.5-1,perm=[0,3,1,2]) #style

path2 ="data/d/data/train/10001.png"
img = cv2.cvtColor(cv2.imread(path2, flags=cv2.IMREAD_COLOR),cv2.COLOR_BGR2RGB)
cv2.imwrite(os.path.join("./test", "2970114.png"), cv2.cvtColor(img,cv2.COLOR_RGB2BGR))
# from paddle.vision.transforms import CenterCrop,Resize
# transform = Resize((256,256))
img_b = img[:,512:,:]
img_b = transform(img_b)
sketch_img0 =paddle.transpose(x=paddle.to_tensor(img_b).unsqueeze(0).astype("float32"),perm=[0,3,1,2])#content
sketch_img = paddle.max( sketch_img0/255,axis=1,keepdim=True)
img_fake,_,_,_= model(sketch_img,reference_img)
print('img_fake',img_fake.shape)
# print(img_fake.shape)
# g_output = paddle.concat([img_fake,g_input1,g_input2],axis = 3).detach()                      # tensor -> numpy
img_fake = img_fake.transpose([0, 2, 3, 1])[0].numpy()           # NCHW -> NHWC
print(img_fake.shape)
img_fake = (img_fake+1) *127.5
reference_img = (reference_img+1)*127.5

sketch_img0 = sketch_img0.transpose([0, 2, 3, 1])[0].numpy()
reference_img = reference_img.transpose([0, 2, 3, 1])[0].numpy()

g_output = np.concatenate((sketch_img0,reference_img,img_fake),axis =1)
g_output = g_output.astype(np.uint8)
cv2.imwrite(os.path.join("./test", " 01.png"), cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
img_fake [1, 3, 256, 256]
(256, 256, 3)
True

5. loss可视化

 过3000的时候效果就开始变好了,变得色彩丰富起来了。

6. 展望一下

艾佬的建议是减少一点扭曲,可以让模型学习哪里是皮肤,哪里是衣服。然后其实如果效果想更加好就关键点就在于从参考图得到色彩信息点(对应线稿结构每个上色位置)。

看到这里了还不收藏起来嘛,大大的疑惑??

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐