



“A group of Shiba Inu flying in the sky”(一群柴犬在天上飞)


  • DALL-E:非营利人工智能研究组织OpenAI训练了一个称为DALL-E的神经网络,可让用户以自然语言文本注解,创建内容相符的图像。而DALLE的名称由来,来自西班牙加泰隆尼亚超现实主义画家萨尔瓦多达利(Salvador Dalí),以及皮克斯动画人物瓦力(WALLE)的混合词。
  • CLIP:这是和DALL-E同时放出的论文,能够将图像映射到文本描述的类别中。可用于图像检索和文字图像匹配度排名。
  • VQGAN:Taming Transformer的一部分,将图像编码为Tokens,并通过对抗训练达到Tokens解码为高清可识别图片。
  • RealESRGAN:通过可微方式生成低清图片作为现实中高清图片的退化,在ESRGAN的基础上实现更加优秀的超分效果。









# 解压模型权重 Unzip model weights
!cd data && unzip -qq data116979/pretrained_models.zip
# 安装依赖
!pip install -r requirements.txt > /dev/null 2> /dev/null
!pip install ipywidgets translators==4.9.5 > /dev/null 2> /dev/null
# 查看环境信息
import multiprocessing
import paddle
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PaddlePaddle version:", paddle.__version__)
print("CUDA version:", paddle.version.cuda())
print("cuDNN version:", paddle.device.get_cudnn_version())
device = 'cuda:0' if len(paddle.static.cuda_places()) > 0 else 'cpu'
print("device:", device)

CPU: 24
RAM GB: 110.2
PaddlePaddle version: 2.2.0
CUDA version: 10.1
cuDNN version: 7605
device: cuda:0
# 导入模块依赖
from rudalle_paddle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle_paddle.utils import seed_everything
# 载入Dalle模型
# fp16可启用,但目前没看到有什么加速效果
device = 'cuda'
dalle = get_rudalle_model('Malevich-paddle', pretrained=True, fp16=False, device=device, cache_dir='data/pretrained_models')
W1123 10:22:52.852972  2335 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1123 10:22:52.853024  2335 device_context.cc:465] device: 0, cuDNN Version: 7.6.

◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.
# 载入编解码模型和超分模型
realesrgan = get_realesrgan('x2-paddle', device=device, cache_dir='data/pretrained_models') # x2/x4/x8
tokenizer = get_tokenizer(cache_dir='data/pretrained_models')
vae = get_vae('vqgan.gumbelf8-sber.paddle', cache_dir='data/pretrained_models').to(device) # still not support dwt now
ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5-paddle', cache_dir='data/pretrained_models')
ruclip = ruclip.to(device)
x2-paddle --> ready
tokenizer --> ready
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
vae --> ready
ruclip --> ready

生成的主要操作 (generation by ruDALLE)

# 载入翻译器
import translators as ts

def translate(txt, backend):
    return getattr(ts, backend)(txt, from_language='auto', to_language='ru')
Using China server backend.
# 对图像的文字描述

backend = 'google' # google/bing/alibaba/tencent/sogou
source_text = '鳄梨形椅子' # 自动模式,你可以输入任何语言作为文字描述 # auto mode, you can type any language

print('源文本 target text:', source_text)
    target_text = translate(source_text, backend)
    raise Exception(
        'Failed to call the translator, please try to replace the backend.'
print('目标文本 target text:', target_text)
源文本 target text: 鳄梨形椅子
目标文本 target text: Стул авокадо
# 随机数种子,单次生成中,使用同一种子数,同一文字描述会生成相同图像
# 模型生成的主体部分
# 大模型生成比较费时间,请耐心等待

text = target_text

pil_images = []
scores = []
for top_k, top_p, images_num in [
    (2048, 0.995, 3), # 一共进行8次生成,赶时间的话可以注释掉6个
    (1536, 0.99, 3),
    (1024, 0.99, 3),
    (1024, 0.98, 3),
    (512, 0.97, 3),
    (384, 0.96, 3),
    (256, 0.95, 3),
    (128, 0.95, 3), 
    _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p)
    pil_images += _pil_images
    scores += _scores
# 显示生成结果
show([pil_image for pil_image, score in sorted(zip(pil_images, scores), key=lambda x: -x[1])] , 6)
使用CLIP筛选最符合描述的图像 (auto-cherry-pick by ruCLIP)

top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6)
show(top_images, 3)
对图像进行超分 (super resolution)

sr_images = super_resolution(top_images, realesrgan)
show(sr_images, 3)



