StyleGAN 特征插值算法复现

本教程通过对 StyleGanV2 网络上的部分修改,实现图片的拼接。

参考论文 StyleGAN of All Trades: Image Manipulation with
Only Pretrained StyleGAN

一、StyleGAN 简介

在GAN出现之前,图形学已经出现了一个重要分支,叫做:纹理迁移(Texture Transfer)
这个领域著名的Adobe公司有不少贡献和论文,它的一般的方法是:先做纹理迁移,再做图像重建(Image Reconstruction)。
大神 Gatys 把神经网络引入了纹理迁移领域,创立了神经风格迁移(Style Transfer)。

对于人脸图像而言,StyleGAN 中的“Style”是指数据集中人脸的主要属性,比如人物的姿态、脸型上面的表情、人脸朝向、发型等等,还包括纹理细节上的人脸肤色、人脸光照等方方面面。StyleGAN 用风格(style)来影响人脸的姿态、身份特征等,用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。

StyleGAN 的网络结构包含两个部分,第一个是Mapping network,即下图 (b)中的左部分,由隐藏变量 z 生成 中间隐藏变量 w 的过程,这个 w 就是用来控制生成图像的 style,即风格。 第二个是Synthesis network,它的作用是生成图像,创新之处在于给每一层子网络都喂了 A 和 B,A 是由 w 转换得到的,用于控制生成图像的风格,B 是转换后的随机噪声,用于丰富生成图像的细节,即每个卷积层都能根据输入的 A 来调整"style"。

SyleGANV2 的改进

StyleGANV2 在 StyleGAN 上的改进,减少一些不必要的计算,包括:

  1. 网络结构的改进,如下图所示

  1. Lazy regularization

减少正则项的优化次数,没16个minibatch才优化一次正则项,以减少计算量。

  1. No progressive growth

在 StyleGAN 训练高分辨率的图片中,都是从低分辨率开始,等训练稳定后,再逐步增加分辨率,即每一种分辨率都会去输出结果,这会导致输出频率较高的细节,缺少移动的变化。使用 Progressive growth 的原因是高分辨率图像生成需要的网络比较大比较深,当网络过深的时候不容易训练,但是skip connection可以解决深度网络的训练,并且避免之前的错误发生。

二、特征插值

StyleGAN 在各种图像处理和编辑任务上,表现优异。然而对于不同的任务,每一次都需要重新训练会花费大量的时间和资源。因此有人提出在预训练的 StyleGAN 模型的基础之上,通过在上文中提到的隐藏变量 w 的一点改动。就可以使得 StyleGAN 直接应用于各种任务。包括全景图生成、从单张图像生成、特征插值、图像到图像翻译等。

对 StyleGAN 中间层进行拼贴可以实现图像信息混合,但要拼接的两张图差异太大时效果往往不尽如人意。因此我们需要用到特征插值方法来解决该问题。

具体的方法:在每个 StyleGAN 层,分别使用不同的噪声生成中间特征 f i A f_i^A fiA f i B f_i^B fiB,然后通过以下公式将两个中间特征进行平滑地混合,然后再传递到下一个卷积层进行操作。

\begin{align}
f_i = (1-\alpha)f_i^A+\alpha f_i^B
\end{align}

其中 α ∈ [ 0 , 1 ] \alpha \in [0, 1] α[0,1] 是一个 mask,如果用于水平混合,则 mask 将从左到右变大。

论文中的结果如下

该方法能够较好地混合两幅图像,对比其他人的工作结果,不会存在明显的伪影。而且在论文中也体现出有 87.6%的人更喜欢该方法。

本教程实现两张图片的直接拼接,同时在学习的过程中也实现了关于 style mixing 的部分。两者是相似的,都是对隐藏变量 w 做一些处理即可。

三、安装必要的包以及配置环境

在这里我们使用 ppgan 中已经预训练好的人脸生成的 StyleGANV2 模型,因此需要先安装 ppgan 。

!pip install ppgan
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: ppgan in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.1.0)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (2.2.3)
Requirement already satisfied: PyYAML>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (5.1.2)
Requirement already satisfied: numba==0.53.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (0.53.1)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (4.64.1)
Requirement already satisfied: scipy>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (1.6.3)
Requirement already satisfied: librosa==0.8.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (0.8.1)
Requirement already satisfied: easydict in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (1.9)
Requirement already satisfied: imageio-ffmpeg in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (0.3.0)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (4.6.0.66)
Requirement already satisfied: natsort in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (8.2.0)
Requirement already satisfied: scikit-image>=0.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (0.19.3)
Requirement already satisfied: munch in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (2.5.0)
Requirement already satisfied: imageio==2.9.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppgan) (2.9.0)
Requirement already satisfied: pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imageio==2.9.0->ppgan) (8.2.0)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imageio==2.9.0->ppgan) (1.19.5)
Requirement already satisfied: joblib>=0.14 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (0.14.1)
Requirement already satisfied: soundfile>=0.10.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (0.10.3.post1)
Requirement already satisfied: pooch>=1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (1.6.0)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (0.24.2)
Requirement already satisfied: decorator>=3.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (4.4.2)
Requirement already satisfied: audioread>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (2.1.8)
Requirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (21.3)
Requirement already satisfied: resampy>=0.2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from librosa==0.8.1->ppgan) (0.2.2)
Requirement already satisfied: llvmlite<0.37,>=0.36.0rc1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from numba==0.53.1->ppgan) (0.36.0)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from numba==0.53.1->ppgan) (56.2.0)
Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image>=0.14.0->ppgan) (1.3.0)
Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image>=0.14.0->ppgan) (2.4)
Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image>=0.14.0->ppgan) (2021.11.2)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->ppgan) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->ppgan) (1.1.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->ppgan) (3.0.9)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->ppgan) (2019.3)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->ppgan) (2.8.2)
Requirement already satisfied: six>=1.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->ppgan) (1.16.0)
Requirement already satisfied: appdirs>=1.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pooch>=1.0->librosa==0.8.1->ppgan) (1.4.4)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pooch>=1.0->librosa==0.8.1->ppgan) (2.24.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa==0.8.1->ppgan) (2.1.0)
Requirement already satisfied: cffi>=1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from soundfile>=0.10.2->librosa==0.8.1->ppgan) (1.15.1)
Requirement already satisfied: pycparser in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from cffi>=1.0->soundfile>=0.10.2->librosa==0.8.1->ppgan) (2.21)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->pooch>=1.0->librosa==0.8.1->ppgan) (2019.9.11)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->pooch>=1.0->librosa==0.8.1->ppgan) (1.25.6)
Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->pooch>=1.0->librosa==0.8.1->ppgan) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->pooch>=1.0->librosa==0.8.1->ppgan) (2.8)

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m22.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
import os
import cv2
import random
import numpy as np
import paddle
from ppgan.apps.styleganv2_predictor import StyleGANv2Predictor

四、模型网络的修改

这一部分主要实现了 style mixing 和特征插值的方法。两者的输入 latent1, latent2 即为隐藏变量 w。通过上文的公式来对这两个隐藏变量进行混合处理。

def make_image(tensor):
    return (((tensor.detach() + 1) / 2 * 255).clip(min=0, max=255).transpose(
        (0, 2, 3, 1)).numpy().astype('uint8'))

class StyleGANv2FeatureInterpolation(StyleGANv2Predictor):
    @paddle.no_grad()
    def mixing(self, latent1, latent2, alpha):
        """
        风格混合
        """
        assert latent1.shape[1] == latent2.shape[1] == len(
            alpha
        ), 'latents and their weights(alpha) should have the same level nums.'
        mix_latent = []
        for i, a in enumerate(alpha):
            mix_latent.append(latent1[:, i:i + 1] * a +
                              latent2[:, i:i + 1] * (1 - a))
        mix_latent = paddle.concat(mix_latent, 1)
        latent_n = paddle.concat([latent1, latent2, mix_latent], 0)
        img_gen, _ = self.generator([latent_n],
                               input_is_latent=True,
                               randomize_noise=False)
        print('------------generate done-----------------------')
        imgs = make_image(img_gen)
        src_img1 = imgs[0]
        src_img2 = imgs[1]
        dst_img = imgs[2]

        os.makedirs(self.output_path, exist_ok=True)
        save_src_path = os.path.join(self.output_path, 'src1.mixing2.png')
        cv2.imwrite(save_src_path, cv2.cvtColor(src_img1, cv2.COLOR_RGB2BGR))
        save_src_path = os.path.join(self.output_path, 'src2.mixing2.png')
        cv2.imwrite(save_src_path, cv2.cvtColor(src_img2, cv2.COLOR_RGB2BGR))
        save_dst_path = os.path.join(self.output_path, 'dst.mixing2.png')
        cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))

        return src_img1, src_img2, dst_img

    @paddle.no_grad()
    def blend(self, latent1, latent2, mode='horizontal'):
        """
        拼接图片
        """
        generator = self.generator
        noise = [getattr(generator.noises, f'noise_{i}') for i in range(generator.num_layers)]


        assert mode in ('vertical', 'horizontal')
        if mode == 'vertical':
            view_size = (1,1,-1,1)
        else:
            view_size = (1,1,1,-1)
            
        out = generator.input(latent1)
        
        out1 = generator.conv1(out, latent1[:,0], noise=noise[0])
        out2 = generator.conv1(out, latent2[:,0], noise=noise[0])
        alpha = paddle.zeros([out1.shape[2]])
        pad = out1.shape[2]//4
        alpha[-pad:] = 1
        alpha[pad:-pad] = paddle.linspace(0,1,alpha.shape[0]-2*pad)
        alpha = alpha.reshape(view_size).expand_as(out1)
        out = (1-alpha)*out1 + alpha*out2

        skip1 = generator.to_rgb1(out, latent1[:,1])
        skip2 = generator.to_rgb1(out, latent2[:,1])
        alpha = paddle.zeros([skip1.shape[2]])
        pad = skip1.shape[2]//4
        alpha[-pad:] = 1
        alpha[pad:-pad] = paddle.linspace(0,1,alpha.shape[0]-2*pad)
        alpha = alpha.reshape(view_size).expand_as(skip1)
        skip = (1-alpha)*skip1 + alpha*skip2


        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(
            generator.convs[::2], generator.convs[1::2], noise[1::2], noise[2::2], generator.to_rgbs
        ):
            out1 = conv1(out, latent1[:,i], noise=noise1)
            out2 = conv1(out, latent2[:,i], noise=noise1)
            alpha = paddle.zeros([out1.shape[2]])
            pad = out1.shape[2]//4
            alpha[-pad:] = 1
            alpha[pad:-pad] = paddle.linspace(0,1,alpha.shape[0]-2*pad)
            alpha = alpha.reshape(view_size).expand_as(out1)
            out = (1-alpha)*out1 + alpha*out2

            out1 = conv2(out, latent1[:,i+1], noise=noise2)
            out2 = conv2(out, latent2[:,i+1], noise=noise2)
            alpha = paddle.zeros([out1.shape[2]])
            pad = out1.shape[2]//4
            alpha[-pad:] = 1
            alpha[pad:-pad] = paddle.linspace(0,1,alpha.shape[0]-2*pad)
            alpha = alpha.reshape(view_size).expand_as(out1)
            out = (1-alpha)*out1 + alpha*out2

            skip1 = to_rgb(out, latent1[:,i+2], skip)
            skip2 = to_rgb(out, latent2[:,i+2], skip)
            alpha = paddle.zeros([skip1.shape[2]])
            pad = skip1.shape[2]//4
            alpha[-pad:] = 1
            alpha[pad:-pad] = paddle.linspace(0,1,alpha.shape[0]-2*pad)
            alpha = alpha.reshape(view_size).expand_as(skip1)
            skip = (1-alpha)*skip1 + alpha*skip2

            i += 2
        return skip

五、图像生成

def save_img(img, output_path, file_name):
    dst_img = make_image(img)[0]
    os.makedirs(output_path, exist_ok=True)
    save_dst_path = os.path.join(output_path, file_name)
    cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))


def mixing(weight=[0.5]*18):
    output_path = './mixing'
    model_type = 'ffhq-config-f'
    size = 1024
    style_dim = 512
    n_mlp = 8
    channel_multiplier = 2
    device = paddle.device.get_device()
    paddle.device.set_device(device)
    predictor = StyleGANv2FeatureInterpolation(
        output_path=output_path,
        weight_path=None,
        model_type=model_type,
        seed=None,
        size=size,
        style_dim=style_dim,
        n_mlp=n_mlp,
        channel_multiplier=channel_multiplier)
    batch = 1

    noise1 = [paddle.randn([batch, style_dim])]
    noise2 = [paddle.randn([batch, style_dim])]
    img1, latent1 = predictor.generator(noise1, return_latents=True, randomize_noise=False)
    img2, latent2 = predictor.generator(noise2, return_latents=True, randomize_noise=False)
    predictor.mixing(latent1, latent2,weight)

def merge_two_image():
    output_path = './merge'
    model_type = 'ffhq-config-f'
    size = 1024
    style_dim = 512
    n_mlp = 8
    channel_multiplier = 2
    device = paddle.device.get_device()
    paddle.device.set_device(device)
    predictor = StyleGANv2FeatureInterpolation(
        output_path=output_path,
        weight_path=None,
        model_type=model_type,
        seed=None,
        size=size,
        style_dim=style_dim,
        n_mlp=n_mlp,
        channel_multiplier=channel_multiplier)

    batch = 1

    noise1 = [paddle.randn([batch, style_dim])]
    noise2 = [paddle.randn([batch, style_dim])]
    img1, latent1 = predictor.generator(noise1, return_latents=True, randomize_noise=False)
    img2, latent2 = predictor.generator(noise2, return_latents=True, randomize_noise=False)

    mix_img = predictor.blend(latent1, latent2)
    save_img(img1, output_path, 'src_img1.png')
    save_img(img2, output_path, 'src_img2.png')
    save_img(mix_img, output_path, 'mix_img.png')

if __name__ == "__main__":
    weight = np.random.uniform(0, 1, 18).tolist()
    print(weight)
    mixing(weight)
    merge_two_image()
    
[0.1610350392105575, 0.22156412003403791, 0.47708907798351163, 0.3625984265347839, 0.4979833225680559, 0.49208213681638013, 0.3747845522741763, 0.6116711136534634, 0.23705591384419666, 0.4985584724350498, 0.13669369364370942, 0.5905680810389776, 0.0035637720464305467, 0.5509074135167897, 0.6947074933157397, 0.4933257343330082, 0.13606117823441222, 0.2087959210560787]
[12/14 17:02:56] ppgan INFO: Found /home/aistudio/.cache/ppgan/stylegan2-ffhq-config-f.pdparams


W1214 17:02:58.590322 18229 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1214 17:02:58.593830 18229 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


------------generate done-----------------------
[12/14 17:03:00] ppgan INFO: Found /home/aistudio/.cache/ppgan/stylegan2-ffhq-config-f.pdparams

istudio/.cache/ppgan/stylegan2-ffhq-config-f.pdparams

六、结果

图片1

图片2

混合生成的图片

拼接两个图片结果

图片1

图片2

水平拼接图片

七、总结

文章中还有其他的一些方法没有一一实现,感兴趣的小伙伴可以阅读论文尝试实现一下。

实现该方法的主要问题可能是在于不清楚每一层网络处理后特征的维度,在实现的过程中经常会报维度上的错误。这些具体的细节,都需要再反复理解论文,反复调试。

参考资料 :

  1. StyleGAN 和 StyleGAN2 的深度理解
  2. 终于有人把各路StyleGAN做了个大汇总 | Reddit超热
  3. StyleGAN of All Trades: Image Manipulation with Only Pretrained StyleGAN

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐