animegan完美复现
base on animegan,本项目面向新手结合代码一步步操作很清晰,可以帮大家了解gan
·
标题:animegan复现
论文地址:https://github.com/TachibanaYoshino/AnimeGAN/blob/master/doc/Chen2020_Chapter_AnimeGAN.pdf
这篇论文其实讲解的人挺多的,但是一般人没结合代码讲解,而是纯纯的论文讲解,或许对于新手不太友好,这里我结合代码进行讲解,顺便给初入深度学习的新手一个友好的参考和借鉴经验。
这里我提供知乎上的讲解网址,https://zhuanlan.zhihu.com/p/162545685 ,适合老手快速了解
首先大致讲解一下这篇论文就是一种定向风格迁移,把真实的图片转换为数据集中训练的风格(动漫风格),训练效果还是很容易出的,说白了就是大致的训练结果还是挺好的,这很简单,首先因为这个训练是输入啥?输入的是真实图片,然后让它输出动漫风格化的图片,真实图片比动漫图片信息多了很多,所以这个我可以看作内容信息一部分衰减的过程,并且因为这个是定向风格迁移,所以这个判别器压力不大,它只要判别一种风格就行了。
模型主体在generater.py
判别器在discriminators.py
GANloss.py为GAN 对抗loss封装
生成器模型参数文件保存在generator_model
判别器参数文件保存在discriminator_model
核心解读
- 主要这个animegan就是通过group conv来缩小参数量。
- 然后判别器把动画图像灰度图判别为false,督促生成器生成高质量颜色鲜艳的图片,最后测试如果输入的是灰度图,也可以生成出颜色鲜艳的图。
- 模型主要架构图,如下:
# !unzip -oq /home/aistudio/data/data112828/dataset.zip -d data/ #数据集解压
#导包
import cv2
from matplotlib import image
import numpy as np
import os
import paddle
import paddle.optimizer
import paddle.nn as nn
from tqdm import tqdm
from paddle.io import Dataset
from paddle.io import DataLoader
import paddle.nn.functional as F
import paddle.tensor as tensor
from generater import AnimeGenerator #这就是生成器存放地方
from discriminators import AnimeDiscriminator
from paddle.vision.datasets import ImageFolder
from paddle.vision.transforms import Compose, ColorJitter, Resize
# real_image_folder[0]
from VGG_MODEL import VGG19
from GANloss import GANLoss
#测试VGG模型 #我只使用VGG最后一层特征图,没有把中间的特征图比较
m = np.random.random([10, 3,311,321])
x = VGG19()(paddle.to_tensor(m,dtype="float32"))
x.shape
W0203 17:46:30.696004 4388 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0203 17:46:30.700654 4388 device_context.cc:422] device: 0, cuDNN Version: 7.6.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
if data.dtype == np.object:
[10, 512, 38, 40]
class AnimeGANV2Dataset(paddle.io.Dataset):
"""
"""
def __init__(self):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
# 三份数据保存位置构造image_folder
self.real_image_folder = ImageFolder("data/train_photo",transform=Compose([Resize(size =(248,248))]),loader=AnimeGANV2Dataset.loader)
self.anime_image_folder = ImageFolder("data/Hayao/style",loader=AnimeGANV2Dataset.loader)
self.smooth_image_folder = ImageFolder("data/Hayao/smooth",loader=AnimeGANV2Dataset.loader)
self.sizes = [
len(fold) for fold in [self.real_image_folder, self.anime_image_folder]
]
self.size = max(self.sizes)
self.reshuffle()
# cv2.imread直接读取为GBR,把通道换成RGB
@staticmethod
def loader(path):
return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
cv2.COLOR_BGR2RGB)
def reshuffle(self):
indexs = []
for cur_size in self.sizes:
x = np.arange(0, cur_size)
np.random.shuffle(x)
if cur_size != self.size:
pad_num = self.size - cur_size
pad = np.random.choice(cur_size, pad_num, replace=True)
x = np.concatenate((x, pad))
np.random.shuffle(x)
indexs.append(x.tolist())
self.indexs = list(zip(*indexs))
def __getitem__(self, index):
try:
index = self.indexs.pop()
except IndexError as e:
self.reshuffle()
index = self.indexs.pop()
real_idx, anime_idx = index
real_image = self.real_image_folder[real_idx]
anime_image = self.anime_image_folder[anime_idx]
smooth_image =self.smooth_image_folder[anime_idx]
return (real_image,anime_image,smooth_image)
def __len__(self):
return self.size
BATCH_SIZE =4
dataset = AnimeGANV2Dataset()
data_loader = paddle.io.DataLoader(dataset,batch_size=BATCH_SIZE)
import matplotlib.pyplot as plt
'''
验证anime和smooth是对应的
'''
for real,anime,smooth in data_loader:
cv2.imwrite('test/real.jpg',cv2.cvtColor( real[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
cv2.imwrite('test/anime.jpg',cv2.cvtColor( anime[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
cv2.imwrite('test/smooth.jpg', cv2.cvtColor( smooth[0].numpy()[0].astype(np.uint8),cv2.COLOR_RGB2BGR))
break
#测试一下data_loader,展示一下图片处理的方法
for data in data_loader:
# print(i[0][0])
real_data,anime_data,smooth = [i[0]/127.5-1 for i in data]
# print(type(real_data[0]))
# print(real_data[0].shape)
real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
smooth =paddle.transpose(x=smooth,perm=[0,3,1,2])
print(real_data.shape)
# print(anime_data.shape)
break
[4, 3, 248, 248]
#模型实例化
Generator = AnimeGenerator()
Discriminator = AnimeDiscriminator()
#导入模型
G_path ='generator_model/Gmodel_state7003.pdparams'
layer_state_dictg = paddle.load(G_path)
Generator.set_state_dict(layer_state_dictg)
# D_path ='discriminator_model/Dmodel_state1003.pdparams'
# layer_state_dictd = paddle.load(D_path)
# Discriminator.set_state_dict(layer_state_dictd)
#设置优化器
optimizer_G = paddle.optimizer.Adam(learning_rate=0.00008,parameters=Generator.parameters(),beta1=0.5)
optimizer_D = paddle.optimizer.Adam(learning_rate=0.00016,parameters=Discriminator.parameters(),beta1=0.5)
VGG = VGG19()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
if data.dtype == np.object:
LOSS
import random
#494414
# 51223
smooth_knerl = tensor.to_tensor([[[[1/9 for i in range(3)]for i in range(3)]for i in range(3)]for i in range(3)])
def variation_loss(image, ksize=1):#使图片清晰
dh = image[:, :, :-ksize, :] - image[:, :, ksize:, :]
dw = image[:, :, :, :-ksize] - image[:, :, :, ksize:]
return (paddle.mean(paddle.abs(dh)) + paddle.mean(paddle.abs(dw)))
def gram(x):
b, c, h, w = x.shape
x_tmp = x.reshape((b, c, (h * w)))
gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
return gram / (c * h * w)
def style_loss(style, fake):
return nn.L1Loss()(gram(style), gram(fake))
def con_sty_loss(real, anime, fake):#内容和风格损失
real_feature_map = VGG(real)
fake_feature_map = VGG(fake)
anime_feature_map = VGG(anime)
c_loss = nn.L1Loss()(real_feature_map, fake_feature_map)
s_loss = style_loss(anime_feature_map, fake_feature_map)
return c_loss, s_loss
def rgb2yuv(rgb):
kernel = paddle.to_tensor([[0.299, -0.14714119, 0.61497538],
[0.587, -0.28886916, -0.51496512],
[0.114, 0.43601035, -0.10001026]],
dtype='float32')
rgb = paddle.transpose(rgb, (0, 2, 3, 1))
yuv = paddle.matmul(rgb, kernel)
return yuv
def denormalize(image):
return image * 0.5 + 0.5
def color_loss( con, fake):
con = rgb2yuv(denormalize(con))
# print("con",con.shape)
fake = rgb2yuv(denormalize(fake))
# print("fake",fake.shape)
return (nn.L1Loss()(con[:, :, :, 0], fake[:, :, :, 0]) +
nn.SmoothL1Loss()(con[:, :, :, 1], fake[:, :, :, 1]) +
nn.SmoothL1Loss()(con[:, :, :, 2], fake[:, :, :, 2]))
def backward_G(real,anime_gray,fake):
fake_logit = Discriminator(fake)
c_loss, s_loss = con_sty_loss(real,anime_gray,fake)
c_loss = 1.5 * c_loss
s_loss = 4.5* s_loss#2.5
tv_loss = 1* variation_loss(fake)
col_loss = 10* color_loss(real,fake)
g_loss = (300* GANLoss()(fake_logit, True))
loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss
loss_dict["G"] = loss_G.numpy()[0]
# print("lossg",loss_G.numpy())
loss_G.backward()
def backward_G_predictor(real,fake):
real_feature_map = VGG(real)
fake_feature_map = VGG(fake)
init_c_loss = nn.L1Loss()(real_feature_map, fake_feature_map)
loss = 1 * init_c_loss
loss_dict["G"] = loss.numpy()[0]
# print("lossg",loss.numpy())
loss.backward()
def backward_D(anime,anime_gray,fake,smooth_gray):
real_logit = Discriminator(anime)#真实动漫图片
gray_logit = Discriminator(anime_gray)#动漫图片变成灰度图
fake_logit = Discriminator(fake.detach())#生成的假的图片
smooth_logit = Discriminator(smooth_gray)#真实动漫图片经过卷积变模糊了一点
d_real_loss = (300 * 1.2 *GANLoss()(real_logit, True))
d_gray_loss = (300 * 1.2 *GANLoss()(gray_logit, False))
# print(fake_logit.shape)
d_fake_loss = (300 * 1.2 *GANLoss()(fake_logit, False))
d_blur_loss = (300 * 0.8 *GANLoss()(smooth_logit, False))
loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss
loss_dict["D"] = (loss_D.numpy()[0])
# print("lossd",loss_D.numpy())
loss_D.backward()
先训练生成器backward_G_predictor,使生成器强大一点,然后再生成器和判别器一起训练backward_G
epoches =100
i = 0
save_dir_generator = "./generator_model"
save_dir_discriminator ="./discriminator_model"
for epoch in range(epoches):
print("epoch",epoch)
for data in tqdm(data_loader):
# try:
# print(i[0][0])
real_data,anime_data,smooth_data = [i[0]/127.5-1 for i in data]
loss_dict = {}
real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
#如果数据集不提供smooth,那就自己把图片高斯模糊就行
# smooth_data = F.conv2d(anime_data,weight=smooth_knerl,stride=1,padding=1)
smooth_data = paddle.transpose(x=smooth_data,perm=[0,3,1,2])
anime_gray_data = paddle.expand(paddle.mean(anime_data,keepdim=True,axis=1),[BATCH_SIZE,3,anime_data.shape[-2],anime_data.shape[-1]])
smooth_data_gray = paddle.expand(paddle.mean(smooth_data,keepdim=True,axis=1),[BATCH_SIZE,3,smooth_data.shape[-2],smooth_data.shape[-1]])
fake_data = Generator(real_data)
optimizer_D.clear_grad()
d_loss = backward_D(anime_data,anime_gray_data,fake_data,smooth_data_gray)
optimizer_D.step()
optimizer_G.clear_grad()
# g_loss = backward_G_predictor(real_data,fake_data)
g_loss = backward_G(real_data,anime_gray_data,fake_data)
optimizer_G.step()
i+=1
# print(i)
if i%100 == 0:
print(i,"D_LOSS",loss_dict["D"],"G_LOSS",loss_dict["G"])
if i%1000 == 3:
save_param_path_g = os.path.join(save_dir_generator, 'Gmodel_state'+str(i)+'.pdparams')
paddle.save(Generator.state_dict(), save_param_path_g)
save_param_path_d = os.path.join(save_dir_discriminator, 'Dmodel_state'+str(i)+'.pdparams')
paddle.save(Discriminator.state_dict(), save_param_path_d)
Generator.eval()
img_A = cv2.imread("test/real.jpg")
g_input = img_A.astype('float32') / 127.5 - 1 # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2) # NHWC -> NCHW
g_input = paddle.to_tensor(g_input) # numpy -> tensor
# print(g_input.shape)
g_output = Generator(g_input)
# g_output = paddle.squeeze(g_output,0)
g_output = g_output.detach().numpy() # tensor -> numpy
g_output = g_output.transpose(0, 2, 3, 1)[0] # NCHW -> NHWC
g_output = g_output * 127.5 + 127.5 # 反归一化
g_output = g_output.astype(np.uint8)
cv2.imwrite(os.path.join("./result", 'epoch'+str(i).zfill(3)+'.png'), cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
Generator.train()
# break
# except:
# pass
0%| | 0/1664 [00:00<?, ?it/s]
epoch 0
6%|▌ | 100/1664 [00:22<05:56, 4.38it/s]
100 D_LOSS 268.2158 G_LOSS 1606.2267
12%|█▏ | 200/1664 [00:45<05:33, 4.39it/s]
200 D_LOSS 278.59476 G_LOSS 1164.1039
18%|█▊ | 300/1664 [01:08<05:11, 4.38it/s]
300 D_LOSS 299.0925 G_LOSS 1311.1616
24%|██▍ | 400/1664 [01:31<04:46, 4.41it/s]
400 D_LOSS 236.26518 G_LOSS 1162.3252
30%|███ | 500/1664 [01:54<04:29, 4.32it/s]
500 D_LOSS 246.95308 G_LOSS 1331.1404
36%|███▌ | 600/1664 [02:16<04:01, 4.41it/s]
600 D_LOSS 208.72493 G_LOSS 1345.257
42%|████▏ | 700/1664 [02:39<03:39, 4.40it/s]
700 D_LOSS 183.42891 G_LOSS 1042.9739
48%|████▊ | 800/1664 [03:02<03:15, 4.42it/s]
800 D_LOSS 208.8155 G_LOSS 1266.8835
54%|█████▍ | 900/1664 [03:24<02:53, 4.40it/s]
900 D_LOSS 213.62747 G_LOSS 1266.7932
60%|██████ | 1000/1664 [03:47<02:31, 4.40it/s]
1000 D_LOSS 167.21292 G_LOSS 1153.8461
66%|██████▌ | 1100/1664 [04:10<02:10, 4.32it/s]
1100 D_LOSS 170.24477 G_LOSS 1741.643
72%|███████▏ | 1200/1664 [04:33<01:46, 4.34it/s]
1200 D_LOSS 154.22945 G_LOSS 1067.7212
78%|███████▊ | 1300/1664 [04:55<01:22, 4.40it/s]
1300 D_LOSS 182.29414 G_LOSS 1226.3143
84%|████████▍ | 1400/1664 [05:18<00:59, 4.42it/s]
1400 D_LOSS 174.83759 G_LOSS 1040.0066
90%|█████████ | 1500/1664 [05:41<00:37, 4.41it/s]
1500 D_LOSS 159.94745 G_LOSS 1324.5468
96%|█████████▌| 1600/1664 [06:04<00:14, 4.40it/s]
1600 D_LOSS 190.95378 G_LOSS 1079.9465
100%|██████████| 1664/1664 [06:18<00:00, 4.40it/s]
0%| | 0/1664 [00:00<?, ?it/s]
epoch 1
2%|▏ | 36/1664 [00:08<06:09, 4.41it/s]
1700 D_LOSS 165.09032 G_LOSS 1424.0165
8%|▊ | 136/1664 [00:30<05:47, 4.39it/s]
1800 D_LOSS 198.23642 G_LOSS 1292.516
14%|█▍ | 236/1664 [00:53<05:21, 4.44it/s]
1900 D_LOSS 165.20651 G_LOSS 1502.531
20%|██ | 336/1664 [01:16<05:01, 4.40it/s]
2000 D_LOSS 171.64952 G_LOSS 1187.1589
26%|██▌ | 436/1664 [01:39<04:38, 4.41it/s]
2100 D_LOSS 79.426735 G_LOSS 1540.8289
32%|███▏ | 536/1664 [02:02<04:19, 4.35it/s]
2200 D_LOSS 146.80545 G_LOSS 1335.2792
38%|███▊ | 636/1664 [02:24<03:53, 4.40it/s]
2300 D_LOSS 140.99062 G_LOSS 1006.7901
44%|████▍ | 736/1664 [02:47<03:30, 4.41it/s]
2400 D_LOSS 164.46994 G_LOSS 1341.5206
50%|█████ | 836/1664 [03:10<03:08, 4.39it/s]
2500 D_LOSS 161.2462 G_LOSS 1362.1427
56%|█████▋ | 936/1664 [03:33<02:43, 4.45it/s]
2600 D_LOSS 111.48205 G_LOSS 1498.5966
62%|██████▏ | 1036/1664 [03:55<02:22, 4.40it/s]
2700 D_LOSS 131.80838 G_LOSS 1205.2937
68%|██████▊ | 1136/1664 [04:18<01:59, 4.40it/s]
2800 D_LOSS 364.6225 G_LOSS 1377.4451
74%|███████▍ | 1236/1664 [04:41<01:37, 4.37it/s]
2900 D_LOSS 142.73322 G_LOSS 1073.4952
80%|████████ | 1336/1664 [05:04<01:14, 4.41it/s]
3000 D_LOSS 103.9964 G_LOSS 1066.0243
86%|████████▋ | 1436/1664 [05:27<00:51, 4.42it/s]
3100 D_LOSS 174.45412 G_LOSS 1369.2009
92%|█████████▏| 1536/1664 [05:49<00:29, 4.41it/s]
3200 D_LOSS 193.95383 G_LOSS 1415.3037
98%|█████████▊| 1636/1664 [06:12<00:06, 4.40it/s]
3300 D_LOSS 106.95616 G_LOSS 1184.0751
100%|██████████| 1664/1664 [06:18<00:00, 4.39it/s]
0%| | 0/1664 [00:00<?, ?it/s]
epoch 2
4%|▍ | 72/1664 [00:16<06:06, 4.35it/s]
3400 D_LOSS 176.8511 G_LOSS 1245.4153
10%|█ | 172/1664 [00:39<05:37, 4.41it/s]
3500 D_LOSS 171.86816 G_LOSS 1244.4407
16%|█▋ | 272/1664 [01:01<05:17, 4.39it/s]
3600 D_LOSS 173.22856 G_LOSS 1243.9589
22%|██▏ | 372/1664 [01:24<04:53, 4.40it/s]
3700 D_LOSS 136.96036 G_LOSS 1350.8328
28%|██▊ | 472/1664 [01:47<04:29, 4.43it/s]
3800 D_LOSS 192.17505 G_LOSS 1239.1552
34%|███▍ | 572/1664 [02:09<04:10, 4.36it/s]
3900 D_LOSS 121.99713 G_LOSS 1312.6243
40%|████ | 672/1664 [02:32<03:48, 4.35it/s]
4000 D_LOSS 145.77002 G_LOSS 1171.575
46%|████▋ | 772/1664 [02:55<03:25, 4.34it/s]
4100 D_LOSS 165.67157 G_LOSS 1180.1825
52%|█████▏ | 872/1664 [03:18<03:00, 4.38it/s]
4200 D_LOSS 98.46747 G_LOSS 1369.8628
58%|█████▊ | 972/1664 [03:41<02:37, 4.38it/s]
4300 D_LOSS 125.519394 G_LOSS 1235.0101
64%|██████▍ | 1072/1664 [04:04<02:15, 4.37it/s]
4400 D_LOSS 174.17241 G_LOSS 1338.9253
70%|███████ | 1172/1664 [04:27<01:52, 4.39it/s]
4500 D_LOSS 131.71597 G_LOSS 1256.3582
76%|███████▋ | 1272/1664 [04:50<01:30, 4.35it/s]
4600 D_LOSS 221.09642 G_LOSS 1309.5458
82%|████████▏ | 1372/1664 [05:13<01:06, 4.37it/s]
4700 D_LOSS 156.89586 G_LOSS 1288.591
88%|████████▊ | 1472/1664 [05:35<00:43, 4.39it/s]
4800 D_LOSS 148.09262 G_LOSS 1221.0227
94%|█████████▍| 1572/1664 [05:58<00:20, 4.38it/s]
4900 D_LOSS 196.10785 G_LOSS 1029.9869
100%|██████████| 1664/1664 [06:19<00:00, 4.38it/s]
0%| | 0/1664 [00:00<?, ?it/s]
epoch 3
0%| | 8/1664 [00:01<06:16, 4.39it/s]
5000 D_LOSS 186.04994 G_LOSS 1162.1355
6%|▋ | 108/1664 [00:24<05:55, 4.37it/s]
5100 D_LOSS 146.7022 G_LOSS 1145.7506
12%|█▎ | 208/1664 [00:47<05:34, 4.35it/s]
5200 D_LOSS 191.23949 G_LOSS 1084.373
19%|█▊ | 308/1664 [01:10<05:10, 4.37it/s]
5300 D_LOSS 131.53249 G_LOSS 1384.2104
25%|██▍ | 408/1664 [01:33<04:45, 4.40it/s]
5400 D_LOSS 204.34352 G_LOSS 1146.7999
31%|███ | 508/1664 [01:56<04:28, 4.30it/s]
5500 D_LOSS 166.13495 G_LOSS 1386.4236
37%|███▋ | 608/1664 [02:19<04:01, 4.38it/s]
5600 D_LOSS 85.40863 G_LOSS 1302.3384
43%|████▎ | 708/1664 [02:41<03:37, 4.39it/s]
5700 D_LOSS 148.7696 G_LOSS 1342.1403
49%|████▊ | 808/1664 [03:04<03:15, 4.37it/s]
5800 D_LOSS 162.14708 G_LOSS 1047.0885
55%|█████▍ | 908/1664 [03:27<02:54, 4.32it/s]
5900 D_LOSS 138.24677 G_LOSS 1305.394
61%|██████ | 1008/1664 [03:50<02:30, 4.36it/s]
6000 D_LOSS 94.707245 G_LOSS 1124.3021
67%|██████▋ | 1108/1664 [04:13<02:06, 4.39it/s]
6100 D_LOSS 111.90222 G_LOSS 1156.682
73%|███████▎ | 1208/1664 [04:36<01:43, 4.39it/s]
6200 D_LOSS 171.16931 G_LOSS 1222.4845
79%|███████▊ | 1308/1664 [04:59<01:21, 4.39it/s]
6300 D_LOSS 115.44682 G_LOSS 1114.112
85%|████████▍ | 1408/1664 [05:22<00:58, 4.39it/s]
6400 D_LOSS 166.84598 G_LOSS 1048.7089
91%|█████████ | 1508/1664 [05:45<00:35, 4.38it/s]
6500 D_LOSS 119.276924 G_LOSS 1269.8976
97%|█████████▋| 1608/1664 [06:07<00:12, 4.41it/s]
6600 D_LOSS 81.460754 G_LOSS 1189.2717
100%|██████████| 1664/1664 [06:20<00:00, 4.37it/s]
0%| | 0/1664 [00:00<?, ?it/s]
epoch 4
3%|▎ | 44/1664 [00:10<06:11, 4.37it/s]
6700 D_LOSS 132.34291 G_LOSS 1022.5383
9%|▊ | 144/1664 [00:32<05:45, 4.40it/s]
6800 D_LOSS 172.09502 G_LOSS 1062.7837
15%|█▍ | 244/1664 [00:55<05:22, 4.40it/s]
6900 D_LOSS 175.25679 G_LOSS 1357.6571
21%|██ | 344/1664 [01:18<04:59, 4.40it/s]
7000 D_LOSS 115.33646 G_LOSS 1136.9133
27%|██▋ | 444/1664 [01:41<04:39, 4.37it/s]
7100 D_LOSS 109.653 G_LOSS 1140.4409
33%|███▎ | 544/1664 [02:04<04:15, 4.38it/s]
7200 D_LOSS 156.79759 G_LOSS 1333.0219
39%|███▊ | 644/1664 [02:27<03:52, 4.39it/s]
7300 D_LOSS 80.23892 G_LOSS 1114.3611
45%|████▍ | 744/1664 [02:49<03:30, 4.37it/s]
7400 D_LOSS 178.24619 G_LOSS 1648.8625
51%|█████ | 844/1664 [03:12<03:14, 4.22it/s]
7500 D_LOSS 147.22025 G_LOSS 1461.1204
57%|█████▋ | 944/1664 [03:35<02:44, 4.37it/s]
7600 D_LOSS 92.94795 G_LOSS 1225.1282
63%|██████▎ | 1044/1664 [03:58<02:20, 4.43it/s]
7700 D_LOSS 126.23758 G_LOSS 1319.5778
69%|██████▉ | 1144/1664 [04:21<01:58, 4.37it/s]
7800 D_LOSS 114.0454 G_LOSS 1203.5481
75%|███████▍ | 1244/1664 [04:44<01:36, 4.35it/s]
7900 D_LOSS 150.40987 G_LOSS 1317.204
81%|████████ | 1344/1664 [05:07<01:13, 4.37it/s]
8000 D_LOSS 176.27557 G_LOSS 1383.7046
87%|████████▋ | 1444/1664 [05:30<00:50, 4.39it/s]
8100 D_LOSS 165.56274 G_LOSS 1378.1512
93%|█████████▎| 1544/1664 [05:53<00:27, 4.37it/s]
8200 D_LOSS 134.8395 G_LOSS 1217.2303
99%|█████████▉| 1644/1664 [06:15<00:04, 4.37it/s]
8300 D_LOSS 137.5817 G_LOSS 1232.9722
100%|██████████| 1664/1664 [06:20<00:00, 4.37it/s]
0%| | 0/1664 [00:00<?, ?it/s]
epoch 5
4%|▎ | 60/1664 [00:13<06:09, 4.34it/s]
效果展示
图片测试,我只是单纯跑了5epoch左右,你可以试试更长时间
-
输入图片
-
输出图片
#测试代码
model_state_dict = paddle.load("generator_model/Gmodel_state8003.pdparams")
Generator = AnimeGenerator()
Generator.load_dict(model_state_dict)
Generator.eval()
# 读取数据
img_A = cv2.imread("src.png")
g_input = img_A.astype('float32') / 127.5 - 1 # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2) # NHWC -> NCHW
g_input = paddle.to_tensor(g_input) # numpy -> tensor
print(g_input.shape)
g_output = Generator(g_input)
# g_output = paddle.squeeze(g_output,0)
g_output = g_output.detach().numpy() # tensor -> numpy
g_output = g_output.transpose(0, 2, 3, 1)[0] # NCHW -> NHWC
g_output = g_output * 127.5 + 127.5 # 反归一化
g_output = g_output.astype(np.uint8)
cv2.imwrite('t18003.png', cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))
W0203 18:19:08.339618 6252 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0203 18:19:08.344558 6252 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[1, 3, 634, 996]
True
总结一下:
- 整个项目训练其实不难,就是大致跑个效果是不难的,首先是要生成器预训练,然后再生成器和判别器一起训练,很适合练手。
- 另外这个主要数据集是风景,所以不适合人,另外就是原始图本身倾向于小清新,效果会好很多。
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
总结
- 这个很好训练。
- 学会写了代码的哪部分就测试一下,不要写完就直接跑,很容易bug太多不容易排查。
- 另外如果需要可以使用logwriter进行loss可视化,我这里没有用,觉得没必要,因为这个项目实在有点简单。
- 点个爱心再走呗,手有余香。
更多推荐
所有评论(0)