用文字创造无尽图像可能,俄语版Dall-E预训练模型上手

(内含自动翻译,可输中文)

瞧一瞧,看一看,这dalle真奇妙,写入你所念,绘出你所想。

脑洞大开的你,可千万别错过。

“鳄梨形椅子”
“漂亮动漫女孩”
“A group of Shiba Inu flying in the sky”(一群柴犬在天上飞)

引入

  • DALL-E:非营利人工智能研究组织OpenAI训练了一个称为DALL-E的神经网络,可让用户以自然语言文本注解,创建内容相符的图像。而DALLE的名称由来,来自西班牙加泰隆尼亚超现实主义画家萨尔瓦多达利(Salvador Dalí),以及皮克斯动画人物瓦力(WALLE)的混合词。
  • CLIP:这是和DALL-E同时放出的论文,能够将图像映射到文本描述的类别中。可用于图像检索和文字图像匹配度排名。
  • VQGAN:Taming Transformer的一部分,将图像编码为Tokens,并通过对抗训练达到Tokens解码为高清可识别图片。
  • RealESRGAN:通过可微方式生成低清图片作为现实中高清图片的退化,在ESRGAN的基础上实现更加优秀的超分效果。

而本项目使用的RuDalle源Repo就是基于这些实现文字生成图像、进行筛选、超分,因为使用了VQGAN和RealESRGAN而达到了比OpenAI原版更漂亮的生成效果。其中训练方式也对CogView进行了参考(没错,就是相当于加了些改进的中文版的DALL-E,整体思路没有变)。OpenAI的生成模型没有开源,而CogView虽然放出模型了但是没有像RuDalle这样集成多个模型,更加工程化的达到更好地生成效果。

如果觉得本项目有趣的话记得来Github上点Star哦!

模型架构

DALL-E的模型整体架构其实并不复杂,除了编解码器部分将图像编码为Tokens并能够解码以外,其它部分基本上就类似于著名的文本生成模型GPT-3,相当于是把图片变成了一个个字符,由类GPT的模型推理下一个字符,而对图片的描述文字的字符串作为条件前置在图片的字符串之前。

推理

不过,众所周知,GPT-3这类模型要有好的生成效果,参数量就会很大,即会有很多隐藏层,并且每一层的Size也很大。所以RuDalle的模型主体也达到了惊人的两个G,一般的显卡也带不动,只好求助强劲的V100来帮帮忙了。

即使强悍如V100,进行一次完整推理也比较费劲,咱们也不期望在平台上训练了。

让我们按顺序开始实验吧!

# 解压模型权重 Unzip model weights
!cd data && unzip -qq data116979/pretrained_models.zip
# 安装依赖
!pip install -r requirements.txt > /dev/null 2> /dev/null
!pip install ipywidgets translators==4.9.5 > /dev/null 2> /dev/null
# 查看环境信息
import multiprocessing
import paddle
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PaddlePaddle version:", paddle.__version__)
print("CUDA version:", paddle.version.cuda())
print("cuDNN version:", paddle.device.get_cudnn_version())
device = 'cuda:0' if len(paddle.static.cuda_places()) > 0 else 'cpu'
print("device:", device)

!nvidia-smi
CPU: 24
RAM GB: 110.2
PaddlePaddle version: 2.2.0
CUDA version: 10.1
cuDNN version: 7605
device: cuda:0
Tue Nov 23 10:22:33 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:05:00.0 Off |                    0 |
| N/A   61C    P0   208W / 300W |      0MiB / 16384MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
# 导入模块依赖
from rudalle_paddle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle_paddle.utils import seed_everything
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
INFO:matplotlib.font_manager:font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
INFO:matplotlib.font_manager:generated new fontManager
# 载入Dalle模型
# fp16可启用,但目前没看到有什么加速效果
device = 'cuda'
dalle = get_rudalle_model('Malevich-paddle', pretrained=True, fp16=False, device=device, cache_dir='data/pretrained_models')
W1123 10:22:52.852972  2335 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1123 10:22:52.853024  2335 device_context.cc:465] device: 0, cuDNN Version: 7.6.


◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.
# 载入编解码模型和超分模型
realesrgan = get_realesrgan('x2-paddle', device=device, cache_dir='data/pretrained_models') # x2/x4/x8
tokenizer = get_tokenizer(cache_dir='data/pretrained_models')
vae = get_vae('vqgan.gumbelf8-sber.paddle', cache_dir='data/pretrained_models').to(device) # still not support dwt now
ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5-paddle', cache_dir='data/pretrained_models')
ruclip = ruclip.to(device)
x2-paddle --> ready
tokenizer --> ready
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
vae --> ready
ruclip --> ready

生成的主要操作 (generation by ruDALLE)

# 载入翻译器
import translators as ts

def translate(txt, backend):
    return getattr(ts, backend)(txt, from_language='auto', to_language='ru')
Using China server backend.
# 对图像的文字描述

backend = 'google' # google/bing/alibaba/tencent/sogou
source_text = '鳄梨形椅子' # 自动模式,你可以输入任何语言作为文字描述 # auto mode, you can type any language

print('源文本 target text:', source_text)
try:
    target_text = translate(source_text, backend)
except:
    raise Exception(
        '调用翻译器失败,请尝试更换backend。'
        'Failed to call the translator, please try to replace the backend.'
    )
print('目标文本 target text:', target_text)
源文本 target text: 鳄梨形椅子
目标文本 target text: Стул авокадо
# 随机数种子,单次生成中,使用同一种子数,同一文字描述会生成相同图像
seed_everything(42)
# 模型生成的主体部分
# 大模型生成比较费时间,请耐心等待

text = target_text

pil_images = []
scores = []
for top_k, top_p, images_num in [
    (2048, 0.995, 3), # 一共进行8次生成,赶时间的话可以注释掉6个
    (1536, 0.99, 3),
    (1024, 0.99, 3),
    (1024, 0.98, 3),
    (512, 0.97, 3),
    (384, 0.96, 3),
    (256, 0.95, 3),
    (128, 0.95, 3), 
]:
    _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p)
    pil_images += _pil_images
    scores += _scores
HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))





/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py:392: UserWarning: [93m
Warning:
tensor.grad will return the tensor value of the gradient. This is an incompatible upgrade for tensor.grad API.  It's return type changes from numpy.ndarray in version 2.0 to paddle.Tensor in version 2.1.0.  If you want to get the numpy value of the gradient, you can use :code:`x.grad.numpy()` [0m
  warnings.warn(warning_msg)



HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))
# 显示生成结果
show([pil_image for pil_image, score in sorted(zip(pil_images, scores), key=lambda x: -x[1])] , 6)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/figure.py:457: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure
  "matplotlib is currently using a non-GUI backend, "

在这里插入图片描述

使用CLIP筛选最符合描述的图像 (auto-cherry-pick by ruCLIP)

top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6)
show(top_images, 3)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

在这里插入图片描述

对图像进行超分 (super resolution)

sr_images = super_resolution(top_images, realesrgan)
show(sr_images, 3)

在这里插入图片描述

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐