


这里我提供知乎上的讲解网址,https://zhuanlan.zhihu.com/p/162545685 ,适合老手快速了解




GANloss.py为GAN 对抗loss封装




  1. 主要这个animegan就是通过group conv来缩小参数量。
  2. 然后判别器把动画图像灰度图判别为false,督促生成器生成高质量颜色鲜艳的图片,最后测试如果输入的是灰度图,也可以生成出颜色鲜艳的图。
  3. 模型主要架构图,如下:
# !unzip -oq /home/aistudio/data/data112828/dataset.zip  -d data/  #数据集解压

import cv2
from matplotlib import image
import numpy as np
import os
import paddle
import paddle.optimizer
import paddle.nn as nn
from tqdm import tqdm
from paddle.io import Dataset
from paddle.io import DataLoader
import paddle.nn.functional as F
import paddle.tensor as tensor
from generater import AnimeGenerator #这就是生成器存放地方
from discriminators import AnimeDiscriminator
from paddle.vision.datasets import ImageFolder
from paddle.vision.transforms import Compose, ColorJitter, Resize
# real_image_folder[0]
from VGG_MODEL import VGG19
from GANloss import GANLoss

#测试VGG模型 #我只使用VGG最后一层特征图,没有把中间的特征图比较
m = np.random.random([10, 3,311,321])
x = VGG19()(paddle.to_tensor(m,dtype="float32"))
W0203 17:46:30.696004  4388 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0203 17:46:30.700654  4388 device_context.cc:422] device: 0, cuDNN Version: 7.6.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

[10, 512, 38, 40]
class AnimeGANV2Dataset(paddle.io.Dataset):
    def __init__(self):
        """Initialize this dataset class.
            cfg (dict) -- stores all the experiment flags
        # 三份数据保存位置构造image_folder
        self.real_image_folder = ImageFolder("data/train_photo",transform=Compose([Resize(size =(248,248))]),loader=AnimeGANV2Dataset.loader)
        self.anime_image_folder = ImageFolder("data/Hayao/style",loader=AnimeGANV2Dataset.loader)
        self.smooth_image_folder = ImageFolder("data/Hayao/smooth",loader=AnimeGANV2Dataset.loader)
        self.sizes = [
            len(fold) for fold in [self.real_image_folder, self.anime_image_folder]
        self.size = max(self.sizes)
    # cv2.imread直接读取为GBR,把通道换成RGB
    def loader(path):
        return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),

    def reshuffle(self):
        indexs = []
        for cur_size in self.sizes:
            x = np.arange(0, cur_size)
            if cur_size != self.size:
                pad_num = self.size - cur_size
                pad = np.random.choice(cur_size, pad_num, replace=True)
                x = np.concatenate((x, pad))
        self.indexs = list(zip(*indexs))

    def __getitem__(self, index):
            index = self.indexs.pop()
        except IndexError as e:
            index = self.indexs.pop()

        real_idx, anime_idx = index
        real_image = self.real_image_folder[real_idx]
        anime_image = self.anime_image_folder[anime_idx]
        smooth_image =self.smooth_image_folder[anime_idx]
        return (real_image,anime_image,smooth_image)

    def __len__(self):
        return self.size
dataset = AnimeGANV2Dataset()
data_loader = paddle.io.DataLoader(dataset,batch_size=BATCH_SIZE)
import matplotlib.pyplot as plt
for real,anime,smooth in data_loader:

    cv2.imwrite('test/real.jpg',cv2.cvtColor( real[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
    cv2.imwrite('test/anime.jpg',cv2.cvtColor( anime[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
    cv2.imwrite('test/smooth.jpg', cv2.cvtColor( smooth[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
for data in data_loader:
    # print(i[0][0])
    real_data,anime_data,smooth = [i[0]/127.5-1 for i in data]
    # print(type(real_data[0]))
    # print(real_data[0].shape)
    real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
    anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
    smooth =paddle.transpose(x=smooth,perm=[0,3,1,2])
    # print(anime_data.shape)
[4, 3, 248, 248]

Generator = AnimeGenerator()
Discriminator = AnimeDiscriminator()
G_path ='generator_model/Gmodel_state7003.pdparams'
layer_state_dictg = paddle.load(G_path)
# D_path ='discriminator_model/Dmodel_state1003.pdparams'
# layer_state_dictd = paddle.load(D_path)
# Discriminator.set_state_dict(layer_state_dictd)
optimizer_G = paddle.optimizer.Adam(learning_rate=0.00008,parameters=Generator.parameters(),beta1=0.5)
optimizer_D = paddle.optimizer.Adam(learning_rate=0.00016,parameters=Discriminator.parameters(),beta1=0.5)
VGG = VGG19()

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


import random
# 51223
smooth_knerl = tensor.to_tensor([[[[1/9 for i in range(3)]for i in range(3)]for i in range(3)]for i in range(3)])
def variation_loss(image, ksize=1):#使图片清晰
    dh = image[:, :, :-ksize, :] - image[:, :, ksize:, :]
    dw = image[:, :, :, :-ksize] - image[:, :, :, ksize:]
    return (paddle.mean(paddle.abs(dh)) + paddle.mean(paddle.abs(dw)))
def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)
def style_loss(style, fake):
    return nn.L1Loss()(gram(style), gram(fake))

def con_sty_loss(real, anime, fake):#内容和风格损失
    real_feature_map = VGG(real)
    fake_feature_map = VGG(fake)
    anime_feature_map = VGG(anime)
    c_loss = nn.L1Loss()(real_feature_map, fake_feature_map)
    s_loss = style_loss(anime_feature_map, fake_feature_map)
    return c_loss, s_loss
def rgb2yuv(rgb):
    kernel = paddle.to_tensor([[0.299, -0.14714119, 0.61497538],
                                   [0.587, -0.28886916, -0.51496512],
                                   [0.114, 0.43601035, -0.10001026]],
    rgb = paddle.transpose(rgb, (0, 2, 3, 1))
    yuv = paddle.matmul(rgb, kernel)
    return yuv

def denormalize(image):
    return image * 0.5 + 0.5

def color_loss( con, fake):
        con = rgb2yuv(denormalize(con))
        # print("con",con.shape)
        fake = rgb2yuv(denormalize(fake))
        # print("fake",fake.shape)
        return (nn.L1Loss()(con[:, :, :, 0], fake[:, :, :, 0]) +
                nn.SmoothL1Loss()(con[:, :, :, 1], fake[:, :, :, 1]) +
                nn.SmoothL1Loss()(con[:, :, :, 2], fake[:, :, :, 2]))
def backward_G(real,anime_gray,fake):    
    fake_logit = Discriminator(fake)
    c_loss, s_loss = con_sty_loss(real,anime_gray,fake)
    c_loss = 1.5 * c_loss
    s_loss = 4.5* s_loss#2.5
    tv_loss = 1* variation_loss(fake)
    col_loss = 10* color_loss(real,fake)
    g_loss = (300* GANLoss()(fake_logit, True))
    loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss
    loss_dict["G"] = loss_G.numpy()[0]
    # print("lossg",loss_G.numpy())
def backward_G_predictor(real,fake):
    real_feature_map = VGG(real)
    fake_feature_map = VGG(fake)
    init_c_loss = nn.L1Loss()(real_feature_map, fake_feature_map)
    loss = 1 * init_c_loss
    loss_dict["G"] = loss.numpy()[0]
    # print("lossg",loss.numpy()) 

def backward_D(anime,anime_gray,fake,smooth_gray):
    real_logit = Discriminator(anime)#真实动漫图片
    gray_logit = Discriminator(anime_gray)#动漫图片变成灰度图
    fake_logit = Discriminator(fake.detach())#生成的假的图片
    smooth_logit = Discriminator(smooth_gray)#真实动漫图片经过卷积变模糊了一点
    d_real_loss = (300 * 1.2 *GANLoss()(real_logit, True))
    d_gray_loss = (300 * 1.2 *GANLoss()(gray_logit, False))
    # print(fake_logit.shape)
    d_fake_loss = (300 * 1.2 *GANLoss()(fake_logit, False))
    d_blur_loss = (300 * 0.8 *GANLoss()(smooth_logit, False))
    loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss
    loss_dict["D"] = (loss_D.numpy()[0])
    # print("lossd",loss_D.numpy())


epoches =100
i = 0
save_dir_generator = "./generator_model"
save_dir_discriminator ="./discriminator_model"
for epoch in range(epoches):
    for data in tqdm(data_loader):

        # try:

            # print(i[0][0])
            real_data,anime_data,smooth_data = [i[0]/127.5-1 for i in data]

            loss_dict = {}
            real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
            anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
            # smooth_data = F.conv2d(anime_data,weight=smooth_knerl,stride=1,padding=1)
            smooth_data = paddle.transpose(x=smooth_data,perm=[0,3,1,2])
            anime_gray_data = paddle.expand(paddle.mean(anime_data,keepdim=True,axis=1),[BATCH_SIZE,3,anime_data.shape[-2],anime_data.shape[-1]])
            smooth_data_gray = paddle.expand(paddle.mean(smooth_data,keepdim=True,axis=1),[BATCH_SIZE,3,smooth_data.shape[-2],smooth_data.shape[-1]])
            fake_data = Generator(real_data)
            d_loss = backward_D(anime_data,anime_gray_data,fake_data,smooth_data_gray)
            # g_loss = backward_G_predictor(real_data,fake_data)
            g_loss = backward_G(real_data,anime_gray_data,fake_data)
            # print(i)
            if i%100 == 0:

            if i%1000 == 3:
                save_param_path_g = os.path.join(save_dir_generator, 'Gmodel_state'+str(i)+'.pdparams')
                paddle.save(Generator.state_dict(), save_param_path_g)
                save_param_path_d = os.path.join(save_dir_discriminator, 'Dmodel_state'+str(i)+'.pdparams')
                paddle.save(Discriminator.state_dict(), save_param_path_d)
                img_A = cv2.imread("test/real.jpg")
                g_input = img_A.astype('float32') / 127.5 - 1             # 归一化
                g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
                g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
                # print(g_input.shape)
                g_output = Generator(g_input)
                # g_output = paddle.squeeze(g_output,0)
                g_output = g_output.detach().numpy()                      # tensor -> numpy
                g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
                g_output = g_output * 127.5 + 127.5                       # 反归一化
                g_output = g_output.astype(np.uint8)
                cv2.imwrite(os.path.join("./result", 'epoch'+str(i).zfill(3)+'.png'), cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
            # break
        # except:
        #     pass

  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 0

  6%|▌         | 100/1664 [00:22<05:56,  4.38it/s]

100 D_LOSS 268.2158 G_LOSS 1606.2267

 12%|█▏        | 200/1664 [00:45<05:33,  4.39it/s]

200 D_LOSS 278.59476 G_LOSS 1164.1039

 18%|█▊        | 300/1664 [01:08<05:11,  4.38it/s]

300 D_LOSS 299.0925 G_LOSS 1311.1616

 24%|██▍       | 400/1664 [01:31<04:46,  4.41it/s]

400 D_LOSS 236.26518 G_LOSS 1162.3252

 30%|███       | 500/1664 [01:54<04:29,  4.32it/s]

500 D_LOSS 246.95308 G_LOSS 1331.1404

 36%|███▌      | 600/1664 [02:16<04:01,  4.41it/s]

600 D_LOSS 208.72493 G_LOSS 1345.257

 42%|████▏     | 700/1664 [02:39<03:39,  4.40it/s]

700 D_LOSS 183.42891 G_LOSS 1042.9739

 48%|████▊     | 800/1664 [03:02<03:15,  4.42it/s]

800 D_LOSS 208.8155 G_LOSS 1266.8835

 54%|█████▍    | 900/1664 [03:24<02:53,  4.40it/s]

900 D_LOSS 213.62747 G_LOSS 1266.7932

 60%|██████    | 1000/1664 [03:47<02:31,  4.40it/s]

1000 D_LOSS 167.21292 G_LOSS 1153.8461

 66%|██████▌   | 1100/1664 [04:10<02:10,  4.32it/s]

1100 D_LOSS 170.24477 G_LOSS 1741.643

 72%|███████▏  | 1200/1664 [04:33<01:46,  4.34it/s]

1200 D_LOSS 154.22945 G_LOSS 1067.7212

 78%|███████▊  | 1300/1664 [04:55<01:22,  4.40it/s]

1300 D_LOSS 182.29414 G_LOSS 1226.3143

 84%|████████▍ | 1400/1664 [05:18<00:59,  4.42it/s]

1400 D_LOSS 174.83759 G_LOSS 1040.0066

 90%|█████████ | 1500/1664 [05:41<00:37,  4.41it/s]

1500 D_LOSS 159.94745 G_LOSS 1324.5468

 96%|█████████▌| 1600/1664 [06:04<00:14,  4.40it/s]

1600 D_LOSS 190.95378 G_LOSS 1079.9465

100%|██████████| 1664/1664 [06:18<00:00,  4.40it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 1

  2%|▏         | 36/1664 [00:08<06:09,  4.41it/s]

1700 D_LOSS 165.09032 G_LOSS 1424.0165

  8%|▊         | 136/1664 [00:30<05:47,  4.39it/s]

1800 D_LOSS 198.23642 G_LOSS 1292.516

 14%|█▍        | 236/1664 [00:53<05:21,  4.44it/s]

1900 D_LOSS 165.20651 G_LOSS 1502.531

 20%|██        | 336/1664 [01:16<05:01,  4.40it/s]

2000 D_LOSS 171.64952 G_LOSS 1187.1589

 26%|██▌       | 436/1664 [01:39<04:38,  4.41it/s]

2100 D_LOSS 79.426735 G_LOSS 1540.8289

 32%|███▏      | 536/1664 [02:02<04:19,  4.35it/s]

2200 D_LOSS 146.80545 G_LOSS 1335.2792

 38%|███▊      | 636/1664 [02:24<03:53,  4.40it/s]

2300 D_LOSS 140.99062 G_LOSS 1006.7901

 44%|████▍     | 736/1664 [02:47<03:30,  4.41it/s]

2400 D_LOSS 164.46994 G_LOSS 1341.5206

 50%|█████     | 836/1664 [03:10<03:08,  4.39it/s]

2500 D_LOSS 161.2462 G_LOSS 1362.1427

 56%|█████▋    | 936/1664 [03:33<02:43,  4.45it/s]

2600 D_LOSS 111.48205 G_LOSS 1498.5966

 62%|██████▏   | 1036/1664 [03:55<02:22,  4.40it/s]

2700 D_LOSS 131.80838 G_LOSS 1205.2937

 68%|██████▊   | 1136/1664 [04:18<01:59,  4.40it/s]

2800 D_LOSS 364.6225 G_LOSS 1377.4451

 74%|███████▍  | 1236/1664 [04:41<01:37,  4.37it/s]

2900 D_LOSS 142.73322 G_LOSS 1073.4952

 80%|████████  | 1336/1664 [05:04<01:14,  4.41it/s]

3000 D_LOSS 103.9964 G_LOSS 1066.0243

 86%|████████▋ | 1436/1664 [05:27<00:51,  4.42it/s]

3100 D_LOSS 174.45412 G_LOSS 1369.2009

 92%|█████████▏| 1536/1664 [05:49<00:29,  4.41it/s]

3200 D_LOSS 193.95383 G_LOSS 1415.3037

 98%|█████████▊| 1636/1664 [06:12<00:06,  4.40it/s]

3300 D_LOSS 106.95616 G_LOSS 1184.0751

100%|██████████| 1664/1664 [06:18<00:00,  4.39it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 2

  4%|▍         | 72/1664 [00:16<06:06,  4.35it/s]

3400 D_LOSS 176.8511 G_LOSS 1245.4153

 10%|█         | 172/1664 [00:39<05:37,  4.41it/s]

3500 D_LOSS 171.86816 G_LOSS 1244.4407

 16%|█▋        | 272/1664 [01:01<05:17,  4.39it/s]

3600 D_LOSS 173.22856 G_LOSS 1243.9589

 22%|██▏       | 372/1664 [01:24<04:53,  4.40it/s]

3700 D_LOSS 136.96036 G_LOSS 1350.8328

 28%|██▊       | 472/1664 [01:47<04:29,  4.43it/s]

3800 D_LOSS 192.17505 G_LOSS 1239.1552

 34%|███▍      | 572/1664 [02:09<04:10,  4.36it/s]

3900 D_LOSS 121.99713 G_LOSS 1312.6243

 40%|████      | 672/1664 [02:32<03:48,  4.35it/s]

4000 D_LOSS 145.77002 G_LOSS 1171.575

 46%|████▋     | 772/1664 [02:55<03:25,  4.34it/s]

4100 D_LOSS 165.67157 G_LOSS 1180.1825

 52%|█████▏    | 872/1664 [03:18<03:00,  4.38it/s]

4200 D_LOSS 98.46747 G_LOSS 1369.8628

 58%|█████▊    | 972/1664 [03:41<02:37,  4.38it/s]

4300 D_LOSS 125.519394 G_LOSS 1235.0101

 64%|██████▍   | 1072/1664 [04:04<02:15,  4.37it/s]

4400 D_LOSS 174.17241 G_LOSS 1338.9253

 70%|███████   | 1172/1664 [04:27<01:52,  4.39it/s]

4500 D_LOSS 131.71597 G_LOSS 1256.3582

 76%|███████▋  | 1272/1664 [04:50<01:30,  4.35it/s]

4600 D_LOSS 221.09642 G_LOSS 1309.5458

 82%|████████▏ | 1372/1664 [05:13<01:06,  4.37it/s]

4700 D_LOSS 156.89586 G_LOSS 1288.591

 88%|████████▊ | 1472/1664 [05:35<00:43,  4.39it/s]

4800 D_LOSS 148.09262 G_LOSS 1221.0227

 94%|█████████▍| 1572/1664 [05:58<00:20,  4.38it/s]

4900 D_LOSS 196.10785 G_LOSS 1029.9869

100%|██████████| 1664/1664 [06:19<00:00,  4.38it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 3

  0%|          | 8/1664 [00:01<06:16,  4.39it/s]

5000 D_LOSS 186.04994 G_LOSS 1162.1355

  6%|▋         | 108/1664 [00:24<05:55,  4.37it/s]

5100 D_LOSS 146.7022 G_LOSS 1145.7506

 12%|█▎        | 208/1664 [00:47<05:34,  4.35it/s]

5200 D_LOSS 191.23949 G_LOSS 1084.373

 19%|█▊        | 308/1664 [01:10<05:10,  4.37it/s]

5300 D_LOSS 131.53249 G_LOSS 1384.2104

 25%|██▍       | 408/1664 [01:33<04:45,  4.40it/s]

5400 D_LOSS 204.34352 G_LOSS 1146.7999

 31%|███       | 508/1664 [01:56<04:28,  4.30it/s]

5500 D_LOSS 166.13495 G_LOSS 1386.4236

 37%|███▋      | 608/1664 [02:19<04:01,  4.38it/s]

5600 D_LOSS 85.40863 G_LOSS 1302.3384

 43%|████▎     | 708/1664 [02:41<03:37,  4.39it/s]

5700 D_LOSS 148.7696 G_LOSS 1342.1403

 49%|████▊     | 808/1664 [03:04<03:15,  4.37it/s]

5800 D_LOSS 162.14708 G_LOSS 1047.0885

 55%|█████▍    | 908/1664 [03:27<02:54,  4.32it/s]

5900 D_LOSS 138.24677 G_LOSS 1305.394

 61%|██████    | 1008/1664 [03:50<02:30,  4.36it/s]

6000 D_LOSS 94.707245 G_LOSS 1124.3021

 67%|██████▋   | 1108/1664 [04:13<02:06,  4.39it/s]

6100 D_LOSS 111.90222 G_LOSS 1156.682

 73%|███████▎  | 1208/1664 [04:36<01:43,  4.39it/s]

6200 D_LOSS 171.16931 G_LOSS 1222.4845

 79%|███████▊  | 1308/1664 [04:59<01:21,  4.39it/s]

6300 D_LOSS 115.44682 G_LOSS 1114.112

 85%|████████▍ | 1408/1664 [05:22<00:58,  4.39it/s]

6400 D_LOSS 166.84598 G_LOSS 1048.7089

 91%|█████████ | 1508/1664 [05:45<00:35,  4.38it/s]

6500 D_LOSS 119.276924 G_LOSS 1269.8976

 97%|█████████▋| 1608/1664 [06:07<00:12,  4.41it/s]

6600 D_LOSS 81.460754 G_LOSS 1189.2717

100%|██████████| 1664/1664 [06:20<00:00,  4.37it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 4

  3%|▎         | 44/1664 [00:10<06:11,  4.37it/s]

6700 D_LOSS 132.34291 G_LOSS 1022.5383

  9%|▊         | 144/1664 [00:32<05:45,  4.40it/s]

6800 D_LOSS 172.09502 G_LOSS 1062.7837

 15%|█▍        | 244/1664 [00:55<05:22,  4.40it/s]

6900 D_LOSS 175.25679 G_LOSS 1357.6571

 21%|██        | 344/1664 [01:18<04:59,  4.40it/s]

7000 D_LOSS 115.33646 G_LOSS 1136.9133

 27%|██▋       | 444/1664 [01:41<04:39,  4.37it/s]

7100 D_LOSS 109.653 G_LOSS 1140.4409

 33%|███▎      | 544/1664 [02:04<04:15,  4.38it/s]

7200 D_LOSS 156.79759 G_LOSS 1333.0219

 39%|███▊      | 644/1664 [02:27<03:52,  4.39it/s]

7300 D_LOSS 80.23892 G_LOSS 1114.3611

 45%|████▍     | 744/1664 [02:49<03:30,  4.37it/s]

7400 D_LOSS 178.24619 G_LOSS 1648.8625

 51%|█████     | 844/1664 [03:12<03:14,  4.22it/s]

7500 D_LOSS 147.22025 G_LOSS 1461.1204

 57%|█████▋    | 944/1664 [03:35<02:44,  4.37it/s]

7600 D_LOSS 92.94795 G_LOSS 1225.1282

 63%|██████▎   | 1044/1664 [03:58<02:20,  4.43it/s]

7700 D_LOSS 126.23758 G_LOSS 1319.5778

 69%|██████▉   | 1144/1664 [04:21<01:58,  4.37it/s]

7800 D_LOSS 114.0454 G_LOSS 1203.5481

 75%|███████▍  | 1244/1664 [04:44<01:36,  4.35it/s]

7900 D_LOSS 150.40987 G_LOSS 1317.204

 81%|████████  | 1344/1664 [05:07<01:13,  4.37it/s]

8000 D_LOSS 176.27557 G_LOSS 1383.7046

 87%|████████▋ | 1444/1664 [05:30<00:50,  4.39it/s]

8100 D_LOSS 165.56274 G_LOSS 1378.1512

 93%|█████████▎| 1544/1664 [05:53<00:27,  4.37it/s]

8200 D_LOSS 134.8395 G_LOSS 1217.2303

 99%|█████████▉| 1644/1664 [06:15<00:04,  4.37it/s]

8300 D_LOSS 137.5817 G_LOSS 1232.9722

100%|██████████| 1664/1664 [06:20<00:00,  4.37it/s]
  0%|          | 0/1664 [00:00<?, ?it/s]

epoch 5

  4%|▎         | 60/1664 [00:13<06:09,  4.34it/s]



  1. 输入图片

  2. 输出图片

model_state_dict = paddle.load("generator_model/Gmodel_state8003.pdparams")
Generator = AnimeGenerator()

# 读取数据
img_A = cv2.imread("src.png")
g_input = img_A.astype('float32') / 127.5 - 1             # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input = paddle.to_tensor(g_input)                       # numpy -> tensor
g_output = Generator(g_input)
# g_output = paddle.squeeze(g_output,0)
g_output = g_output.detach().numpy()                      # tensor -> numpy
g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
g_output = g_output * 127.5 + 127.5                       # 反归一化
g_output = g_output.astype(np.uint8)
cv2.imwrite('t18003.png', cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
W0203 18:19:08.339618  6252 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0203 18:19:08.344558  6252 device_context.cc:422] device: 0, cuDNN Version: 7.6.

[1, 3, 634, 996]



  1. 整个项目训练其实不难,就是大致跑个效果是不难的,首先是要生成器预训练,然后再生成器和判别器一起训练,很适合练手。
  2. 另外这个主要数据集是风景,所以不适合人,另外就是原始图本身倾向于小清新,效果会好很多。


Please click here for more detailed instructions.


  1. 这个很好训练。
  2. 学会写了代码的哪部分就测试一下,不要写完就直接跑,很容易bug太多不容易排查。
  3. 另外如果需要可以使用logwriter进行loss可视化,我这里没有用,觉得没必要,因为这个项目实在有点简单。
  4. 点个爱心再走呗,手有余香。

