Masked Autoencoders Are Scalable Vision Learners

MAE提出一种自监督的训练方法,该方法可以有效地对模型进行与训练,提升模型性能。本项目实现了自监督训练部分,并且可视化了训练过程。

网络结构

Network
MAE的结构较为简单,它由编码器和解码器组成,这里编码器和解码器都采用了Transformer结构。对于输入图片,将其划分为patches后,对一定比例的patch进行masked(论文中比例为75%),将unmasked patches送入encoder得到encoded patches,引入masked tokens和encoded patches结合,送入decoder,decoder的输出目标是原图像,损失仅在masked patches上计算。

需要注意的细节:
1、Masking:将图像划分为不重叠的patches后,masked patches选择服从均匀分布;
2、Encoder: encoder仅作用在unmasked patches,embedding patches需要加上postion embeddings;
3、Decoder: decoder的输入由encoded patches和mask tokens组成,mask token是一个参数共享的可学习参数,同时为了mask tokens加上postion embeddings表示位置信息;
4、重构目标:decoder输出目标图片(输入原图)的每个像素值,损失仅在masked patches计算;
5、实现:
(1)对每个patch生成token;
(2)对所有token进行shuffle,然后按照masking ratio移除一部分token;
(3)得到encoded tokens后,将mask tokens和encoded tokens合并,注意这里不需要unshuffle,简单concat就可以;

本项目的内容

本项目在ImageNet 1K的验证机训练,将5W张图片的4W张用作训练数据,剩下的1W留作验证。由于训练比较慢,这里只对MAE进行预训练,masking ratio为0.5,仅训练了200个epoch,由于数据少,epoch小,效果并不是太好,但是可以看到mae的输出变化过程,没有fine-tuning过程。

输出变化

左图是原图,中间是masked image, 右面是mae的预测结果。
epoch 1:
1

epoch 10:
10

epoch 200:

200

其他

参考:vit-pytorch
非常感谢朱欤老师的课程(朱老师牛逼):从零开始学视觉Transformer
Aistudio个人主页:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/312316
路过的老爷们,求求点一下喜欢,给老弟加点战斗力。

# 处理数据集
%cd ~/data/data89857/
!tar -xf ILSVRC2012mini.tar
%cd ~/

# 数据集的txt文件有点问题,修正train_list内容, 运行一次就可以
import os

train_file_path = '/home/aistudio/data/data89857/ILSVRC2012mini/train_list.txt'
data = []
with open(train_file_path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        _, info = line.split('/')
        data.append(info)

with open(train_file_path, 'w') as f:
    f.writelines(data)

ViT

ViT的实现不做过多解释,需要注意: 由于MAE的重构目标是原图的像素值,所以不要使用卷积来进行patch embedding,先对原图划分patches,然后使用linear embedding。

# VIT
import paddle 
from paddle import nn


class PreNorm(nn.Layer):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self, x):
        return self.fn(self.norm(x))


class Mlp(nn.Layer):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = self.mlp(x)
        return x


class Identity(nn.Layer):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class Attention(nn.Layer):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** (-0.5)

        self.attend = nn.Softmax(axis = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias_attr=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout),
        ) if project_out else Identity()

    def forward(self, x):
        B, N, _ = x.shape
        qkv = self.to_qkv(x).chunk(3, axis=-1)
        q, k, v = map(lambda t: t.reshape([B, N, self.heads, -1]).transpose([0, 2, 1, 3]), qkv)

        dots = paddle.matmul(q, k.transpose([0, 1, 3, 2])) * self.scale
        attn = self.attend(dots)

        out = attn.matmul(v)
        out = out.transpose([0, 2, 1, 3]).flatten(2)
        out = self.to_out(out)
        return out


class Transformer(nn.Layer):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.LayerList()
        for _ in range(depth):
            self.layers.append(
                nn.LayerList(
                    [
                        PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                        PreNorm(dim, Mlp(dim, mlp_dim, dropout=dropout)),
                    ]
                )
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class PatchEmbedding(nn.Layer):
    def __init__(self, image_size, patch_size, embed_dim=768, in_channels=3):
        super().__init__()
        image_height, image_width = image_size if isinstance(image_size, tuple) else (image_size, image_size) 
        self.patch_height, self.patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)

        assert image_height % self.patch_height == 0 and image_width % self.patch_width == 0, "Image dimensions must be divisible by the patch size."
        self.p1, self.p2 = (image_height // self.patch_height), (image_width // self.patch_width)
        self.num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)

        self.patch_embed = nn.Linear(in_channels * self.patch_height * self.patch_width, embed_dim)

    def forward(self, x):
        N, C, H, W = x.shape
        patches = x.reshape([N, C, self.p1, self.patch_height, self.p2, self.patch_width]).transpose([0, 2, 4, 1, 3, 5]).reshape([N, self.num_patches, -1])
        x = self.patch_embed(patches)
        x = x.flatten(2)
        return x, patches


class ViT(nn.Layer):
    def __init__(
        self, 
        image_size, 
        patch_size, 
        num_classes, 
        depth, 
        heads,
        mlp_dim, 
        embed_dim=768,
        pool='cls',
        channels=3,
        dim_head=64,
        dropout=0,
        embed_dropout=0.,
        ):
        super().__init__()

        assert pool in {'cls', 'mean'},  'pool type nums be either cls(cls token) or mean (mean pooling).'
        self.embed_dim = embed_dim
        self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim=embed_dim, in_channels=channels)
        self.num_patches = self.patch_embedding.num_patches
        
        self.pos_embedding = self.create_parameter(shape=[1, self.num_patches + 1, embed_dim], default_initializer=nn.initializer.KaimingNormal(0.02))
        self.cls_token = self.create_parameter(shape=[1, 1, embed_dim])
        self.dropout = nn.Dropout(embed_dropout)

        self.transformer = Transformer(embed_dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes),
        )

    def forward(self, x):
        x, patches = self.patch_embedding(x)

        B, N, _ = x.shape
        cls_tokens = paddle.tile(self.cls_token, [B, 1, 1])
        x = paddle.concat([cls_tokens, x], axis=1)
        x += self.pos_embedding[:, :(N + 1)]
        x = self.dropout(x)

        x = self.transformer(x)
        x = x.mean(axis=1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        x = self.mlp_head(x)

        return x



# if __name__ == '__main__':
#     model = ViT(image_size=(256,256), 
#         patch_size=(32,32), 
#         num_classes=1000, 
#         embed_dim=1024,
#         heads=8,
#         depth=6, 
#         mlp_dim=2048, )
#     x = paddle.randn([2, 3, 256, 256])
    
#     y = model(x)
#     print(x.shape, y.shape)
#     paddle.summary(model, (4, 3, 256, 256))

MAE

MAE的encoder就是ViT,decoder是一个transformer模型。

import paddle
from paddle import nn
import paddle.nn.functional as F


class MAE(nn.Layer):
    def __init__(self, encoder, decoder_dim, masking_ratio=0.75, decoder_depth=1, decoder_heads=8, decoder_dim_head=64):
        super().__init__()
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must in range (0, 1), but got {}.'.format(masking_ratio)
        self.masking_ratio = masking_ratio
        self.encoder = encoder

        patch_dim = self.encoder.patch_embedding.patch_embed.weight.shape[0] # 划分后每个patches的dim

        self.enc_to_dec = nn.Linear(encoder.embed_dim, decoder_dim) if encoder.embed_dim != decoder_dim else Identity()
        self.mask_token = self.create_parameter(shape=(1, 1, decoder_dim))  # mask_token 共享的可学习参数
        self.decoder = Transformer(dim=decoder_dim, depth=decoder_depth, heads=decoder_heads, dim_head=decoder_dim_head, mlp_dim=decoder_dim*4) # 解码器
        self.decoder_pos_emb = nn.Embedding(encoder.num_patches, decoder_dim) # decoder position embedding
        self.to_pixels = nn.Linear(decoder_dim, patch_dim)

    def forward(self, x):
        tokens, patches = self.encoder.patch_embedding(x) # patches 是在原图划分的patches,用作target
        
        batch, num_patches, _ = tokens.shape # batch_size, num_patches, _
        tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]

        # mask part of patches,均匀分布采样
        num_masked = int(self.masking_ratio * num_patches)
        rand_indices = paddle.rand(shape=[batch, num_patches]).argsort(axis=-1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # unmasked tokens to be encoded
        batch_range = paddle.arange(batch)[:, None]
        tokens = tokens[batch_range, unmasked_indices]
   
        # masked_patches
        masked_patches = patches[batch_range, masked_indices] # 仅在masked patches计算损失
        
        # transformer
        encoded_tokens = self.encoder.transformer(tokens)
        decoder_tokens = self.enc_to_dec(encoded_tokens)

        # decoder embed
        mask_tokens = paddle.tile(self.mask_token, [batch, num_masked, 1]) # decoder position embedding
        mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) # learned mask token

        decoder_tokens = paddle.concat([mask_tokens, decoder_tokens], axis=1) # 不需要unshuffle
        decoded_tokens = self.decoder(decoder_tokens)
     
        if self.training:
            mask_tokens = decoded_tokens[:, :num_masked]
            pred = self.to_pixels(mask_tokens) # N, num_unmasked, dim
            loss = F.mse_loss(pred, masked_patches)
            return loss
        else:
            image = patches.clone() # 采样后的图
            image.stop_gradient = True
            image[batch_range, masked_indices] = 0 # mask sampling area
            pred = self.to_pixels(decoded_tokens)
            return pred, image




# if __name__ == '__main__':
#     encoder = ViT(image_size=256, 
#         patch_size=32, 
#         num_classes=1000, 
#         embed_dim=1024,
#         heads=8,
#         depth=6, 
#         mlp_dim=2048)
#     model = MAE(encoder, masking_ratio=0.75, decoder_dim=512, decoder_depth=6)
#     x = paddle.randn([4, 3, 256, 256])
#     y = model(x)
#     print(x.shape, y.shape)
#     paddle.summary(model, (4, 3, 256, 256))

dataset

# 构建dataset
from paddle.io import Dataset, DataLoader
import paddle.vision.transforms as T
import cv2
import os

class ImageNetDataset(Dataset):
    def __init__(self, data_dir, info_txt, mode='train', transforms=None):
        self.data_dir = data_dir
        self.image_paths, self.labels = self.get_info(info_txt)
        self.mode = mode
        self.transforms = transforms

    def get_info(self, file_path):
        paths = []
        labels = []
        with open(file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                image_name, label = line.strip().split(' ')
                paths.append(os.path.join(self.data_dir, image_name))
                labels.append(int(label))
        return paths, labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(image_path)
        if self.transforms:
            image = self.transforms(image)
        if self.mode == 'train':
            return image, label
        else:
            return image

# mae_train_trans = T.Compose(
#     [
#         T.Resize((256, 256)),
#         T.RandomHorizontalFlip(),
#         T.RandomVerticalFlip(),
#         T.Transpose([2, 0, 1]),
#     ]
# )


# if __name__ == '__main__':
#     dataset = ImageNetDataset('/home/aistudio/data/data89857/ILSVRC2012mini/train', '/home/aistudio/data/data89857/ILSVRC2012mini/train_list.txt', mode='val', transforms=mae_train_trans)
#     print(len(dataset))
#     image = dataset[0]
#     import matplotlib.pyplot as plt
#     plt.imshow(image)
#     plt.show()

预训练MAE

# 辅助类
class AverageMeter:
    def __init__(self):
        self.val = 0.
        self.count = 0.

    def update(self, value, n=1):
        self.val += value
        self.count += n

    def reset(self):
        self.val = 0.
        self.count = 0.

    def __call__(self):
        return self.val / self.count
# 设置相关参数
import time

epoches = 2000
batch_size = 256
learning_rate = 0.00001
grad_clip_value = 10

# encoder param
patch_size = (32, 32) 
image_size = (256, 256)
num_classes = 1000
encoder_embed_dim = 1024
encoder_heads = 8
encoder_depth = 6
encoder_mlp_dim = 2048

# decoder params
masking_ratio = 0.5
decoder_dim = 512
decoder_depth = 6


mae_train_trans = T.Compose(
    [
        T.Resize((256, 256)),
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.Transpose([2, 0, 1]),
    ]
)
# mode = 'val',因为预训练不需要label,加上也可以
mae_dataset = ImageNetDataset('/home/aistudio/data/data89857/ILSVRC2012mini/train', '/home/aistudio/data/data89857/ILSVRC2012mini/train_list.txt', mode='val', transforms=mae_train_trans)
mae_dataloader = DataLoader(
    mae_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

# MAE model
encoder = ViT(image_size=image_size, 
    patch_size=patch_size, 
    num_classes=num_classes, 
    embed_dim=encoder_embed_dim,
    heads=encoder_heads,
    depth=encoder_depth, 
    mlp_dim=encoder_mlp_dim)
model = MAE(encoder, masking_ratio=masking_ratio, decoder_dim=decoder_dim, decoder_depth=decoder_depth)
# paddle.summary(model, (4, 3, 256, 256))
clip = paddle.nn.ClipGradByValue(min=-grad_clip_value, max=grad_clip_value)
optimizer = paddle.optimizer.Momentum(learning_rate=learning_rate, parameters=model.parameters(), grad_clip=clip)

# 测试函数,用一张图片可视化训练过程
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def reconstruct(x, image_size, patch_size):
    """reconstrcunt [batch_size, num_patches, embedding] -> [batch_size, channels, h, w]"""
    B, N, _ = x.shape  # batch_size, num_patches, dim

    p1, p2 = image_size[0] // patch_size[0], image_size[1] // patch_size[1] 
    x = x.reshape([B, p1, p2, -1, patch_size[0], patch_size[1]]).transpose([0, 3, 1, 4, 2, 5]).reshape([B, -1, image_size[0], image_size[1]])
    return x

def test(model):
    """
    使用模型预测一个图片,查看效果,可以看出模型训练过程预测的变化趋势
    """
    model.eval()
    image_path = '/home/aistudio/data/data89857/ILSVRC2012mini/val/ILSVRC2012_val_00040043.JPEG'
    source_image = cv2.imread(image_path)
    trans = T.Compose(
        [
            T.Resize((256, 256)),
            T.Transpose([2, 0, 1]),
        ]
    )
    source_image = trans(source_image)
    image = paddle.to_tensor(source_image, dtype='float32').unsqueeze(0)
    pred, masked_img = model(image)
    pred_img = reconstruct(pred, image_size, patch_size)
    masked_img = reconstruct(masked_img, image_size, patch_size)

    masked_img = masked_img[0].numpy()
    masked_img = np.clip(masked_img, 0, 255).astype('uint8')
    masked_img = np.transpose(masked_img, [1, 2, 0])

    pred_img = pred_img[0].numpy()
    pred_img = np.clip(pred_img, 0, 255).astype('uint8')
    pred_img = np.transpose(pred_img, [1, 2, 0])

    plt.subplot(1, 3, 1)
    plt.imshow(source_image.transpose([1, 2, 0]))
    plt.subplot(1, 3, 2)
    plt.imshow(masked_img)
    plt.subplot(1, 3, 3)
    plt.imshow(pred_img)
    plt.show()
    return pred_img

# 训练

model.train()
for epoch in range(1, epoches + 1):
    losses = AverageMeter()
    for batch_id, image in enumerate(mae_dataloader):
        image = image.astype('float32')
        loss = model(image)
        losses.update(loss.numpy()[0])

        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
        lr = optimizer.get_lr()
        if batch_id % 50 == 0:
            print(time.asctime( time.localtime(time.time()) ), "Epoch: {}/{}, Batch id: {}, lr: {}, loss: {}".format(epoch, epoches, batch_id, lr, losses()))
    obj = {
        'model': encoder.state_dict(),
        'epoch': epoch,
    }
    paddle.save(obj, 'model.pdparams')
    obj = {
        'model': model.state_dict(),
        'epoch': epoch,
    }
    paddle.save(obj, 'mae.pdparams')

    test(model) # 这里会变成eval模式
    model.train() # 转成train模式
  

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐