标题:animegan复现

论文地址:https://github.com/TachibanaYoshino/AnimeGAN/blob/master/doc/Chen2020_Chapter_AnimeGAN.pdf

这篇论文其实讲解的人挺多的,但是一般人没结合代码讲解,而是纯纯的论文讲解,或许对于新手不太友好,这里我结合代码进行讲解,顺便给初入深度学习的新手一个友好的参考和借鉴经验。

这里我提供知乎上的讲解网址,https://zhuanlan.zhihu.com/p/162545685 ,适合老手快速了解

首先大致讲解一下这篇论文就是一种定向风格迁移,把真实的图片转换为数据集中训练的风格(动漫风格),训练效果还是很容易出的,说白了就是大致的训练结果还是挺好的,这很简单,首先因为这个训练是输入啥?输入的是真实图片,然后让它输出动漫风格化的图片,真实图片比动漫图片信息多了很多,所以这个我可以看作内容信息一部分衰减的过程,并且因为这个是定向风格迁移,所以这个判别器压力不大,它只要判别一种风格就行了。

模型主体在generater.py


判别器在discriminators.py


GANloss.py为GAN 对抗loss封装


生成器模型参数文件保存在generator_model


判别器参数文件保存在discriminator_model

核心解读

  1. 主要这个animegan就是通过group conv来缩小参数量。
  2. 然后判别器把动画图像灰度图判别为false,督促生成器生成高质量颜色鲜艳的图片,最后测试如果输入的是灰度图,也可以生成出颜色鲜艳的图。
  3. 模型主要架构图,如下:
# !unzip -oq /home/aistudio/data/data112828/dataset.zip  -d data/  #数据集解压
#导包

import cv2
from matplotlib import image
import numpy as np
import os
import paddle
import paddle.optimizer
import paddle.nn as nn
from tqdm import tqdm
from paddle.io import Dataset
from paddle.io import DataLoader
import paddle.nn.functional as F
import paddle.tensor as tensor
from generater import AnimeGenerator #这就是生成器存放地方
from discriminators import AnimeDiscriminator
from paddle.vision.datasets import ImageFolder
from paddle.vision.transforms import Compose, ColorJitter, Resize
# real_image_folder[0]
from VGG_MODEL import VGG19
from GANloss import GANLoss

#测试VGG模型 #我只使用VGG最后一层特征图,没有把中间的特征图比较
m = np.random.random([10, 3,311,321])
x = VGG19()(paddle.to_tensor(m,dtype="float32"))
x.shape
W0203 17:46:30.696004  4388 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0203 17:46:30.700654  4388 device_context.cc:422] device: 0, cuDNN Version: 7.6.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:





[10, 512, 38, 40]
class AnimeGANV2Dataset(paddle.io.Dataset):
    """
    """
    def __init__(self):
        """Initialize this dataset class.
        Args:
            cfg (dict) -- stores all the experiment flags
        """
        # 三份数据保存位置构造image_folder
        self.real_image_folder = ImageFolder("data/train_photo",transform=Compose([Resize(size =(248,248))]),loader=AnimeGANV2Dataset.loader)
        self.anime_image_folder = ImageFolder("data/Hayao/style",loader=AnimeGANV2Dataset.loader)
        self.smooth_image_folder = ImageFolder("data/Hayao/smooth",loader=AnimeGANV2Dataset.loader)
        self.sizes = [
            len(fold) for fold in [self.real_image_folder, self.anime_image_folder]
        ]
        self.size = max(self.sizes)
        self.reshuffle()
    # cv2.imread直接读取为GBR,把通道换成RGB
    @staticmethod
    def loader(path):
        return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
                            cv2.COLOR_BGR2RGB)

    def reshuffle(self):
        indexs = []
        for cur_size in self.sizes:
            x = np.arange(0, cur_size)
            np.random.shuffle(x)
            if cur_size != self.size:
                pad_num = self.size - cur_size
                pad = np.random.choice(cur_size, pad_num, replace=True)
                x = np.concatenate((x, pad))
                np.random.shuffle(x)
            indexs.append(x.tolist())
        self.indexs = list(zip(*indexs))

    def __getitem__(self, index):
        try:
            index = self.indexs.pop()
        except IndexError as e:
            self.reshuffle()
            index = self.indexs.pop()

        real_idx, anime_idx = index
        real_image = self.real_image_folder[real_idx]
        anime_image = self.anime_image_folder[anime_idx]
        smooth_image =self.smooth_image_folder[anime_idx]
        return (real_image,anime_image,smooth_image)

    def __len__(self):
        return self.size
BATCH_SIZE =4
dataset = AnimeGANV2Dataset()
data_loader = paddle.io.DataLoader(dataset,batch_size=BATCH_SIZE)
import matplotlib.pyplot as plt
'''
验证anime和smooth是对应的
'''
for real,anime,smooth in data_loader:

    cv2.imwrite('test/real.jpg',cv2.cvtColor( real[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
    cv2.imwrite('test/anime.jpg',cv2.cvtColor( anime[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
    cv2.imwrite('test/smooth.jpg', cv2.cvtColor( smooth[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
    break
#测试一下data_loader,展示一下图片处理的方法
for data in data_loader:
    # print(i[0][0])
    real_data,anime_data,smooth = [i[0]/127.5-1 for i in data]
    # print(type(real_data[0]))
    # print(real_data[0].shape)
    real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
    anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
    smooth =paddle.transpose(x=smooth,perm=[0,3,1,2])
    print(real_data.shape)
    # print(anime_data.shape)
    break
[4, 3, 248, 248]

#模型实例化
Generator = AnimeGenerator()
Discriminator = AnimeDiscriminator()
#导入模型
G_path ='generator_model/Gmodel_state7003.pdparams'
layer_state_dictg = paddle.load(G_path)
Generator.set_state_dict(layer_state_dictg)
# D_path ='discriminator_model/Dmodel_state1003.pdparams'
# layer_state_dictd = paddle.load(D_path)
# Discriminator.set_state_dict(layer_state_dictd)
#设置优化器
optimizer_G = paddle.optimizer.Adam(learning_rate=0.00008,parameters=Generator.parameters(),beta1=0.5)
optimizer_D = paddle.optimizer.Adam(learning_rate=0.00016,parameters=Discriminator.parameters(),beta1=0.5)
VGG = VGG19()

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

LOSS

import random
#494414
# 51223
smooth_knerl = tensor.to_tensor([[[[1/9 for i in range(3)]for i in range(3)]for i in range(3)]for i in range(3)])
def variation_loss(image, ksize=1):#使图片清晰
    dh = image[:, :, :-ksize, :] - image[:, :, ksize:, :]
    dw = image[:, :, :, :-ksize] - image[:, :, :, ksize:]
    return (paddle.mean(paddle.abs(dh)) + paddle.mean(paddle.abs(dw)))
def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)
def style_loss(style, fake):
    return nn.L1Loss()(gram(style), gram(fake))

def con_sty_loss(real, anime, fake):#内容和风格损失
    real_feature_map = VGG(real)
    fake_feature_map = VGG(fake)
    anime_feature_map = VGG(anime)
    c_loss = nn.L1Loss()(real_feature_map, fake_feature_map)
    s_loss = style_loss(anime_feature_map, fake_feature_map)
    return c_loss, s_loss
def rgb2yuv(rgb):
    kernel = paddle.to_tensor([[0.299, -0.14714119, 0.61497538],
                                   [0.587, -0.28886916, -0.51496512],
                                   [0.114, 0.43601035, -0.10001026]],
                                  dtype='float32')
    rgb = paddle.transpose(rgb, (0, 2, 3, 1))
    yuv = paddle.matmul(rgb, kernel)
    return yuv


def denormalize(image):
    return image * 0.5 + 0.5

def color_loss( con, fake):
        con = rgb2yuv(denormalize(con))
        # print("con",con.shape)
        fake = rgb2yuv(denormalize(fake))
        # print("fake",fake.shape)
        return (nn.L1Loss()(con[:, :, :, 0], fake[:, :, :, 0]) +
                nn.SmoothL1Loss()(con[:, :, :, 1], fake[:, :, :, 1]) +
                nn.SmoothL1Loss()(con[:, :, :, 2], fake[:, :, :, 2]))
def backward_G(real,anime_gray,fake):    
    fake_logit = Discriminator(fake)
    c_loss, s_loss = con_sty_loss(real,anime_gray,fake)
    c_loss = 1.5 * c_loss
    s_loss = 4.5* s_loss#2.5
    tv_loss = 1* variation_loss(fake)
    col_loss = 10* color_loss(real,fake)
    g_loss = (300* GANLoss()(fake_logit, True))
    loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss
    loss_dict["G"] = loss_G.numpy()[0]
    # print("lossg",loss_G.numpy())
    loss_G.backward()
def backward_G_predictor(real,fake):
    real_feature_map = VGG(real)
    fake_feature_map = VGG(fake)
    init_c_loss = nn.L1Loss()(real_feature_map, fake_feature_map)
    loss = 1 * init_c_loss
    loss_dict["G"] = loss.numpy()[0]
    # print("lossg",loss.numpy()) 
    loss.backward()

def backward_D(anime,anime_gray,fake,smooth_gray):
    real_logit = Discriminator(anime)#真实动漫图片
    gray_logit = Discriminator(anime_gray)#动漫图片变成灰度图
    fake_logit = Discriminator(fake.detach())#生成的假的图片
    smooth_logit = Discriminator(smooth_gray)#真实动漫图片经过卷积变模糊了一点
    d_real_loss = (300 * 1.2 *GANLoss()(real_logit, True))
    d_gray_loss = (300 * 1.2 *GANLoss()(gray_logit, False))
    # print(fake_logit.shape)
    d_fake_loss = (300 * 1.2 *GANLoss()(fake_logit, False))
    d_blur_loss = (300 * 0.8 *GANLoss()(smooth_logit, False))
    loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss
    loss_dict["D"] = (loss_D.numpy()[0])
    # print("lossd",loss_D.numpy())
    loss_D.backward()





先训练生成器backward_G_predictor,使生成器强大一点,然后再生成器和判别器一起训练backward_G

epoches =100
i = 0
save_dir_generator = "./generator_model"
save_dir_discriminator ="./discriminator_model"
for epoch in range(epoches):
    print("epoch",epoch)
    for data in tqdm(data_loader):

        # try:

            # print(i[0][0])
            real_data,anime_data,smooth_data = [i[0]/127.5-1 for i in data]

            loss_dict = {}
            real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
            anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
            #如果数据集不提供smooth,那就自己把图片高斯模糊就行
            # smooth_data = F.conv2d(anime_data,weight=smooth_knerl,stride=1,padding=1)
            smooth_data = paddle.transpose(x=smooth_data,perm=[0,3,1,2])
            anime_gray_data = paddle.expand(paddle.mean(anime_data,keepdim=True,axis=1),[BATCH_SIZE,3,anime_data.shape[-2],anime_data.shape[-1]])
            smooth_data_gray = paddle.expand(paddle.mean(smooth_data,keepdim=True,axis=1),[BATCH_SIZE,3,smooth_data.shape[-2],smooth_data.shape[-1]])
            fake_data = Generator(real_data)
            optimizer_D.clear_grad()
            d_loss = backward_D(anime_data,anime_gray_data,fake_data,smooth_data_gray)
            optimizer_D.step()
            optimizer_G.clear_grad()
            # g_loss = backward_G_predictor(real_data,fake_data)
            g_loss = backward_G(real_data,anime_gray_data,fake_data)
            optimizer_G.step()
            i+=1
            # print(i)
            if i%100 == 0:
                print(i,"D_LOSS",loss_dict["D"],"G_LOSS",loss_dict["G"])

            if i%1000 == 3:
                save_param_path_g = os.path.join(save_dir_generator, 'Gmodel_state'+str(i)+'.pdparams')
                paddle.save(Generator.state_dict(), save_param_path_g)
                save_param_path_d = os.path.join(save_dir_discriminator, 'Dmodel_state'+str(i)+'.pdparams')
                paddle.save(Discriminator.state_dict(), save_param_path_d)
                Generator.eval()
                img_A = cv2.imread("test/real.jpg")
                g_input = img_A.astype('float32') / 127.5 - 1             # 归一化
                g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
                g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
                # print(g_input.shape)
                g_output = Generator(g_input)
                # g_output = paddle.squeeze(g_output,0)
                g_output = g_output.detach().numpy()                      # tensor -> numpy
                g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
                g_output = g_output * 127.5 + 127.5                       # 反归一化
                g_output = g_output.astype(np.uint8)
                cv2.imwrite(os.path.join("./result", 'epoch'+str(i).zfill(3)+'.png'), cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
                Generator.train()
            # break
        # except:
        #     pass

  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 0


  6%|▌         | 100/1664 [00:22<05:56,  4.38it/s]

100 D_LOSS 268.2158 G_LOSS 1606.2267


 12%|█▏        | 200/1664 [00:45<05:33,  4.39it/s]

200 D_LOSS 278.59476 G_LOSS 1164.1039


 18%|█▊        | 300/1664 [01:08<05:11,  4.38it/s]

300 D_LOSS 299.0925 G_LOSS 1311.1616


 24%|██▍       | 400/1664 [01:31<04:46,  4.41it/s]

400 D_LOSS 236.26518 G_LOSS 1162.3252


 30%|███       | 500/1664 [01:54<04:29,  4.32it/s]

500 D_LOSS 246.95308 G_LOSS 1331.1404


 36%|███▌      | 600/1664 [02:16<04:01,  4.41it/s]

600 D_LOSS 208.72493 G_LOSS 1345.257


 42%|████▏     | 700/1664 [02:39<03:39,  4.40it/s]

700 D_LOSS 183.42891 G_LOSS 1042.9739


 48%|████▊     | 800/1664 [03:02<03:15,  4.42it/s]

800 D_LOSS 208.8155 G_LOSS 1266.8835


 54%|█████▍    | 900/1664 [03:24<02:53,  4.40it/s]

900 D_LOSS 213.62747 G_LOSS 1266.7932


 60%|██████    | 1000/1664 [03:47<02:31,  4.40it/s]

1000 D_LOSS 167.21292 G_LOSS 1153.8461


 66%|██████▌   | 1100/1664 [04:10<02:10,  4.32it/s]

1100 D_LOSS 170.24477 G_LOSS 1741.643


 72%|███████▏  | 1200/1664 [04:33<01:46,  4.34it/s]

1200 D_LOSS 154.22945 G_LOSS 1067.7212


 78%|███████▊  | 1300/1664 [04:55<01:22,  4.40it/s]

1300 D_LOSS 182.29414 G_LOSS 1226.3143


 84%|████████▍ | 1400/1664 [05:18<00:59,  4.42it/s]

1400 D_LOSS 174.83759 G_LOSS 1040.0066


 90%|█████████ | 1500/1664 [05:41<00:37,  4.41it/s]

1500 D_LOSS 159.94745 G_LOSS 1324.5468


 96%|█████████▌| 1600/1664 [06:04<00:14,  4.40it/s]

1600 D_LOSS 190.95378 G_LOSS 1079.9465


100%|██████████| 1664/1664 [06:18<00:00,  4.40it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 1


  2%|▏         | 36/1664 [00:08<06:09,  4.41it/s]

1700 D_LOSS 165.09032 G_LOSS 1424.0165


  8%|▊         | 136/1664 [00:30<05:47,  4.39it/s]

1800 D_LOSS 198.23642 G_LOSS 1292.516


 14%|█▍        | 236/1664 [00:53<05:21,  4.44it/s]

1900 D_LOSS 165.20651 G_LOSS 1502.531


 20%|██        | 336/1664 [01:16<05:01,  4.40it/s]

2000 D_LOSS 171.64952 G_LOSS 1187.1589


 26%|██▌       | 436/1664 [01:39<04:38,  4.41it/s]

2100 D_LOSS 79.426735 G_LOSS 1540.8289


 32%|███▏      | 536/1664 [02:02<04:19,  4.35it/s]

2200 D_LOSS 146.80545 G_LOSS 1335.2792


 38%|███▊      | 636/1664 [02:24<03:53,  4.40it/s]

2300 D_LOSS 140.99062 G_LOSS 1006.7901


 44%|████▍     | 736/1664 [02:47<03:30,  4.41it/s]

2400 D_LOSS 164.46994 G_LOSS 1341.5206


 50%|█████     | 836/1664 [03:10<03:08,  4.39it/s]

2500 D_LOSS 161.2462 G_LOSS 1362.1427


 56%|█████▋    | 936/1664 [03:33<02:43,  4.45it/s]

2600 D_LOSS 111.48205 G_LOSS 1498.5966


 62%|██████▏   | 1036/1664 [03:55<02:22,  4.40it/s]

2700 D_LOSS 131.80838 G_LOSS 1205.2937


 68%|██████▊   | 1136/1664 [04:18<01:59,  4.40it/s]

2800 D_LOSS 364.6225 G_LOSS 1377.4451


 74%|███████▍  | 1236/1664 [04:41<01:37,  4.37it/s]

2900 D_LOSS 142.73322 G_LOSS 1073.4952


 80%|████████  | 1336/1664 [05:04<01:14,  4.41it/s]

3000 D_LOSS 103.9964 G_LOSS 1066.0243


 86%|████████▋ | 1436/1664 [05:27<00:51,  4.42it/s]

3100 D_LOSS 174.45412 G_LOSS 1369.2009


 92%|█████████▏| 1536/1664 [05:49<00:29,  4.41it/s]

3200 D_LOSS 193.95383 G_LOSS 1415.3037


 98%|█████████▊| 1636/1664 [06:12<00:06,  4.40it/s]

3300 D_LOSS 106.95616 G_LOSS 1184.0751


100%|██████████| 1664/1664 [06:18<00:00,  4.39it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 2


  4%|▍         | 72/1664 [00:16<06:06,  4.35it/s]

3400 D_LOSS 176.8511 G_LOSS 1245.4153


 10%|█         | 172/1664 [00:39<05:37,  4.41it/s]

3500 D_LOSS 171.86816 G_LOSS 1244.4407


 16%|█▋        | 272/1664 [01:01<05:17,  4.39it/s]

3600 D_LOSS 173.22856 G_LOSS 1243.9589


 22%|██▏       | 372/1664 [01:24<04:53,  4.40it/s]

3700 D_LOSS 136.96036 G_LOSS 1350.8328


 28%|██▊       | 472/1664 [01:47<04:29,  4.43it/s]

3800 D_LOSS 192.17505 G_LOSS 1239.1552


 34%|███▍      | 572/1664 [02:09<04:10,  4.36it/s]

3900 D_LOSS 121.99713 G_LOSS 1312.6243


 40%|████      | 672/1664 [02:32<03:48,  4.35it/s]

4000 D_LOSS 145.77002 G_LOSS 1171.575


 46%|████▋     | 772/1664 [02:55<03:25,  4.34it/s]

4100 D_LOSS 165.67157 G_LOSS 1180.1825


 52%|█████▏    | 872/1664 [03:18<03:00,  4.38it/s]

4200 D_LOSS 98.46747 G_LOSS 1369.8628


 58%|█████▊    | 972/1664 [03:41<02:37,  4.38it/s]

4300 D_LOSS 125.519394 G_LOSS 1235.0101


 64%|██████▍   | 1072/1664 [04:04<02:15,  4.37it/s]

4400 D_LOSS 174.17241 G_LOSS 1338.9253


 70%|███████   | 1172/1664 [04:27<01:52,  4.39it/s]

4500 D_LOSS 131.71597 G_LOSS 1256.3582


 76%|███████▋  | 1272/1664 [04:50<01:30,  4.35it/s]

4600 D_LOSS 221.09642 G_LOSS 1309.5458


 82%|████████▏ | 1372/1664 [05:13<01:06,  4.37it/s]

4700 D_LOSS 156.89586 G_LOSS 1288.591


 88%|████████▊ | 1472/1664 [05:35<00:43,  4.39it/s]

4800 D_LOSS 148.09262 G_LOSS 1221.0227


 94%|█████████▍| 1572/1664 [05:58<00:20,  4.38it/s]

4900 D_LOSS 196.10785 G_LOSS 1029.9869


100%|██████████| 1664/1664 [06:19<00:00,  4.38it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 3


  0%|          | 8/1664 [00:01<06:16,  4.39it/s]

5000 D_LOSS 186.04994 G_LOSS 1162.1355


  6%|▋         | 108/1664 [00:24<05:55,  4.37it/s]

5100 D_LOSS 146.7022 G_LOSS 1145.7506


 12%|█▎        | 208/1664 [00:47<05:34,  4.35it/s]

5200 D_LOSS 191.23949 G_LOSS 1084.373


 19%|█▊        | 308/1664 [01:10<05:10,  4.37it/s]

5300 D_LOSS 131.53249 G_LOSS 1384.2104


 25%|██▍       | 408/1664 [01:33<04:45,  4.40it/s]

5400 D_LOSS 204.34352 G_LOSS 1146.7999


 31%|███       | 508/1664 [01:56<04:28,  4.30it/s]

5500 D_LOSS 166.13495 G_LOSS 1386.4236


 37%|███▋      | 608/1664 [02:19<04:01,  4.38it/s]

5600 D_LOSS 85.40863 G_LOSS 1302.3384


 43%|████▎     | 708/1664 [02:41<03:37,  4.39it/s]

5700 D_LOSS 148.7696 G_LOSS 1342.1403


 49%|████▊     | 808/1664 [03:04<03:15,  4.37it/s]

5800 D_LOSS 162.14708 G_LOSS 1047.0885


 55%|█████▍    | 908/1664 [03:27<02:54,  4.32it/s]

5900 D_LOSS 138.24677 G_LOSS 1305.394


 61%|██████    | 1008/1664 [03:50<02:30,  4.36it/s]

6000 D_LOSS 94.707245 G_LOSS 1124.3021


 67%|██████▋   | 1108/1664 [04:13<02:06,  4.39it/s]

6100 D_LOSS 111.90222 G_LOSS 1156.682


 73%|███████▎  | 1208/1664 [04:36<01:43,  4.39it/s]

6200 D_LOSS 171.16931 G_LOSS 1222.4845


 79%|███████▊  | 1308/1664 [04:59<01:21,  4.39it/s]

6300 D_LOSS 115.44682 G_LOSS 1114.112


 85%|████████▍ | 1408/1664 [05:22<00:58,  4.39it/s]

6400 D_LOSS 166.84598 G_LOSS 1048.7089


 91%|█████████ | 1508/1664 [05:45<00:35,  4.38it/s]

6500 D_LOSS 119.276924 G_LOSS 1269.8976


 97%|█████████▋| 1608/1664 [06:07<00:12,  4.41it/s]

6600 D_LOSS 81.460754 G_LOSS 1189.2717


100%|██████████| 1664/1664 [06:20<00:00,  4.37it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 4


  3%|▎         | 44/1664 [00:10<06:11,  4.37it/s]

6700 D_LOSS 132.34291 G_LOSS 1022.5383


  9%|▊         | 144/1664 [00:32<05:45,  4.40it/s]

6800 D_LOSS 172.09502 G_LOSS 1062.7837


 15%|█▍        | 244/1664 [00:55<05:22,  4.40it/s]

6900 D_LOSS 175.25679 G_LOSS 1357.6571


 21%|██        | 344/1664 [01:18<04:59,  4.40it/s]

7000 D_LOSS 115.33646 G_LOSS 1136.9133


 27%|██▋       | 444/1664 [01:41<04:39,  4.37it/s]

7100 D_LOSS 109.653 G_LOSS 1140.4409


 33%|███▎      | 544/1664 [02:04<04:15,  4.38it/s]

7200 D_LOSS 156.79759 G_LOSS 1333.0219


 39%|███▊      | 644/1664 [02:27<03:52,  4.39it/s]

7300 D_LOSS 80.23892 G_LOSS 1114.3611


 45%|████▍     | 744/1664 [02:49<03:30,  4.37it/s]

7400 D_LOSS 178.24619 G_LOSS 1648.8625


 51%|█████     | 844/1664 [03:12<03:14,  4.22it/s]

7500 D_LOSS 147.22025 G_LOSS 1461.1204


 57%|█████▋    | 944/1664 [03:35<02:44,  4.37it/s]

7600 D_LOSS 92.94795 G_LOSS 1225.1282


 63%|██████▎   | 1044/1664 [03:58<02:20,  4.43it/s]

7700 D_LOSS 126.23758 G_LOSS 1319.5778


 69%|██████▉   | 1144/1664 [04:21<01:58,  4.37it/s]

7800 D_LOSS 114.0454 G_LOSS 1203.5481


 75%|███████▍  | 1244/1664 [04:44<01:36,  4.35it/s]

7900 D_LOSS 150.40987 G_LOSS 1317.204


 81%|████████  | 1344/1664 [05:07<01:13,  4.37it/s]

8000 D_LOSS 176.27557 G_LOSS 1383.7046


 87%|████████▋ | 1444/1664 [05:30<00:50,  4.39it/s]

8100 D_LOSS 165.56274 G_LOSS 1378.1512


 93%|█████████▎| 1544/1664 [05:53<00:27,  4.37it/s]

8200 D_LOSS 134.8395 G_LOSS 1217.2303


 99%|█████████▉| 1644/1664 [06:15<00:04,  4.37it/s]

8300 D_LOSS 137.5817 G_LOSS 1232.9722


100%|██████████| 1664/1664 [06:20<00:00,  4.37it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 5


  4%|▎         | 60/1664 [00:13<06:09,  4.34it/s]

效果展示

图片测试,我只是单纯跑了5epoch左右,你可以试试更长时间

  1. 输入图片

  2. 输出图片

#测试代码
model_state_dict = paddle.load("generator_model/Gmodel_state8003.pdparams")
Generator = AnimeGenerator()

Generator.load_dict(model_state_dict)
Generator.eval()
# 读取数据
img_A = cv2.imread("src.png")
g_input = img_A.astype('float32') / 127.5 - 1             # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
print(g_input.shape)
g_output = Generator(g_input)
# g_output = paddle.squeeze(g_output,0)
g_output = g_output.detach().numpy()                      # tensor -> numpy
g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
g_output = g_output * 127.5 + 127.5                       # 反归一化
g_output = g_output.astype(np.uint8)
cv2.imwrite('t18003.png', cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
W0203 18:19:08.339618  6252 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0203 18:19:08.344558  6252 device_context.cc:422] device: 0, cuDNN Version: 7.6.


[1, 3, 634, 996]





True

总结一下:

  1. 整个项目训练其实不难,就是大致跑个效果是不难的,首先是要生成器预训练,然后再生成器和判别器一起训练,很适合练手。
  2. 另外这个主要数据集是风景,所以不适合人,另外就是原始图本身倾向于小清新,效果会好很多。

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

总结

  1. 这个很好训练。
  2. 学会写了代码的哪部分就测试一下,不要写完就直接跑,很容易bug太多不容易排查。
  3. 另外如果需要可以使用logwriter进行loss可视化,我这里没有用,觉得没必要,因为这个项目实在有点简单。
  4. 点个爱心再走呗,手有余香。
Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐