新春巨献 Style transfer 任意风格迁移

原论文题目:Parameter-Free Style Projection for Arbitrary Style Transfer


附记:本文将毫无保留的讲解整个过程,基于原论文改进复现并且改变(姑且称之为改进),原先想基于此写一下创新论文,觉得积累尚浅,变算了,但是整个过程还是值得我小小自豪一下.

本项目文章大概思路:你们就当目录吧

  1. part1 效果展示
  2. part2 你看完就可以直接玩了,做到风格迁移(还提供图片变视频的方法)
  3. part3 论文走一下,基于本人对于该论文核心的理解(若有出入,别骂啊,我才大二呢)
  4. part4 创新改良(若觉得不咋地,别骂啊,我是小菜鸡)
  5. 最后代码自己拿走不谢

Part1 效果展示

直接上才艺,效果展示(不然大家没感觉,当然我也提供训练好的模型大家自己玩)

  1. 原图(内容图):
  2. 风格图:
  3. 我效果风格迁移后的图片:
  4. 原论文paddlegan项目里面该模型效果:
  5. 纹理文件夹提供很多纹理,大家自己开心自己玩测试
#该部分为调包
import cv2
from matplotlib import image
import numpy as np
import os
import paddle
import paddle.optimizer
import paddle.nn as nn
import math
from tqdm import tqdm
from paddle.io import Dataset
from paddle.io import DataLoader
import paddle.nn.functional as F
import paddle.tensor as tensor
from Generater import Generater,VGG19
from discriminator import AnimeDiscriminator
from paddle.vision.datasets import ImageFolder
from paddle.vision.transforms import Compose, ColorJitter, Resize
import random
from visualdl import LogWriter
from GANloss import GANLoss
smooth_knerl = tensor.to_tensor([[[[1/9 for i in range(3)]for i in range(3)]for i in range(3)]for i in range(3)])
log_writer = LogWriter("./log/gnet")
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

Part2 简单想试一试我的Style Transfer

简单介绍一下,想直接玩玩试试风格迁移的效果,需要调整的几个属性:

  1. 下面代码块中第21行代码中,
    generator(a,b,c)其中三个参数,分别为content,style,风格化程度(一般为0到1之间,当然可以往上调,但是一般就是风格化太严重就不美了).
  2. 没有2了

另外下面代码块已经充分注释了不存在看不懂吧,对了,唯一稍微注意一下图片保存位置仅此而已.

generator = Generater()
# oslist =os.listdir("纹理")
# print(oslist)
G_path ='Gmodel_state33003.pdparams'
layer_state_dictg = paddle.load(G_path)
generator.set_state_dict(layer_state_dictg)#导入训练好的参数文件

generator.eval()
img_A = cv2.imread("venice-boat.jpg")#内容图
g_input = img_A.astype('float32') / 127.5-1          # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input = paddle.to_tensor(g_input)                       # numpy -> tensor

img_B = cv2.imread("feathers.jpg")#风格图
g_input_s = img_B.astype('float32') / 127.5-1            # 归一化
g_input_s = g_input_s[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input_s = paddle.to_tensor(g_input_s)                       # numpy -> tensor
for i in range(0,51,10):
    i = i/50
    i = paddle.to_tensor([i])
    g_output = generator(g_input,g_input_s,i)
    g_output = g_output.detach().numpy()                      # tensor -> numpy
    g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
    g_output = (g_output+1) *127.5                        # 反归一化
    g_output = g_output.astype(np.uint8)
    cv2.imwrite(os.path.join("./test", str(i.numpy()[0])+'qt.png'), g_output)#保存图片到本地

为了大家方便,我直接锦上添花附赠图片转换视频的代码,按照我的示例直接用就行

import os
import cv2
import numpy as np
fps = 1 #视频每秒1帧
size = (512, 512) #注意点1:需要转为视频的图片的尺寸,  可以使用cv2.resize()进行修改,一定要尺寸注意别搞错
path = './test'
# filelist = os.listdir(path)
filelist = ["test/"+str(i/50)+'qt.png' for  i in range(0,51,10)]#注意点2:尽量英文路径
# filelist
video = cv2.VideoWriter("Video.avi", cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, size)   #视频保存在当前目录下, 格式为 motion-jpeg codec,图片颜色失真比较小

for item in filelist:
    print(item)
    if item.endswith('.png'): 
        # print(item)
        img = cv2.imread(item)
        video.write(img)

video.release()
cv2.destroyAllWindows()
print('Video has been made.')


接下来就是论文核心讲解,进入Part3部分

网络主要架构

  1. 主要架构还是基本的Encoder-Decoder架构,其中这个pretrained encoder就是VGG,然后我是用VGG19(去掉最后面全连接和最后面几层卷积),Decoder架构你们就正常搭建就行,原论文这里很简单,我用了animegan里面的一些generator搭建思想,希望可以减少参数量,然后提升效果,这里不是重点.
  2. 当两张图片输入encoder得到两个特征图,c_feature和s_feature,第一个复现细节,因为输入内容图和风格图尺寸不一致,因此c_feature和s_feature大小不一样,所以我直接把s_feature resize为c_feature大小.
  3. 得到两个特征图后,就得提到这个style projection了,这是原论文核心创新点.
    首先第一点这个部分是无参的,就是不需要可学习参数.
    看上图啊,就是style feature和content feature是一样大小的,然后style feature这个张量数据基于content feature顺序重排,最后输出的output里面所有的数来自于style feature,但是大小顺序是content feature,好好看看论文里面这个给的示例就行,不难的,当然style projection实现是困难的,必须使用np的索引特性(好像叫矩阵索引),不能使用tensor,它不支持,我有啥办法.
  4. 然后把这个style projection输出的output通过decoder上采样就行,这就是基本思路
  5. 对了架构中的content skip connections,我复现改良代码没有用啊,因为这个本身就是辅助用的,这个思想很通用,encoder和decoder建立连接嘛,简简单单,不是重点.不会的可以直接百度,或者问问身边大哥,这个很常见.
  6. 接下来就是训练的损失设计了.一共三个损失,主要分为两个部分,内容损失和风格损失
    1. Style loss这个作为风格损失,是将生成出的风格化图片和原先输入的风格图比较经过VGG逐层算一个感知损失,比较均值和方差分别求2范式然后相加,也不是啥创新.我把它改成了gram,这个也是借鉴animegan的.
    2. Perceptual loss作为内容损失就是将生成出的风格化图片和原先输入的内容图比较经过VGG后算一个感知损失
    3. KL Loss这里也是内容损失,就是将生成出的风格化图片和原先输入的内容图比较经过VGG后算两个分布KL散度,这里我复现没用KL,第一次用KL的时候会出现,两个分布毫不相关,无法提供梯度的问题,因此导致训练效果很烂,训练的毒瘤.于是我把它替换成了,两个张量的差的平方的均值

都看到这里了,不来个爱心?求赞啊,祝大家新春快乐,越来越好,年年有余.

如果看到这里大家还能跟上,那么我就继续深入了,毕竟我说的是论文改良,而不止于复现.接下来我会根据我的思路给大家慢慢深化.复现是不够的,改良创新才是值得追求的.

Part4 追求卓越,创新

第一点.

考虑的地方就是,添加gan的思想,就是添加判别器,这个思想很容易想到,通过判别器拟合更加复杂的损失,更好推动生成器结合风格图的图像风格,这里我得展示一下了,那就是我为了,更好的生成的图像更鲜艳有色彩信息,判别器判别原风格图的灰度图像为false(如果你输入的风格图为黑白的,那么输出的图片还是有色彩信息的,我觉得毕竟黑白不好看,应该大家只想要黑白风格图中纹理信息而不是色彩信息),然后模糊的风格图为false(使最后生成的图片不糊.),其实这部分也是借鉴animegan,哭泣
#这里截取下方代码部分方便大家理解
def backward_D(anime,anime_gray,fake_1,smooth_gray):
    real_logit = Discriminator(anime)#真实动漫图片
    gray_logit = Discriminator(anime_gray)#动漫图片变成灰度图
    fake_logit = Discriminator(fake_1.detach())#生成的假的图片
    smooth_logit = Discriminator(smooth_gray)#真实动漫图片经过卷积变模糊了一点
    d_real_loss = (300 * 1.2 *GANLoss()(real_logit, True))/4
    d_gray_loss = (300 * 1.2 *GANLoss()(gray_logit, False))/4
    # print(fake_logit.shape)
    d_fake_loss = (300 * 1.2 *GANLoss()(fake_logit, False))/4
    d_blur_loss = (300 * 0.8 *GANLoss()(smooth_logit, False))/4

但是啊我还是创新了一下,就是GAN LOSS设计,我这里参考了Wgan,但是试了下我觉得它不行,不好用,于是我改良了,这里我是真正的创新啊:

  1. 核心思想
    1. 训练判别器真正的风格图片(颜色加纹理)
    2. 这个生成器的对抗损失是让生成出来的图片和风格图都经过判别器,然后让两个分布靠近,采取的差的平方的求和除以图片的H和W.
    3. 这样上述两步的目的是,让判别器成为一个很棒的老师也没事,然后生成器就跟着判别器学就行,只要g_loss和d_loss都下降就行,进行了可视化,很好的避免原始GAN训练不稳定的问题.


原论文
原论文
我的
我的

第二点.(让我再说一个点)

大家想,这个style projection是不是可以变成这样,a*style_feature+(1-a)*content_feature,那么a = 0即,我们训练其内容损失,a = 1训练其中的风格损失,这样就更合理.,这个是小小小的改良谈不上创新.

创新总结

  1. gan引入
  2. gan,loss自己设计,这方面我还没有调研过是否已经有我设计的这个GAN_loss,没有的话,就叫STGAN(Student Teacher GAN),学生像老师学习,越优秀的老师成就越优秀的学生,极大缓解GAN训练不稳定的问题.
  3. a*style_feature+(1-a)*content_feature

简单全部总结及其一些想法

PS:训练的时候可以先生成器和判别器分开预训练,我都写好了代码

1. 在这里总结一下,我对图片的认识,首先第一点把图片信息看作两部分,第一部分是内容信息,第二部分是风格信息,对于内容图来说,我们要保留其中的内容信息,对于风格图来说,我们要保留其中的风格信息,其中类似AdaIn这种极具代表性的,把风格信息看作一个分布的均值和方差进行处理,如果我们把内容特征图归一化之后加上风格特征图均值和方差的信息,这就是另一大风格迁移的主流,就是这个风格迁移模型核心是有参的,而本论文提出的这个style projection无参方法确实足够惊艳。

2. 如果想超越我这个效果,嗯这个,其实你可以训练轮数比我多,我预训练1万batch,后面训练3万batch以下,这个你完全可以更好的超越一下。另外一点,你看我下方代码很容易发现就是我用的数据集是paddle公开的数据集,风格图就是梵高的图片,也就是我这个风格图风格训练时及其单一,我也没有进行数据增广,嗯,不过在数据集如此简单的情况下,我还能训练成这样确实不错,不过也说明这个风格迁移任务,原论文复现训练难度不大(按照我这个改进后的)。现在我在训练SPADE这个难度训练实话有点大,当然我是做衣服生成任务,和原论文任务不一样,这也是一个问题。另外还有一个点就是完全可以使用多尺度判别器去替代我这个复制过来的及其简单的判别器。

3. 就是我没有做到的,但是确实很值得做的就是一张内容图片融入多种风格,然后这些风格需要自适应于这个图片,选择在合适的图片添加风格,这个点是值得尝试的。

4.就是风格图片的风格信息其实还是有点风格图的内容信息的,就是无法很好的真正的实现一个图片风格信息和内容信息的全分开,那么为了更好的展示效果,那就是第一点尽量使用纯纹理图,这是一个不错的选择,第二点就是选择风格图和内容图场景差不多的这种。

5.我在这里用风格迁移的视角去看待Spade这篇很有意义的论文,就是语义分割提供区域划分,然后不同的区域尽量让生成器去选择各自最合适的风格去填充这些块,因此语义风格更好的从像素级去控制风格迁移,这是我的观点,欢迎讨论。

点赞加fork啊,谢谢大家,
后面就是训练代码了,大家可以参考参考啊.

谢谢大家观看,stay hungry,stay foolish

#测试VGG19
m = np.random.random([1, 3,256,256])
real_image = paddle.to_tensor(m,dtype="float32")
VGG19()(real_image)

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:





Tensor(shape=[1, 512, 32, 32], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[[[-504.92623901 , -681.18713379 , -491.53030396 , ...,  78.78646088  ,  230.80319214 ,  56.87818146  ],
          [-669.20605469 , -829.37371826 , -477.69311523 , ...,  257.14486694 ,  736.31365967 ,  495.45385742 ],
          [-614.30639648 , -584.00915527 , -273.75314331 , ...,  164.56665039 ,  584.51458740 ,  552.98535156 ],
          ...,
          [-782.33227539 , -1009.52368164, -682.67425537 , ..., -90.78873444  ,  467.31436157 ,  604.84436035 ],
          [-1063.41333008, -1451.66125488, -989.54534912 , ..., -476.89031982 ,  60.31153870  ,  312.97409058 ],
          [-786.73352051 , -1036.59179688, -743.00543213 , ..., -430.36911011 , -20.83921242  ,  328.82687378 ]],

         [[-333.25308228 , -590.21368408 , -395.30163574 , ..., -510.68737793 , -470.66986084 , -214.75030518 ],
          [-354.14178467 , -534.85650635 , -369.56881714 , ..., -473.73120117 , -608.05413818 , -264.37985229 ],
          [-255.48089600 , -325.38314819 , -280.28259277 , ..., -388.04418945 , -628.09313965 , -452.81896973 ],
          ...,
          [-479.50253296 , -510.44772339 , -310.19030762 , ..., -30.17097855  , -106.12725067 , -157.98654175 ],
          [-279.87341309 , -350.56439209 , -285.77011108 , ..., -528.46514893 , -510.09298706 , -234.10522461 ],
          [-450.49606323 , -755.33892822 , -816.19183350 , ..., -1178.37304688, -1066.41967773, -609.41149902 ]],

         [[ 104.42698669 ,  136.74966431 , -44.84558105  , ...,  41.55228043  ,  146.76239014 ,  150.26026917 ],
          [-439.90225220 , -1004.07623291, -848.84722900 , ..., -842.31549072 , -788.59765625 , -538.78668213 ],
          [-263.58209229 , -560.84185791 , -277.10461426 , ..., -390.14138794 , -319.35894775 , -324.33853149 ],
          ...,
          [-312.32800293 , -560.64428711 , -522.88842773 , ..., -507.10333252 , -454.98864746 , -308.45715332 ],
          [-536.54656982 , -726.49127197 , -502.41726685 , ..., -215.54708862 , -414.77868652 , -336.37695312 ],
          [-380.70880127 , -393.37814331 , -25.59741783  , ..., -360.69711304 , -540.04974365 , -418.73223877 ]],

         ...,

         [[-490.27185059 , -861.55792236 , -802.02850342 , ..., -770.13836670 , -904.92425537 , -899.36431885 ],
          [-125.21800232 , -490.11276245 , -759.26080322 , ..., -699.79028320 , -933.78308105 , -1035.50280762],
          [-171.57382202 , -158.97537231 , -469.81860352 , ..., -288.07272339 , -442.69866943 , -732.10504150 ],
          ...,
          [-561.07318115 , -303.82955933 , -273.98178101 , ..., -15.62613869  , -354.42163086 , -523.39263916 ],
          [-559.61560059 , -147.03627014 , -167.77288818 , ..., -166.15148926 , -708.39447021 , -442.20263672 ],
          [-439.39028931 , -414.93225098 , -496.26449585 , ..., -543.37988281 , -958.16961670 , -749.18267822 ]],

         [[ 41.80353165  , -731.41302490 , -460.68313599 , ..., -727.31927490 , -614.69030762 , -350.02209473 ],
          [-787.27893066 , -1398.37646484, -840.11602783 , ..., -1074.66687012, -1082.42333984, -382.94317627 ],
          [-478.99234009 , -777.34460449 , -339.71063232 , ..., -541.90875244 , -691.02301025 , -188.59556580 ],
          ...,
          [-445.81188965 , -729.77392578 , -525.51556396 , ..., -263.03131104 , -565.15966797 , -404.40911865 ],
          [-511.72973633 , -934.59289551 , -691.37677002 , ..., -613.89788818 , -1001.38079834, -519.13104248 ],
          [-647.12799072 , -764.63293457 , -650.47686768 , ..., -593.37756348 , -890.70361328 ,  224.04991150 ]],

         [[-1089.00842285, -1555.93505859, -1187.05700684, ..., -1203.28234863, -1502.73266602, -1059.82653809],
          [-1104.17895508, -1661.84057617, -1227.82507324, ..., -804.88482666 , -1015.57427979, -726.86853027 ],
          [-498.17492676 , -871.57086182 , -569.47473145 , ..., -101.30202484 , -158.84301758 , -65.01415253  ],
          ...,
          [-528.08905029 , -689.14617920 , -253.05319214 , ..., -91.12982178  , -62.60898590  ,  34.38231659  ],
          [-71.31159973  , -436.28555298 , -421.11590576 , ..., -76.41316223  ,  83.53818512  ,  379.49111938 ],
          [ 593.02130127 ,  344.37402344 , -12.78419399  , ...,  338.39352417 ,  547.33068848 ,  651.10540771 ]]]])
# !unzip data/data105513/photo2vangogh.zip -data#解压数据
#Dataset
class AnimeGANV2Dataset(paddle.io.Dataset):
    """
    """
    def __init__(self):
        """Initialize this dataset class.
        Args:
            cfg (dict) -- stores all the experiment flags
        """
        # self.cfg = cfg
        self.real_image_folder = ImageFolder("ata/trainA",transform=Compose([Resize(size =(248,248))]),loader=cv2.imread)
        self.anime_image_folder = ImageFolder("ata/trainB",loader=cv2.imread)
        # self.smooth_image_folder = ImageFolder("data/Hayao/smooth",loader=cv2.imread)
        self.sizes = [
            len(fold) for fold in [self.real_image_folder, self.anime_image_folder]
        ]
        self.size = max(self.sizes)
        self.reshuffle()

    @staticmethod
    def loader(path):
        return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
                            cv2.COLOR_BGR2RGB)

    def reshuffle(self):
        indexs = []
        for cur_size in self.sizes:
            x = np.arange(0, cur_size)
            np.random.shuffle(x)
            if cur_size != self.size:
                pad_num = self.size - cur_size
                pad = np.random.choice(cur_size, pad_num, replace=True)
                x = np.concatenate((x, pad))
                np.random.shuffle(x)
            indexs.append(x.tolist())
        self.indexs = list(zip(*indexs))

    def __getitem__(self, index):
        try:
            index = self.indexs.pop()
        except IndexError as e:
            self.reshuffle()
            index = self.indexs.pop()

        real_idx, anime_idx = index
        real_image = self.real_image_folder[real_idx]
        anime_image = self.anime_image_folder[anime_idx]
        # smooth_image =self.smooth_image_folder[smooth_index]
        return (real_image,anime_image)

    def __len__(self):
        return self.size
BATCH_SIZE =4
dataset = AnimeGANV2Dataset()
data_loader = paddle.io.DataLoader(dataset,batch_size=BATCH_SIZE)

for data in data_loader:
    # print(i[0][0])
    real_data,anime_data = [i[0]/127.5-1 for i in data]
    # print(type(real_data[0]))
    # print(real_data[0].shape)
    real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
    anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
    print(real_data.shape)
    print(anime_data.shape)
    break
W0121 14:40:14.417394   335 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0121 14:40:14.421514   335 device_context.cc:422] device: 0, cuDNN Version: 7.6.


[4, 3, 248, 248]
[4, 3, 256, 256]

arr7 = np.arange(35).reshape(5,7)#生成一个5*7的数组

arr7
Out[24]:
array([[ 0, 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34]])

arr7[[1,3,2,4],[2,0,6,5]]
Out[27]: array([ 9, 21, 20, 33])

generator = Generater()
Discriminator = AnimeDiscriminator()
m = np.random.random([1, 3,520,410])
real_image = paddle.to_tensor(m,dtype="float32")
n = np.random.random([1, 3,310,320])
style_image = paddle.to_tensor(m,dtype="float32")

gen = generator(real_image,style_image,paddle.to_tensor([0.7]))
Discriminator = AnimeDiscriminator()
print(gen.shape)
G_path ='g_model_kl_predict/Gmodel_state2803.pdparams'
# G_path ="g_model_kl/Gmodel_state9503.pdparams"
layer_state_dictg = paddle.load(G_path)
generator.set_state_dict(layer_state_dictg)
D_path ='d_model_kl/Dmodel_state1003.pdparams'
layer_state_dictd = paddle.load(D_path)
Discriminator.set_state_dict(layer_state_dictd)
scheduler_G = paddle.optimizer.lr.StepDecay(learning_rate=0.0002, step_size=3, gamma=0.8, verbose=True)
scheduler_D = paddle.optimizer.lr.StepDecay(learning_rate=0.0001, step_size=3, gamma=0.8, verbose=True)
# optimizer_G = paddle.optimizer.SGD(learning_rate=0.001,parameters=generator.parameters())
# optimizer_D = paddle.optimizer.SGD(learning_rate=0.0004,parameters=Discriminator.parameters())
optimizer_G = paddle.optimizer.Adam(learning_rate=scheduler_G,parameters=generator.parameters(),beta1=0.5)
optimizer_D = paddle.optimizer.Adam(learning_rate=scheduler_D,parameters=Discriminator.parameters(),beta1=0.5)
VGG = VGG19()
[1, 3, 520, 408]
Epoch 0: StepDecay set learning rate to 0.0002.
Epoch 0: StepDecay set learning rate to 0.0001.
def variation_loss(image, ksize=1):
    dh = image[:, :, :-ksize, :] - image[:, :, ksize:, :]
    dw = image[:, :, :, :-ksize] - image[:, :, :, ksize:]
    return (paddle.mean(paddle.abs(dh)) + paddle.mean(paddle.abs(dw)))
def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)
def style_loss(style, fake):
    mean_loss = paddle.sqrt(paddle.abs(paddle.square(paddle.mean(style))-paddle.square(paddle.mean(fake))))*0.5
    std_loss =  paddle.sqrt(paddle.abs(paddle.square(paddle.std(style))-paddle.square(paddle.std(fake))))*0.5

    gram_loss = nn.L1Loss()(gram(style), gram(fake))*3
    if i%2 ==0:
        log_writer.add_scalar(tag='train/s_mean_loss', step=i, value=mean_loss.numpy()[0])
        log_writer.add_scalar(tag='train/s_std_loss', step=i, value=std_loss.numpy()[0])
        log_writer.add_scalar(tag='train/s_gram_loss', step=i, value=gram_loss.numpy()[0])

    return gram_loss
    # return gram_loss
def klloss(real,fake):
    # print(real.shape)
    real0 = real.reshape([BATCH_SIZE,-1])
    fake0 = fake.reshape([BATCH_SIZE,-1])
    loss = paddle.mean(paddle.square(real0-fake0))
    # real0 = paddle.nn.Softmax()(real0)
    # fake0 = paddle.nn.Softmax()(fake.reshape([BATCH_SIZE,-1]))
    # print(real0.shape)
    # print(fake0.shape)
    # m =paddle.log((real0+1e-8)/(fake0+1e-8))
    # m = paddle.multiply(real0,m)
    return loss
def con_sty_kl_loss(real, anime, fake_0,fake_1):
    real_feature_map = VGG(real)
    fake_feature_map_0 = VGG(fake_0)
    fake_feature_map_1 = VGG(fake_1)
    anime_feature_map = VGG(anime)
    c_loss = nn.L1Loss()(real_feature_map, fake_feature_map_0)
    s_loss = style_loss(anime_feature_map, fake_feature_map_1)
    # kl_loss = nn.L1Loss()(real_feature_map, fake_feature_map_1)
    kl_loss =klloss(real_feature_map, fake_feature_map_0) #采取平均
    return c_loss, s_loss,kl_loss
def rgb2yuv(rgb):
    kernel = paddle.to_tensor([[0.299, -0.14714119, 0.61497538],
                                   [0.587, -0.28886916, -0.51496512],
                                   [0.114, 0.43601035, -0.10001026]],
                                  dtype='float32')
    rgb = paddle.transpose(rgb, (0, 2, 3, 1))
    yuv = paddle.matmul(rgb, kernel)
    return yuv


def denormalize(image):
    return image * 0.5 + 0.5

def color_loss( con, fake):
        con = F.interpolate(x=con, size=fake.shape[-2:],mode="BILINEAR")
        con = rgb2yuv(denormalize(con))
        # print("con",con.shape)
        fake = rgb2yuv(denormalize(fake))
        # print("fake",fake.shape)
        return (nn.L1Loss()(con[:, :, :, 0], fake[:, :, :, 0]) +
                nn.SmoothL1Loss()(con[:, :, :, 1], fake[:, :, :, 1]) +
                nn.SmoothL1Loss()(con[:, :, :, 2], fake[:, :, :, 2]))
def backward_G_kl(real,anime,anime_gray,fake_0,fake_1):#去掉了gan对抗损失
    Discriminator.eval()
    fake_logit_1 = Discriminator(fake_1)
    c_loss, s_loss,kl_loss= con_sty_kl_loss(real,anime_gray,fake_0,fake_1)
    c_loss = 2.5 * c_loss/8/8
    s_loss = 1.5* s_loss/8*2/5
    kl_loss = 2.5 * kl_loss/8/8/200
    tv_loss = 1* variation_loss(fake_0)+variation_loss(fake_1)/8*100
    col_loss0 = 200* color_loss(real,fake_0)/8*1.25
    col_loss1 = 600* color_loss(anime,fake_1)/8*1.3
    b,c,h,w = fake_logit_1.shape
    anime = F.interpolate(x=anime, size=fake_1.shape[-2:],mode="BILINEAR")
    anime_logit = Discriminator(anime)
    col_loss = col_loss0+col_loss1
    g_loss = paddle.sum(paddle.square(anime_logit -fake_logit_1))/math.sqrt(c*h*w)*6/8/16
    # loss_G = c_loss + s_loss + col_loss + kl_loss+ tv_loss+g_loss
    loss_G = col_loss+c_loss+s_loss+kl_loss+tv_loss+g_loss
    # loss_G = c_loss + s_loss + col_loss + tv_loss
    loss_dict["c_loss"] = c_loss.numpy()[0]
    loss_dict["G"] = loss_G.numpy()[0]
    loss_dict["s_loss"] = s_loss.numpy()[0]
    loss_dict["kl_loss"] = kl_loss.numpy()[0]
    loss_dict["col_loss"] = col_loss.numpy()[0]
    loss_dict["tv_loss"] = tv_loss.numpy()[0]
    loss_dict["g_loss"] = g_loss.numpy()[0]
    loss_dict["g_lr"] = optimizer_G.get_lr()
    if i%2 ==0:
        
        log_writer.add_scalar(tag='train/c_loss', step=i, value=loss_dict["c_loss"])
        log_writer.add_scalar(tag='train/s_loss', step=i, value=loss_dict["s_loss"])
        log_writer.add_scalar(tag='train/kl_loss', step=i, value=loss_dict["kl_loss"])
        log_writer.add_scalar(tag='train/col_loss', step=i, value=loss_dict["col_loss"])
        log_writer.add_scalar(tag='train/G_loss', step=i, value=loss_dict["G"])
        log_writer.add_scalar(tag='train/g_loss', step=i, value=loss_dict["g_loss"])


    # print("lossg",loss_G.numpy())
    loss_G.backward()
    optimizer_G.step()
    Discriminator.train()
def backward_G_predictor(real,fake_0,fake_1):
    real_feature_map = VGG(real)
    fake_feature_map_0 = VGG(fake_0)
    fake_feature_map_1 = VGG(fake_1)
    kl_loss = klloss(real_feature_map, fake_feature_map_0)/10000*30
    init_c_loss = nn.L1Loss()(real_feature_map, fake_feature_map_0)
    # loss = 1 * init_c_loss +kl_loss*0.5
    loss = init_c_loss+kl_loss
    loss_dict["G"] = loss.numpy()[0]
    if i%2 ==0:
        log_writer.add_scalar(tag='train/G_loss', step=i, value=loss.numpy()[0])
        log_writer.add_scalar(tag='train/kl_loss', step=i, value=kl_loss.numpy()[0])
        log_writer.add_scalar(tag='train/init_c_loss', step=i, value=init_c_loss.numpy()[0])
    # print("lossg",loss.numpy()) 
    loss.backward()
    optimizer_G.step()

def backward_D(anime,anime_gray,fake_1,smooth_gray):
    real_logit = Discriminator(anime)#真实动漫图片
    gray_logit = Discriminator(anime_gray)#动漫图片变成灰度图
    fake_logit = Discriminator(fake_1.detach())#生成的假的图片
    smooth_logit = Discriminator(smooth_gray)#真实动漫图片经过卷积变模糊了一点
    d_real_loss = (300 * 1.2 *GANLoss()(real_logit, True))/4
    d_gray_loss = (300 * 1.2 *GANLoss()(gray_logit, False))/4
    # print(fake_logit.shape)
    d_fake_loss = (300 * 1.2 *GANLoss()(fake_logit, False))/4
    d_blur_loss = (300 * 0.8 *GANLoss()(smooth_logit, False))/4
    # d_blur_loss =0
    loss_D = d_real_loss + d_gray_loss + d_blur_loss+d_fake_loss
    # loss_D = d_real_loss + d_gray_loss + d_blur_loss
    loss_dict["D"] = (loss_D.numpy()[0])
    loss_dict["d_real_loss"] = d_real_loss.numpy()[0]
    loss_dict["d_fake_loss"] = d_fake_loss.numpy()[0]
    loss_dict["d_lr"] = optimizer_D.get_lr()
    # print("lossd",loss_D.numpy())
    # if (d_real_loss<d_fake_loss*2.5)or(i%200==5):
    # print("d_step")
    loss_D.backward()
    optimizer_D.step()
    if i%2 ==0:
        log_writer.add_scalar(tag='train/d_real_loss', step=i, value=d_real_loss.numpy()[0])
        log_writer.add_scalar(tag='train/d_fake_loss', step=i, value=d_fake_loss.numpy()[0])
        log_writer.add_scalar(tag='train/D_loss', step=i, value=loss_D.numpy()[0])


epoches =100
i = 0
save_dir_generator = "g_model_kl"
save_dir_Discriminator = "d_model_kl"
for epoch in range(epoches):
    print("epoch",epoch)
    for data in tqdm(data_loader):
        try:
            # print(i[0][0])
            real_data,anime_data = [i[0]/127.5-1 for i in data]
            # print(type(real_data[0]))
            # print(real_data[0].shape)
            loss_dict = {}
            real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
            anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
            smooth_data = F.conv2d(anime_data,weight=smooth_knerl,stride=1,padding=1)
            # smooth_data = paddle.transpose(x=smooth_data,perm=[0,3,1,2])
            anime_gray_data = paddle.expand(paddle.mean(anime_data,keepdim=True,axis=1),[BATCH_SIZE,3,anime_data.shape[-2],anime_data.shape[-1]])
            fake_data_0 = generator(real_data,anime_data,paddle.to_tensor([0.]))#content
            fake_data_1 = generator(real_data,anime_data,paddle.to_tensor([1.]))#style
            optimizer_D.clear_grad()
            d_loss = backward_D(anime_data,anime_gray_data,fake_data_1,smooth_data)
            
            optimizer_G.clear_grad()
            # g_loss = backward_G_predictor(real_data,fake_data_0,fake_data_1)
            g_loss = backward_G_kl(real_data,anime_data,anime_gray_data,fake_data_0,fake_data_1)

            i+=1
            # print(i)
            if i%300 == 3:
                # print(i,"D_LOSS",loss_dict["D"],"G_LOSS",loss_dict["G"])
                print(i,loss_dict)
                generator.eval()
                img_A = cv2.imread("老虎.jpg")
                g_input = img_A.astype('float32') / 127.5-1          # 归一化
                g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
                g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
                img_B = cv2.imread("style.jpg")
                g_input_s = img_B.astype('float32') / 127.5-1            # 归一化
                g_input_s = g_input_s[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
                g_input_s = paddle.to_tensor(g_input_s)                       # numpy -> tensor
                # print(g_input.shape)

                g_output = generator(g_input,g_input_s,paddle.to_tensor([1.]))
                # g_output = paddle.squeeze(g_output,0)
                g_output = g_output.detach().numpy()                      # tensor -> numpy
                g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
                g_output = (g_output+1) *127.5                        # 反归一化
                g_output = g_output.astype(np.uint8)
                cv2.imwrite(os.path.join("./kl_result", 'epoch'+str(i).zfill(3)+'.png'), g_output)
                generator.train()

            if i%1000 == 3:
                save_param_path_g = os.path.join(save_dir_generator, 'Gmodel_state'+str(i)+'.pdparams')
                paddle.save(generator.state_dict(), save_param_path_g)
                save_param_path_d = os.path.join(save_dir_Discriminator, 'Dmodel_state'+str(i)+'.pdparams')
                paddle.save(Discriminator.state_dict(), save_param_path_d)
            # break
        except:
            pass
    scheduler_G.step()
    scheduler_D.step()
    # break
  0%|          | 0/1572 [00:00<?, ?it/s]

epoch 0


  0%|          | 2/1572 [00:02<29:36,  1.13s/it]

3 {'D': 23.263124, 'd_real_loss': 6.634072, 'd_fake_loss': 15.863004, 'd_lr': 0.0001, 'c_loss': 6.146783, 'G': 90.09673, 's_loss': 46.35534, 'kl_loss': 8.516734, 'col_loss': 21.139446, 'tv_loss': 0.85263306, 'g_loss': 7.0857983, 'g_lr': 0.0002}


  0%|          | 4/1572 [00:07<47:54,  1.83s/it]
# generator = Generater()
# oslist =os.listdir("纹理")
# print(oslist)
# G_path ='Gmodel_state33003.pdparams'
# layer_state_dictg = paddle.load(G_path)
# generator.set_state_dict(layer_state_dictg)
# generator.eval()
# img_A = cv2.imread("Cache_-71658739148bde80..jpg")
# g_input = img_A.astype('float32') / 127.5-1          # 归一化
# g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
# g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
# for j in oslist: 
#     # try:
#         img_B = cv2.imread((os.path.join("./纹理",j)))
#         # img_B = cv2.imread("13.jpg")
#         g_input_s = img_B.astype('float32') / 127.5-1            # 归一化
#         g_input_s = g_input_s[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
#         g_input_s = paddle.to_tensor(g_input_s)                       # numpy -> tensor
#         for i in range(10,21,30):
#             i = i/10
#             i = paddle.to_tensor([i])
#             g_output = generator(g_input,g_input_s,i)
#             g_output = g_output.detach().numpy()                      # tensor -> numpy
#             g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
#             g_output = (g_output+1) *127.5                        # 反归一化
#             g_output = g_output.astype(np.uint8)
#             cv2.imwrite(os.path.join("./验证", str(i)+'纹理老虎'+j+'tanh.png'), g_output)

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐