FS-CS 用于分类和分割的综合少样本学习

论文名称: Integrative Few-Shot Learning for Classification and Segmentation

论文Github项目:https://github.com/dahyun-kang/ifsl

论文源代码在./ifsl文件夹中,ifsl/fs-cs/model/ifsl.py写着一些损失的计算,训练的一些信息

1. 关于few-shot learning的一点知识:

在这里,我必须强推这个ai studio的项目,小样本学习(Few-Shot Learning),我就是通过这个项目从零学习少样本学习,强推!

few-shot learning为少样本学习,query是输入的查询图片,support set是让模型判断query和support set里面哪个类长的像。support set 为N-way k-shot 则代表一共有N个类,每个类有k张图片,k是很少的,因为模型主要不是依靠support set训练的,我们只是用support set作为一定的fine-tuning。

2. 一个新任务FS-CS的提出:

这里就是该论文的贡献1:We introduce the task of integrative few-shot classification and segmentation (FS-CS), which combines few-shot classification and few-shot segmentation into an integrative task by addressing their limitations.

给定一个query image和N-way K-shot support set,我们的目标是在这个query image中识别support set中每个类是否存在,并去预测foreground mask from the query,我们称之为综合的少镜头分类和分割(FS-CS)。

这个任务和普通的few-shot learning 的区别在于什么,在于普通的少样本学习是假定给出的query这张图片中是有support set中的唯一类的,并且support set中各个类应该不允许存在包含关系。

在这里这点是很关键的,为啥呢?比如support set一共有5类,{苹果,橘子,水果,天空,西瓜},但是注意我这个5个类的设计,苹果橘子西瓜都是水果,也就是在这个新的任务中,各个类是可以存在包含关系的。然后这个时候我给的查询图片是一个摆着水果的果盘,意不意外,这个果盘呢上面的水果有橘子和苹果。也就是我们需要模型一个个把query与support set中的每一个类的图片进行比对,然后判别出这个query图片有橘子,苹果,水果三个类,而不是只是这个三类中的一类。第二种情况此时我们给的查询图片为地板,那么我就需要模型判断出query有0个类,而不是强行给query分配一个类。

然后第二点就是让分割和分类联动起来,什么意思呢?就是比如query为一个摆着橘子苹果的果盘,我不单单想模型识别出三个类,我们还需要让模型分别输出各个类的mask,把橘子,苹果和整体水果分别分割出来。当然了,这个时候你或许还有个疑问,那么岂不是我必须要求support set必须有分割信息吗,如果我只有分类信息呢?这里你不需要担心,你可以选择是否使用分割信息。

这里我小小的总结下,FS-CS对于某个任务的有了一个更好的扩展性,并且将分类和分割丝滑的联动起来。

3. 针对FS-CS一个有效的学习框架Integrative Few-Shot Learning (iFSL)

为了解决FS-CS问题,我们提出了一个有效的学习框架,即综合少样本学习(iFSL)。iFSL框架被设计来jointly solve few-shot classification and few-shot segmentation using either a class tag or a segmentation supervision。The integrative few-shot learner f将查询图像x和支持集S作为输入,然后生成the class-wise foreground maps作为输出。The set of class-wise foreground maps Y集由N个类的Y(n)∈RH×W组成:

在这里插入图片描述

然后我们通过这个y得到分类和分割结果。y.shape = [batch_size,n_ways,2,h,w] ,对于y的每个位置p,p对于n_ways个类都有存在和不存在的两种可能性,所以有一个维度为2,这个维度index0为不存在,1为存在

推理阶段:
在这里插入图片描述

  1. 然后y如果有一个位置p对于N_ways中的某个类的存在的概率大于0.5,那么我们就认为这个query存在该类,也就是是判断最大值,而不是平均值,因为论文说如果判断平均值会让模型容易错过占图片面积较小的物体。

在这里插入图片描述

  1. 然后分割结果,我们会得到一个背景类,背景类为y[:, :, 0].mean(dim=1) shape为[bs,1,h,w],那么n个类为y[:, :, 1] shape为[bs,n_ways,h,w],然后进行1维进行concat,每个位置看N+1个可能性中哪个最大。

iFSL框架允许learner使用类标签或分割注释进行训练,一个是每张图片,一个是逐个像素,使用分类标签则损失为在这里插入图片描述

分割损失为在这里插入图片描述

4. 论文提出的框架

在这里插入图片描述

请从左往右看,首先第一步我们把query_image[batch_size,3,400,400]和support_image[batch_size*way,3,400,400]输入到预训练好的resnet101中,得到两组图片的30层特征(就是特征金字塔),然后计算query和support每层特征之间的余弦相似度(就是两个特征金字塔的相关性,就是query和对应的每张support图片的相关性用两个特征金字塔的余弦相似度表示)。然后分成3部分进行concat,得到3组输入.这么讲述,我知道很不直观,听不懂,放心论文讲的也就书面化专业一点,也很不直观,所以建议直接看下方代码块代码。

然后我说一下整体的设计思路,就是输入的为[batchsize*nways,channels,h_query,w_query,h_support,w_support],我们想每个query map每个位置都对应support,然后我们需要输出[batchsize*nways,channels,h_query,w_query,1,1] 从而如何把(h_support,w_support)逐渐downsample到(1,1)这就是通过AS layer中的stride>1实现的。

这里我强调几个重点:

  1. 这个代码实现的尺寸要求很精密,比如输入的图片尺寸为400,400就不能改,如果改小,后面AS其实就是self attention下采样的时候就会告诉你输出尺寸小于0,因为这个尺寸计算精确到每一个stride和kenerl size。

  2. 还有一点就是59行pool等于True,这就导致框架图最左下角蓝色的正方形a1等于上方蓝色正方形a2的大小,而不是与a1相连的红色正方形b1大小相等。这个很重要,也不能乱改,因为注意此条代码:

hypercorr_encoded = hypercorr_mix432.reshape([bsz, ch, ha, wa, -1]).squeeze(-1)
最后一个维度只有为1才能被squeeze掉

  1. 所以这个模型对于尺寸很精密,别改,我看了1天就是发现运行不成功,最后才发现模型对于尺寸的严苛程度。

在这里插入图片描述


'''
correlation.py中得到特征金字塔相关性

support_feat = support_feat / (support_feat.norm(axis=1, p=2, keepdim=True) + eps)
query_feat = query_feat / (query_feat.norm(axis=1, p=2, keepdim=True) + eps) 是为了除以各自向量的模

query_feat = query_feat.repeat_interleave(way, axis=0)是为了query可以和对应的support set每张图片对应上,从而计算两张图片特征金字塔的相关性,训练的时候way为1,inference时候way则不一定为1。

stack_ids = paddle.to_tensor([ 3, 26, 30])这是针对于resnet101的,主要这个设置是因为特征的hw不一样。

 corr = corr.clip(min=0) 就起到RELU的作用
'''

class Correlation:

    @classmethod
    def multilayer_correlation(cls, query_feats, support_feats, stack_ids, way=1):
        eps = 1e-5

        corrs = []
        for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)):
            bszs, ch, hb, wb = support_feat.shape
            support_feat = support_feat.reshape([bszs, ch, -1])
            support_feat = support_feat / (support_feat.norm(axis=1, p=2, keepdim=True) + eps)

            bszq, ch, ha, wa = query_feat.shape
            query_feat = query_feat.reshape([bszq, ch,-1])
            query_feat = query_feat / (query_feat.norm(axis=1, p=2, keepdim=True) + eps)

            if way > 1:
                query_feat = query_feat.repeat_interleave(way, axis=0)

            corr = paddle.matmul(query_feat.transpose([0,2,1]), support_feat).reshape([bszs, ha, wa, hb, wb])
            corr = corr.clip(min=0)
            corrs.append(corr)

        corr_l4 = paddle.stack(corrs[-stack_ids[0]:]).transpose([1,0,2,3,4,5])
        corr_l3 = paddle.stack(corrs[-stack_ids[1]:-stack_ids[0]]).transpose([1,0,2,3,4,5])
        corr_l2 = paddle.stack(corrs[-stack_ids[2]:-stack_ids[1]]).transpose([1,0,2,3,4,5])

        return [corr_l4, corr_l3, corr_l2]

#因为pytorch使用了einops这个库,然后这是肖佬自制的paddle版本的,进行pip下载
!pip install einops-0.3.0-py3-none-any.whl

from einops import rearrange
import paddle
hypercorr = paddle.randn([3,2,4,2,4,5])
hypercorr = rearrange(hypercorr, 'b c d t h w -> (b h w) c d t')

print(hypercorr)

import paddle.nn.functional as F
from paddle.vision.models import resnet101
from operator import add
from functools import reduce
import paddle
from correlation import Correlation
def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids, pool=False, pool_thr=50):
    r""" Extract intermediate features from ResNet"""
    feats = []

    # Layer 0
    feat = backbone.conv1.forward(img)
    feat = backbone.bn1.forward(feat)
    feat = backbone.relu.forward(feat)
    feat = backbone.maxpool.forward(feat)

    # Layer 1-4
    for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)):
        res = feat
        feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)

        if bid == 0:
            res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)

        feat += res

        if hid + 1 in feat_ids:
            if pool and feat.shape[-1] >= pool_thr:
                feats.append(F.avg_pool2d(feat.clone(), kernel_size=3, stride=2, padding=1))
            else:
                feats.append(feat.clone())

        feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)

    return feats

backbone = resnet101(pretrained=True)
feat_ids = list(range(4, 34))
extract_feats = extract_feat_res
nbottlenecks = [3, 4, 23, 3]
way = 1
batch_size = 6
query_img = paddle.randn([batch_size,3,400,400])
support_img = paddle.randn([batch_size*way,3,400,400])

bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
# print(lids)
# stack_ids = paddle.tensor(lids).bincount().__reversed__().cumsum(dim=0)[:3]
stack_ids = paddle.to_tensor([ 3, 26, 30])
query_feats = extract_feats(query_img, backbone, feat_ids, bottleneck_ids, lids)
support_feats =  extract_feats(support_img, backbone, feat_ids, bottleneck_ids, lids,pool =True)
corr = Correlation.multilayer_correlation(query_feats, support_feats,stack_ids, way)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:654: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")
'''
特征shape建议好好看看,方便理解上方的提取的是哪几层的特征
'''

len(query_feats)#30

for i in query_feats:
    print(i.shape)

print("-----------------------------")


for i in support_feats:
    print(i.shape)

print("-----------------------------")

for i in corr:
    print(i.shape)

5.Attentive squeeze layer (AS layer)

嗯,整体架构就是self-attention 加 MLP,这个东西其实很熟了,只是换个地方用一下。

在这里插入图片描述

import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from einops import rearrange


class AttentiveSqueezeLayer(nn.Layer):
    """
    Attentive squeeze layer consisting of a global self-attention layer followed by a feed-forward MLP
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, heads=8, groups=4, pool_kv=False):
        super(AttentiveSqueezeLayer, self).__init__()
        self.attn = Attention(in_channels, out_channels, kernel_size, stride, padding, bias, heads, groups, pool_kv)
        self.ff = FeedForward(out_channels, groups)

    def forward(self, input):
        x, support_mask = input
        batch, c, qh, qw, sh, sw = x.shape
        x = rearrange(x, 'b c d t h w -> b c (d t) h w')
        out = self.attn((x, support_mask))
        out = self.ff(out)
        out = rearrange(out, 'b c (d t) h w -> b c d t h w', d=qh, t=qw)
        return out, support_mask


class Attention(nn.Layer):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, heads=8, groups=4, pool_kv=False):
        super(Attention, self).__init__()
        self.heads = heads
        '''
        Size of conv output = floor((input  + 2 * pad - kernel) / stride) + 1
        The second condition of `retain_dim` checks the spatial size consistency by setting input=output=0;
        Use this term with caution to check the size consistency for generic cases!
        '''
        retain_dim = in_channels == out_channels and math.floor((2 * padding - kernel_size) / stride) == -1
        hidden_channels = out_channels // 2
        assert hidden_channels % self.heads == 0, "out_channels should be divided by heads. (example: out_channels: 40, heads: 4)"

        ksz_q = (1, kernel_size, kernel_size)
        str_q = (1, stride, stride)
        pad_q = (0, padding, padding)
        # print(ksz_q,str_q,pad_q)
        self.short_cut = nn.Sequential(
            nn.Conv3D(in_channels, out_channels, kernel_size=ksz_q, stride=str_q, padding=pad_q, bias_attr=False),
            nn.GroupNorm(groups, out_channels),
            nn.ReLU()
        ) if not retain_dim else nn.Identity()

        # Convolutional embeddings for (q, k, v)
        self.qhead = nn.Conv3D(in_channels, hidden_channels, kernel_size=ksz_q, stride=str_q, padding=pad_q, bias_attr=bias)

        ksz = (1, kernel_size, kernel_size) if pool_kv else (1, 1, 1)
        str = (1, stride, stride) if pool_kv else (1, 1, 1)
        pad = (0, padding, padding) if pool_kv else (0, 0, 0)

        self.khead = nn.Conv3D(in_channels, hidden_channels, kernel_size=ksz, stride=str, padding=pad, bias_attr=bias)
        self.vhead = nn.Conv3D(in_channels, hidden_channels, kernel_size=ksz, stride=str, padding=pad, bias_attr=bias)

        self.agg = nn.Sequential(
            nn.GroupNorm(groups, hidden_channels),
            nn.ReLU(),
            nn.Conv3D(hidden_channels, out_channels, kernel_size=1, stride=1, bias_attr=False),
            nn.GroupNorm(groups, out_channels),
            nn.ReLU()
        )
        self.out_norm = nn.GroupNorm(groups, out_channels)

    def forward(self, input):
        x, support_mask = input #[batchsize,channels,query_h*query_y,support_x,support_y] 因为第二维为query的h和w所以卷积参数会为ksz str pad第0个为1 1 0
        # print("x.shape",x.shape)
        x_ = self.short_cut(x)
        q_out = self.qhead(x)
        k_out = self.khead(x)
        v_out = self.vhead(x)

        q_h, q_w = q_out.shape[-2:]
        k_h, k_w = k_out.shape[-2:]

        q_out = rearrange(q_out, 'b (g c) t h w -> b g c t (h w)', g=self.heads)
        k_out = rearrange(k_out, 'b (g c) t h w -> b g c t (h w)', g=self.heads)
        v_out = rearrange(v_out, 'b (g c) t h w -> b g c t (h w)', g=self.heads)
        # print("qout.shape",q_out.shape,"k_out.shape",k_out.shape)
        out = paddle.einsum('b g c t l, b g c t m -> b g t l m', q_out, k_out)
        if support_mask is not None:
            out = self.attn_mask(out, support_mask, spatial_size=(k_h, k_w))
        out = F.softmax(out, axis=-1)
        out = paddle.einsum('b g t l m, b g c t m -> b g c t l', out, v_out)
        out = rearrange(out, 'b g c t (h w) -> b (g c) t h w', h=q_h, w=q_w)
        out = self.agg(out)

        return self.out_norm(out + x_)

    def attn_mask(self, x, mask, spatial_size):
        assert mask is not None
        mask = F.interpolate(mask.float().unsqueeze(1), spatial_size, mode='bilinear', align_corners=True)
        mask = rearrange(mask, 'b 1 h w -> b 1 1 1 (h w)')
        out = x.masked_fill_(mask == 0, -1e9)
        return out


class FeedForward(nn.Layer):
    def __init__(self, out_channels, groups=4, size=2):
        super(FeedForward, self).__init__()
        hidden_channels = out_channels // size
        self.ff = nn.Sequential(
            nn.Conv3D(out_channels, hidden_channels, kernel_size=1, stride=1, padding=0, bias_attr=False),
            nn.GroupNorm(groups, hidden_channels),
            nn.ReLU(),
            nn.Conv3D(hidden_channels, out_channels, kernel_size=1, stride=1, padding=0, bias_attr=False),
        )
        self.out_norm = nn.GroupNorm(groups, out_channels)

    def forward(self, x):
        x_ = x
        out = self.ff(x)
        return self.out_norm(out + x_)

# from paddle_aslayer import AttentiveSqueezeLayer
import paddle.nn as nn
from einops import rearrange
class AttentionLearner(nn.Layer):
    def __init__(self, inch, way):
        super(AttentionLearner, self).__init__()
        self.way = way

        def make_building_attentive_block(in_channel, out_channels, kernel_sizes, spt_strides, pool_kv=False):
            building_block_layers = []
            for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)):
                inch = in_channel if idx == 0 else out_channels[idx - 1]
                padding = ksz // 2 if ksz > 2 else 0
                building_block_layers.append(AttentiveSqueezeLayer(inch, outch, ksz, stride, padding, pool_kv=pool_kv))

            return nn.Sequential(*building_block_layers)

        self.encoder_layer4 = make_building_attentive_block(inch[0], [32, 128], [5, 3], [4, 2])
        self.encoder_layer3 = make_building_attentive_block(inch[1], [32, 128], [5, 5], [4, 4], pool_kv=True)
        self.encoder_layer2 = make_building_attentive_block(inch[2], [32, 128], [5, 5], [4, 4], pool_kv=True)

        self.encoder_layer4to3 = make_building_attentive_block(128, [128, 128], [1, 2], [1, 1])
        self.encoder_layer3to2 = make_building_attentive_block(128, [128, 128], [1, 2], [1, 1])

        # Decoder layers
        self.decoder1 = nn.Sequential(nn.Conv2D(128, 128, (3, 3), padding=(1, 1), bias_attr=True),
                                      nn.ReLU(),
                                      nn.Conv2D(128, 64, (3, 3), padding=(1, 1), bias_attr=True),
                                      nn.ReLU())

        self.decoder2 = nn.Sequential(nn.Conv2D(64, 64, (3, 3), padding=(1, 1), bias_attr=True),
                                      nn.ReLU(),
                                      nn.Conv2D(64, 2, (3, 3), padding=(1, 1), bias_attr=True))

    def interpolate_query_dims(self, hypercorr, spatial_size):
        bsz, ch, ha, wa, hb, wb = hypercorr.shape
        hypercorr = rearrange(hypercorr, 'b c d t h w -> (b h w) c d t')
        # (B H W) C D T -> (B H W) C * spatial_size
        hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True)
        return rearrange(hypercorr, '(b h w) c d t -> b c d t h w', b=bsz, h=hb, w=wb)

    def forward(self, hypercorr_pyramid, support_mask):
        # print(hypercorr_pyramid[0].shape)
        hypercorr_sqz4 = self.encoder_layer4((hypercorr_pyramid[0], support_mask))[0]
        hypercorr_sqz3 = self.encoder_layer3((hypercorr_pyramid[1], support_mask))[0]
        hypercorr_sqz2 = self.encoder_layer2((hypercorr_pyramid[2], support_mask))[0]
        # print(hypercorr_sqz4.shape,hypercorr_sqz3.shape,hypercorr_sqz2.shape)
        hypercorr_sqz4 = hypercorr_sqz4.mean(axis = [-1, -2], keepdim=True)
        hypercorr_sqz4 = self.interpolate_query_dims(hypercorr_sqz4, hypercorr_sqz3.shape[-4:-2])
        hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3
        # print(hypercorr_mix43.shape)
        hypercorr_mix43 = self.encoder_layer4to3((hypercorr_mix43, support_mask))[0]

        hypercorr_mix43 = self.interpolate_query_dims(hypercorr_mix43, hypercorr_sqz2.shape[-4:-2])
        hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2
        hypercorr_mix432 = self.encoder_layer3to2((hypercorr_mix432, support_mask))[0]

        bsz, ch, ha, wa, hb, wb = hypercorr_mix432.shape
        # print("hypercorr_mix432.shape",hypercorr_mix432.shape)
        hypercorr_encoded = hypercorr_mix432.reshape([bsz, ch, ha, wa, -1]).squeeze(-1)

        hypercorr_decoded = self.decoder1(hypercorr_encoded)
        upsample_size = (hypercorr_decoded.shape[-1] * 2,) * 2
        hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True)
        logit_mask = self.decoder2(hypercorr_decoded)
        logit_mask = logit_mask.reshape([-1, self.way, *logit_mask.shape[1:]])

        # B, N, 2, H, W
        return logit_mask
m = AttentionLearner(list(reversed(nbottlenecks[-3:])),way)(corr,None)
print(m.shape)
[6, 1, 2, 100, 100]
import paddle.nn.functional as F
class AttentiveSqueezeNetwork(nn.Layer):
    def __init__(self,way = 1):
        super(AttentiveSqueezeNetwork, self).__init__()

        self.backbone = resnet101(pretrained=True)
        self.feat_ids = list(range(4, 34))
        self.extract_feats = extract_feat_res
        self.nbottlenecks = [3, 4, 23, 3]
        self.way = way
        self.weak =True #如果是True代表使用分类标签,如果为False代表使用分割标签
        self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
        self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
        self.stack_ids = paddle.to_tensor([ 3, 26, 30])

        self.backbone.eval()
        self.learner = AttentionLearner(list(reversed(nbottlenecks[-3:])), self.way)

    def forward(self, batch):
        '''
        query_img.shape : [bsz, 3, H, W]
        support_imgs.shape : [bsz, way, 3, H, W]
        support_masks.shape : [bsz, way, H, W]
        '''

        support_img = rearrange(batch['support_imgs'], 'b n c h w -> (b n) c h w')
        support_mask = None if self.weak else rearrange(batch['support_masks'], 'b n h w -> (b n) h w')
        query_img = batch['query_img']

        # 因为backbone不训练,所以使用no_grad()得到两个特征金字塔的相关性
        with paddle.no_grad():
            query_feats = self.extract_feats(query_img, self.backbone,  self.feat_ids, self.bottleneck_ids, self.lids)
            support_feats =  self.extract_feats(support_img, self.backbone,  self.feat_ids, self.bottleneck_ids, self.lids,pool =True)
            corr = Correlation.multilayer_correlation(query_feats, support_feats,self.stack_ids, self.way)
        # learner就是需要模型学习的部分
        shared_masks = self.learner(corr, support_mask)

        # B, N, 2, H, W
        shared_masks = F.softmax(shared_masks, axis=2)
        
        return shared_masks

way = 1
batch_size = 6
query_img = paddle.randn([batch_size,3,400,400])
support_img = paddle.randn([batch_size,way,3,400,400])
batch = {"query_img":query_img,"support_imgs":support_img}
AttentiveSqueezeNetwork()(batch).shape
[6, 1, 2, 100, 100]

6. 训练部分

这里我只是作为一个样例,实际训练可看原论文paper和代码

model = AttentiveSqueezeNetwork()
from paddle.vision.transforms import Compose, ColorJitter, Resize,ToTensor

transform = Compose([ColorJitter(), Resize(size=608)])
def compute_cls_objective(shared_masks, gt_presence):
    ''' supports 1-way training '''
    # 因为nll loss 没有log所以需要先进行log处理
    shared_masks = paddle.log(shared_masks)
    # B, N, 2, H, W -> B, N, 2 -> B, 2
    prob_avg = shared_masks.mean(axis=[-1, -2]).squeeze(1)
    # print(prob_avg)
    return F.nll_loss(prob_avg, gt_presence)
import cv2
orange1 = cv2.imread("orange1.png")
orange2 = cv2.imread("orange2.png")
apple1 = cv2.imread("apple1.png")
apple2 = cv2.imread("apple2.png")
transform = Compose([Resize(size=(400,400)),ToTensor()])
orange1 = transform(orange1)
orange2 = transform(orange2)
apple1 = transform(apple1)
apple2 = transform(apple2)

batch_query = [orange1,apple1,apple1,orange2,apple2]
batch_support = [orange2.unsqueeze(0),apple2.unsqueeze(0),orange1.unsqueeze(0),apple1.unsqueeze(0),orange2.unsqueeze(0)]
batch_label = paddle.to_tensor([1,1,0,0,0])
batch = {"query_img":paddle.stack(batch_query,axis=0),"support_imgs":paddle.stack(batch_support,axis=0),"label":batch_label}



predictor = AttentiveSqueezeNetwork()
predictor.train()

# 只学习Softmax分类器中的参数W和b
opt = paddle.optimizer.Adam(learning_rate=1e-4,
                            parameters=predictor.parameters())
epochs = 40
use_entropy_regularization = True

for epoch in range(epochs):
    # 将整个Support set作为训练数据
    input_datas = []
    labels = batch["label"]

    logits = predictor(batch)
    loss = compute_cls_objective(logits, labels)
    # print(loss)
    if use_entropy_regularization:
        # 计算Entropy Regularization,向量p的entropy等于自己和自己的cross_entropy
        prob_avg = logits.mean(axis=[-1, -2]).squeeze(1)
        entropy_regularization = paddle.nn.functional.cross_entropy(prob_avg, prob_avg,
                                                              use_softmax=False, soft_label=True)
        # print(entropy_regularization)
        loss += entropy_regularization

    if (epoch + 1) % 20 == 0:
        print("epoch: {}, loss is: {}".format(epoch+1, loss.numpy()))
        paddle.save(predictor.state_dict(), 'models/predictor.pdparams')

    loss.backward()
    opt.step()
    opt.clear_grad()

# 训练完保存参数
paddle.save(predictor.state_dict(), 'models/predictor.pdparams')
epoch: 20, loss is: [0.09264192]

7. 测试部分

  1. 由于训练的使用只有苹果和橘子,首先我测试query为苹果,support为橘子和苹果,模型分类成功
  2. 我然后测试query为香蕉,support不变,结果都为False,充分说明了这个iFSL框架的通用性
apple1 = transform(apple1)
apple2 = transform(apple2)
apple_test = cv2.imread("apple_test.png")
apple_test = transform(apple_test)
banana_test = cv2.imread("banana_test.png")
banana_test = transform(banana_test)


predictor = AttentiveSqueezeNetwork(2)
predictor.set_state_dict(paddle.load("./predictor.pdparams"))
predictor.eval()


batch_query = [apple_test]
batch_support = [orange2.unsqueeze(0),apple2.unsqueeze(0)]
batch_label = paddle.to_tensor([1,1])
batch = {"query_img":paddle.stack(batch_query,axis=0),"support_imgs":paddle.stack(batch_support,axis=0)}




logits = predictor(batch)
result = logits[:, :, 1].flatten(2).mean(axis =-1) >= 0.5
print(result)

##########################################################

batch_query = [banana_test]
batch_support = [orange2.unsqueeze(0),apple2.unsqueeze(0)]
batch_label = paddle.to_tensor([1,1])
batch = {"query_img":paddle.stack(batch_query,axis=0),"support_imgs":paddle.stack(batch_support,axis=0)}



logits = predictor(batch)
result = logits[:, :, 1].flatten(2).mean(axis =-1) >= 0.5
print(result)
Tensor(shape=[1, 2], dtype=bool, place=Place(gpu:0), stop_gradient=False,
       [[False, True ]])
Tensor(shape=[1, 2], dtype=bool, place=Place(gpu:0), stop_gradient=False,
       [[False, False]])

8. 总结

实际论文的精度在这里插入图片描述

这个模型参数175MB很大,然后采用的是Transformer 的设计,显存占用的问题,可以试试poolformer,但是论文的这个iFSL框架,把query与support set中每一个类中的图片一一进行分类,而不是和传统的一次性计算query在support每一个类的概率一样,通用性得到明显的提升。我只是使用这个模型的分类,你们也可以自己试试分割。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐