论文名称:Style-Aware Normalized Loss for Improving Arbitrary Style Transfer

1. 论文介绍部分:

参考github项目:https://github.com/Neroxn/Style-Aware-Normalized/tree/main/Project_Kargi ,非官方

  1. Neural Style Transfer (NST) has quickly evolved from single-style to infinite-style models, also known as Arbitrary Style Transfer (AST)

  2. 神经风格转移(NST)是指通过神经网络从两幅图像C和两幅图像S中生成一个仿式图像P,其中P与C共享内容,但具有S的风格。这里的NST也指AST。

  3. 这是 Original NST loss.

在这里插入图片描述

原始的NST的损失由两部分组成,一部分是内容部分,一部分是风格部分, β是a trade-off factor(用来平衡内容部分和风格部分的超参)。通常需要一个在ImageNet[8]预训练的VGG[45]网络F从C,S,P中提取特征,内容损失通过比较P和C的特征进行计算,和风格部分通过比较P和S各自特征的Gram矩阵 g 进行计算,其中Gram矩阵被广泛认为是有效的风格信息。在实践中,风格和内容部分是通过来自几个层的特征计算的,然后使用加权和(权重通常设置为1)进行聚合。其中MSE为均方误差

这个时候论文提出了一个关于风格化程度的问题,请见下图:

在这里插入图片描述

使用classic loss即上文的original nst loss 会出现两种情况,Under-stylization(风格化不够) 和 Over-stylization(过度风格化) ,如果风格化不够,就生成的图像P缺少S的质感,如果过度风格化则生成的图像P就会丢失大量C的内容信息,而本文提出的:Style-Aware Normalized Loss 就是为了缓解这两种由于风格化程度不合理导致的情况的。

并且原始任意风格迁移的风格部分损失还有其他的问题,请见下图:

在这里插入图片描述

这张图显示了四种AST方法[14,19,29,37]的经典的基于Gram矩阵的风格损失分布。较小的损失并不能保证更好的风格样式(左两张图像),而高质量的生成图像可以有较大的风格损失(中间的两张图像),而过度风格化的图像与直觉相反,获得了最大的风格损失。

这个时候我们继续深入思考原始的风格损失在训练中的实际运用,这是我们原先训练时候使用的风格损失
在这里插入图片描述

其中B代表batch size,也就是说在每1batch的训练数据中,我们是把每张图片风格损失设置为相同的权重,但是这是不太合理的,因为针对于不同的风格图片我们需要风格化的程度也不一样,所以不同图片的风格化损失的权重也应该不一样。

'''
一般的风格损失
'''
def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)
def style_loss(style, fake):
    return nn.MSELoss()(gram(style),gram(fake))

因此我们考虑训练时候把风格损失设计成这样更加合理,同一batch每张不同图片都有不同的风格损失权重,把它看作muti-task.
在这里插入图片描述

这个multi-task的视角对于解决ISR问题可以给予很大的帮助,我们通过给一个batch中每一个风格迁移任务(我们把每张图片的风格迁移都看作各自的任意风格迁移任务)给各自的"正确"的任务权重。见下图:

在这里插入图片描述

其中等式右边的分母Vl(S,P)是和具体每一个任务相关的(其中λk=1/Vl(S,P))。

因为经典的AST每层风格损失为

在这里插入图片描述

我们首先推导出经典的AST每层风格损失的理论上界和下界,如下图所示:

在这里插入图片描述

然后论文把损失的理论上界当作Vl(S,P),得到最终的结果:

在这里插入图片描述

并且经典的AST每层风格损失和它的理论上界为线性正相关关系,由下面这两张图可见:

在这里插入图片描述

在这里插入图片描述

并且论文提出的新的style-balanced loss是数值越小,风格化程度越大。

在这里插入图片描述

上面提到的github复现的损失

def weighted_mse_loss(input,target,weights = None):
  assert input.size() == target.size()
  size = input.size()
  if weights == None:
    weights = torch.ones(size = size[0])
    
  if len(size) == 3: # gram matrix is B,C,C
    se = ((input.view(size[0],-1) - target.view(size[0],-1))**2)
    return (se.mean(dim = 1)*weights).mean()
    
def gram_matrix(x, normalize=True):
    '''
    Generate gram matrices of the representations of content and style images.
    '''
    (b, ch, h, w) = x.size()
    features = x.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t)
    if normalize:
        gram /= ch * h * w
    return gram

def calc_ast_style_loss_normalized(self, input, target): # this replaces the style loss for all AST models in the training
        G1 = gram_matrix(input, False)
        G2 = gram_matrix(target, False).detach() # we dont need the gradient of the target

        size = input.size()
        assert(len(size) == 4)

        g1_norm = torch.linalg.norm(G1,dim = (1,2))
        g2_norm = torch.linalg.norm(G2,dim = (1,2))

        size = G1.size()
        Nl = size[1] * size[2] # Or C x C = C^2
        normalize_term =  (torch.square(g1_norm) + torch.square(g2_norm))/Nl  #

        weights = (1/normalize_term)
        #weights = weights.view(size[0],1,1)
        return weighted_mse_loss(G1,G2,weights)

我认为的论文损失设计:

import paddle
def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)
def style_loss(style, fake):
    style_gram = gram(style)
    fake_gram = gram(fake)
    weight = paddle.mean((paddle.square(style_gram)+paddle.square(fake_gram)).flatten(1),axis=1)
    print("weight.shape",weight.shape)
    loss = (paddle.mean(paddle.square(style_gram - fake_gram).flatten(1),axis=1)/weight).mean()
    return loss

style_feat = paddle.randn([4,64,256,256])
fake_feat = paddle.randn([4,64,256,256])

style_loss(style_feat,fake_feat)
W0822 19:13:14.206990 16572 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0822 19:13:14.212574 16572 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.


weight.shape [4]





Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       [0.00099442])

2. 实际使用论文的损失函数验证结论

2.1 随意搭建一个AdaIN模型

from VGG_Model import VGG19
import paddle
import paddle.nn as nn 
VGG = VGG19()
x = paddle.randn([4,3,256,256])
b = VGG(x)
for i in b:
    print(i.shape)
[4, 64, 256, 256]
[4, 128, 128, 128]
[4, 256, 64, 64]
[4, 512, 32, 32]
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.shape
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.reshape([N, C, -1]).var(axis=2) + eps
    feat_std = feat_var.sqrt().reshape([N, C, 1, 1])
    feat_mean = feat.reshape([N, C, -1]).mean(axis=2).reshape([N, C, 1, 1])
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.shape[:2] == style_feat.shape[:2])
    size = content_feat.shape
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

class Model(nn.Layer):
    def __init__(self):
        super().__init__()
        self.vgg = VGG19()
        self.decoder = nn.Sequential(
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(512, 256, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(256, 256, (3, 3)),
            nn.ReLU(),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(256, 256, (3, 3)),
            nn.ReLU(),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(256, 256, (3, 3)),
            nn.ReLU(),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(256, 128, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(128, 128, (3, 3)),
            nn.ReLU(),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(128, 64, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(64, 64, (3, 3)),
            nn.ReLU(),
        nn.Pad2D((1,1,1,1), mode='reflect'),
            nn.Conv2D(64, 3, (3, 3)),
        nn.Tanh()
        )
        self.adain = adaptive_instance_normalization
    def forward(self,content,style):
        content_feat = self.vgg(content)
        style_feat = self.vgg(style)
        x = self.adain(content_feat[-1],style_feat[-1])
        x = self.decoder(x)
        return x
x = paddle.randn([4,3,256,256])
Model()(x,x).shape
[4, 3, 256, 256]
import cv2
from matplotlib import image
import numpy as np
import os
import paddle
import paddle.optimizer
import paddle.nn as nn
import math
from tqdm import tqdm
from paddle.io import Dataset
from paddle.io import DataLoader
import json
import paddle.nn.functional as F
import paddle.tensor as tensor
from paddle.vision.datasets import ImageFolder
from paddle.vision.transforms import Compose, ColorJitter, Resize
import random
from visualdl import LogWriter
smooth_knerl = tensor.to_tensor([[[[1/9 for i in range(3)]for i in range(3)]for i in range(3)]for i in range(3)])
log_writer = LogWriter("./log/gnet")
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

2.2 数据处理


import os
if not os.path.isdir("./data/d"):
    os.mkdir("./data/d")
! unzip -qo data/data105513/photo2vangogh.zip -d ./data/d
from paddle.vision.transforms import CenterCrop,Resize
transform = Resize((256,256))
#构造dataset
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
import paddle
import cv2
import os
def data_maker(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname) and ("outfit" not in fname):
                path = os.path.join(root, fname)
                images.append(path)

    return sorted(images)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


class AnimeDataset(paddle.io.Dataset):
    """
    """
    def __init__(self):
        super(AnimeDataset,self).__init__()
        real_image_dirs =data_maker( "data/d/trainA")
        anime_image_dirs =data_maker("data/d/trainB")
        self.size =min( len(anime_image_dirs),len(real_image_dirs))

        self.real_image_dirs =real_image_dirs[:self.size]
        self.anime_image_dirs =anime_image_dirs[:self.size]
    # cv2.imread直接读取为GBR,把通道换成RGB
    @staticmethod
    def loader(path):
        return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
                            cv2.COLOR_BGR2RGB)
    def __getitem__(self, index):
        img_a = AnimeDataset.loader(self.real_image_dirs[index])
        img_a =transform(img_a)
        img_b = AnimeDataset.loader(self.anime_image_dirs[index])
        img_b = transform(img_b)

        return img_a,img_b

    def __len__(self):
        return self.size

dataset = AnimeDataset()
for real,anime in dataset:
    print(real.shape)
    print(anime.shape)
    break



BATCH_SIZE =2
dataset = AnimeDataset()
data_loader = paddle.io.DataLoader(dataset,shuffle= True,batch_size=BATCH_SIZE,drop_last =True)
(256, 256, 3)
(256, 256, 3)
x = paddle.randn([4])
y = paddle.randn([4])
list(map(lambda a:(float(a[0]),float(a[1])),zip(x,y)))
[(-0.27411192655563354, -0.7972860336303711),
 (-0.8477393984794617, -0.09127818793058395),
 (0.6167448163032532, -0.9844359159469604),
 (0.02306453324854374, 0.392347127199173)]

2.3 训练部分


generator = Model()
# # model和discriminator参数文件导入
M_path ='model_params/Mmodel_state1.pdparams'
layer_state_dictm = paddle.load(M_path)
generator.set_state_dict(layer_state_dictm)

scheduler_G = paddle.optimizer.lr.StepDecay(learning_rate=0.0002, step_size=400, gamma=0.95, verbose=True)
# scheduler_D = paddle.optimizer.lr.StepDecay(learning_rate=0.0001, step_size=3, gamma=0.99, verbose=True)

optimizer_G = paddle.optimizer.Adam(learning_rate=scheduler_G,parameters=generator.parameters(),beta1=0.5)
VGG = VGG19()
Epoch 0: StepDecay set learning rate to 0.0002.

为了验证经典风格损失和它的理论上限为线性关系所以维度style_loss输出jilu_list

def gram(x):
    b, c, h, w = x.shape
    x_tmp = x.reshape((b, c, (h * w)))
    gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
    return gram / (c * h * w)
def style_loss(style, fake):
    style_gram = gram(style)
    fake_gram = gram(fake)
    weight = paddle.mean((paddle.square(style_gram)+paddle.square(fake_gram)).flatten(1),axis=1)
    # print("weight.shape",weight.shape)
    a = paddle.mean(paddle.square(style_gram - fake_gram).flatten(1),axis=1)
    jilu_list = list(map(lambda a:(float(a[0]),float(a[1])),zip(a,weight)))
    loss = (a/weight).mean()
    return loss,jilu_list

mse_loss = nn.MSELoss()
from tqdm import tqdm
step = 0
save_dir_model =  "model_params"
EPOCHES = 10000
all_jilu_list= [] 
for epoch in  range(EPOCHES):
    for data in tqdm(data_loader):
        try:
            real_data,anime_data = [i/127.5-1 for i in data]
            real_data =paddle.transpose(x=real_data,perm=[0,3,1,2])
            anime_data =paddle.transpose(x=anime_data,perm=[0,3,1,2])
            # print(real_data.shape)
            # print(anime_data.shape)
            img_fake = generator(anime_data,real_data)
            img_cc = generator(real_data,real_data)
            img_ss = generator(anime_data,anime_data)
            
            
            g_vggloss = paddle.to_tensor(0.)
            g_styleloss= paddle.to_tensor(0.)
            g_ccloss= paddle.to_tensor(0.)
            g_ssloss= paddle.to_tensor(0.)
            
            rates = [1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

            fake_features = VGG(img_fake)
            real_features = VGG(real_data)
            anime_features = VGG(anime_data)
            cc_features = VGG(img_cc)
            ss_features = VGG(img_ss)

            for i in range(len(fake_features)):
                a,b = fake_features[i], real_features[i]
                b.stop_gradient =True
                g_vggloss += rates[i] * mse_loss(a,b)

                # g_vggloss *= 10
                g_vggloss /=3000


            for i in range(len(anime_features)):
                a,b = fake_features[i], anime_features[i]
                b.stop_gradient =True
                styleloss,jilu_list = style_loss(a,b)
                all_jilu_list+=jilu_list
                g_styleloss += rates[i] * styleloss 
                g_styleloss *=3

            for i in range(len(cc_features)):
                a,b = cc_features[i], real_features[i]
                b.stop_gradient =True
                g_ccloss += rates[i] * mse_loss(a,b)

                # g_ccloss *= 2
                g_ccloss /=10000

            for i in range(len(ss_features)):
                a,b = ss_features[i], anime_features[i]
                b.stop_gradient =True
                g_ssloss += rates[i] * mse_loss(a,b)

                # g_ssloss *= 2
                g_ssloss /=10000
            
            id1 = (mse_loss(img_cc,real_data)+mse_loss(img_ss,anime_data))*5
            mse = nn.MSELoss()(img_fake,real_data)*30   
            # init_loss = mse + g_vggloss
            init_loss = mse + g_vggloss + g_styleloss + g_ccloss + g_ssloss + id1
            # init_loss =  g_ccloss + g_ssloss + id1

            init_loss.backward()
            optimizer_G.step()
            optimizer_G.clear_grad()


            if step%2==0:

                log_writer.add_scalar(tag='train/g_vggloss', step=step, value=g_vggloss.numpy()[0])
                log_writer.add_scalar(tag='train/mse', step=step, value=mse.numpy()[0])
                log_writer.add_scalar(tag='train/g_ssloss', step=step, value=g_ssloss.numpy()[0])
                log_writer.add_scalar(tag='train/g_ccloss', step=step, value=g_ccloss.numpy()[0])
                log_writer.add_scalar(tag='train/g_id1loss', step=step, value=id1.numpy()[0])
                log_writer.add_scalar(tag='train/g_loss', step=step, value=init_loss.numpy()[0])
                log_writer.add_scalar(tag='train/g_styleloss', step=step, value=g_styleloss.numpy()[0])



            step+=1
            # print(i)
            if step%100 == 3:
                print(step,"g_vggloss",g_vggloss.numpy()[0],"g_mseloss",mse.numpy()[0],"g_styleloss",g_styleloss.numpy()[0],"id1",id1.numpy()[0] ,\
                "g_ccloss",g_ccloss.numpy(),"g_ssloss",g_ssloss.numpy(),"g_loss",init_loss.numpy()[0])
                # print(step,"dreal_loss",dr_ganloss.numpy()[0],"dfake_loss",df_ganloss.numpy()[0],"dseg_ganloss",dseg_ganloss.numpy()[0],"d_all_loss",d_loss.numpy()[0])

                real_data = (real_data+1)*127.5
                anime_data = (anime_data+1)*127.5
                img_fake = (img_fake+1)*127.5

                g_output = paddle.clip(paddle.concat([img_fake,real_data,anime_data],axis = 3),0,255).detach().numpy()                      # tensor -> numpy
                g_output = g_output.transpose(0, 2, 3, 1)[0]             # NCHW -> NHWC
                # g_output = (g_output+1) *127.5                        # 反归一化
                g_output = g_output.astype(np.uint8)
                cv2.imwrite(os.path.join("./result", 'epoch'+str(step).zfill(3)+'.png'),cv2.cvtColor(g_output,cv2.COLOR_RGB2BGR))


            
            if step%100 == 3:

                # save_param_path_d = os.path.join(save_dir_Discriminator, 'Dmodel_state'+str(3)+'.pdparams')
                # paddle.save(discriminator.state_dict(), save_param_path_d)

                save_param_path_m = os.path.join(save_dir_model, 'Mmodel_state'+str(1)+'.pdparams')
                paddle.save(generator.state_dict(), save_param_path_m)
        except:
            pass
    jilu_str = json.dumps(all_jilu_list)
    f = open("jilu.txt","w+")
    f.writelines(jilu_str)
    f.close()
    scheduler_G.step()

2.4 图像可视化

import matplotlib.pyplot as plt
import json

f = open("jilu.txt")


shuju = json.loads(f.read())
x = [i[0] for i in shuju]
y = [i[1] for i in shuju]

# np.random.seed(0)           #设置随机数种子
# X, Y = sklearn.datasets.make_moons(40,noise=0.2) #生成2组半圆形数据

# arg = np.squeeze(np.argwhere(Y==0),axis = 1)     #获取第1组数据索引
# arg2 = np.squeeze(np.argwhere(Y==1),axis = 1)#获取第2组数据索引

plt.title("data")
plt.scatter(x,y, s=100,c='b',marker='+',label='data1')
# # plt.scatter(X[arg2,0], X[arg2,1],s=40, c='r',marker='o',label='data2')
# plt.legend()
2')
# plt.legend()
plt.show()

在这里插入图片描述

3. 总结

由上图可见,确实经典的风格损失和其理论上限成总的线性关系,但是既然风格可以进行归一化,那么内容损失是不是也可以进行归一化处理呢,这是我以后值得探究的点。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐