PASSL文字画图——你写啥他画啥

一、前言

还在为不会画画而苦恼吗?请发挥你的超强想象力,随意输入一句话,就能为你画出一张图片。本文采用的是CLIP,CLIP是一个图文多模态预训练模型,如何实现请参考PASSL,下面教程手把手教你如何使用,学会了,可替换自己的数据集画出你专属的图片哦。

二、算法介绍

CLIP(Contrastive Language-Image Pre-training)1是 openai 提出的图文对比预训练模型,该模型在 4 亿(400 million)互联网收集的图像文本数据对上完成自监督预训练,在多模态大数据、大模型以及大 batch-size 的加持下,CLIP 模型 zero-shot 性能在 30 多个视觉公开数据集上取得了足以匹敌有监督学习的效果,在部分数据集上甚至超越了监督学习的效果;

1)模型结构介绍

如下图所示,为了利用自然语言信息作为和图像表示学习的监督信息,CLIP 模型由 Vision Transformer 和 Bert-Encoder 双塔结构组成,图像部分由 Vision Transformer 进行编码,文本部分由 Transformer-Encoder2 进行编码;

2)Contrastive Loss 计算

对比学习简介:假定一个 batch 有 N 个图像-文本对组成,该 batch 理论上可以产生 N^2 对样本,其中包括 N 对正样本和 N^2 - N 对负样本;对比学习的目标是减小正样本对之间的余弦相似度同时增大负样本对之间的余弦相似度;

CLIP 训练流程:首先将 N 对图像-文本对输入图像和文本编码器,然后产生 N^2 对样本余弦相似度得分 S,最终将 S 输入到交叉熵损失函数来计算最终的 loss 从而优化模型参数;

三、安装 Paddle 自监督学习库 PASSL

PASSL 是一个世界领先的自监督算法库,旨在加速研究人员使用 paddle 开发自监督算法;

!pip install passl==0.0.4 -U -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install scikit-image -U -i https://pypi.tuna.tsinghua.edu.cn/simple 

四、加载模型和参数

import os
import numpy as np
import paddle 
from passl import SimpleTokenizer

print("Paddle version:", paddle.__version__)
# Downloading the model
if not os.path.exists('ViT-B-32.pdparams'):
    os.system('wget https://passl.bj.bcebos.com/models/ViT-B-32.pdparams')

# Load Model
from passl.modeling.architectures import CLIPWrapper
arch = {'name': 'CLIP', 'embed_dim':512, 
        'image_resolution': 224, 'vision_layers': 12,
        'vision_width': 768, 'vision_patch_size': 32,
        'context_length': 77, 'vocab_size': 49408,
        'transformer_width': 512, 'transformer_heads': 8,
        'transformer_layers': 12,'qkv_bias': True}
head = {'name': 'CLIPHead'}
model = CLIPWrapper(architecture=arch, head=head)
tokenizer = SimpleTokenizer()

with paddle.no_grad():
    state_dict = paddle.load("ViT-B-32.pdparams")['state_dict']
    model.set_state_dict(state_dict)
Paddle version: 2.1.2

五、图像预处理初始化

# Image Preprocessing
'''We resize the input images and center-crop them to conform with the image resolution that the model expects. 
   Before doing so, we will normalize the pixel intensity using the dataset mean and standard deviation.'''
from paddle.vision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from passl.datasets.preprocess.transforms import ToRGB

preprocess = Compose([Resize(224,interpolation='bicubic'),
                     CenterCrop(224),
                     ToTensor(),
                     ])
image_mean = paddle.to_tensor([0.48145466, 0.4578275, 0.40821073])
image_std = paddle.to_tensor([0.26862954, 0.26130258, 0.27577711])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

六、文字画图

from PIL import Image as PilImage

def precompute_image_features():
    image_features = []
    dataset = []
    im_tensors = []
    unsplash = 'unsplash-25k-photos.zip'
    unsplash_dir = 'data/data108238'
    if not os.path.exists(unsplash_dir + '/' + unsplash):
        os.makedirs(unsplash_dir)
        os.system('wget -P data/data108238/ http://sbert.net/datasets/' + unsplash)
    
    os.system('unzip data/data108238/unsplash-25k-photos.zip -d data/data108238/')    
  # Downloading the features
    if not os.path.exists('feats.npy'):
        os.system('wget https://passl.bj.bcebos.com/aisutio/feats.npy')
        os.system('wget https://passl.bj.bcebos.com/aisutio/names.npy')
    feats = np.load('feats.npy')
    namelist = np.load('names.npy')
    return feats, list(namelist)

def find_image(text_query, datatset, image_features, n=1):
    from passl import SimpleTokenizer
    tokenizer = SimpleTokenizer()
    text_tokens = [tokenizer.encode(text_query)]

    text_input = paddle.zeros((len(text_tokens), 77), dtype="int64")
    sot_token = tokenizer.encoder['<|startoftext|>']
    eot_token = tokenizer.encoder['<|endoftext|>']

    for i, tokens in enumerate(text_tokens):
        tokens = [sot_token] + tokens + [eot_token]
        text_input[i, :len(tokens)] = paddle.to_tensor(tokens)
    
    zeroshot_weights = model.model.encode_text(text_input)
    zeroshot_weights /= zeroshot_weights.norm(axis=-1, keepdim=True)

    distances = np.dot(zeroshot_weights, image_features.T)
    
    file_paths = []
    for i in range(1, n+1):
        idx = np.argsort(distances, axis=1)[0, -i]
        file_paths.append('data/data108238/' + dataset[idx])
    return file_paths

import matplotlib.pyplot as plt
from IPython.display import display, Image

def show_images(image_list):
    for im_path in image_list:
        display(Image(filename=im_path))

image_features, dataset = precompute_image_features()   

def draw(text, out_num=1):
    image_paths = find_image(text, dataset, image_features, n=out_num)
    show_images(image_paths)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

数数 Count

text = "one cat"
draw(text)

在这里插入图片描述

text = "two cats"
draw(text)

在这里插入图片描述

颜色 Color

text = "one black dog"
draw(text)
text = "a boy in the forest"
draw(text)

在这里插入图片描述

物体之间的逻辑关系

text = "a person wear a headset"
draw(text, out_num=2)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

在这里插入图片描述

在这里插入图片描述

参考

如果想了解更多 CLIP 或 DALLE 原理,请点亮你的小星星 https://github.com/PaddlePaddle/PASSL 或者在 github issue 区留言

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐