视觉图神经网络:一张图片可以看作是图的节点

Vision GNN: An Image is Worth Graph of Nodes. Pytorch Code

本文提出了一种 GNN 通用视觉模型,是来自中国科学院大学,北京华为诺亚方舟实验室的学者们在通用视觉模型方面有价值的探索。

1. 背景和动机

     在现代计算机视觉任务中,通用视觉模型最早以 CNN 为主。近期 Vision Transformer,Vision MLP 为代表的新型主干网络的研究进展将通用视觉模型推向了一个前所未有的高度。

     不同的主干网络对于输入图片的处理方式也不一样,如下图所示是一张图片的网格表示,序列表示和图表示。图像数据通常表示为欧几里得空间 (Euclidean space) 中的规则像素网格,CNN 通过在图片上进行滑动窗口操作引入平移不变形和局部性。而 Vision Transformer,Vision MLP 为代表的新型主干网络将图片视为图片块的序列,比如一般将 224×224 大小的图片分为196个 16×16 的图片块。

     但是无论是上面的网格表示还是序列表示,图片都以一种非常规则的方式被建模了,也就是说,每个图片块之间的 “联系” 已经固化。比如图1中这条 “鱼” 的 “鱼头” 可能分布在多个图片块中,这些 Patch 按照网格表示或者序列表示都没有 “特殊” 的联系,但是它们在语义上其实都表示 “鱼头”。这或许就是传统的图片建模方法的不完美之处。
在这里插入图片描述

2.本文思路

     本文提出以一种更加灵活的方式来处理图片:计算机视觉的一个基本任务是识别图像中的物体。由于图片中的物体通常不是形状规则的方形,所以经典的网格表示或者序列表示在处理图片时显得冗余且不够灵活。比如一个对象可以被视为由很多部分的组合:例如,一个人可以粗略地分为头部、上身、手臂和腿,这些由关节连接的部分自然形成一个图结构。

     在网格表示中,像素或小块仅仅通过空间位置排序。在序列表示中,2D 图像被切分成为一系列小块。在图表示中,节点通过其内容链接起来,不受本地位置的约束。网格表示和序列表示都可以视为是图表示的特例。因此,将一张图片视为图是相对于前二者更加灵活且有效。

     本文基于把图片视为图表示的观点,本文提出一种基于图表示的新型通用视觉架构 ViG。将输入图像分成许多小块,并将每个小块视为图中的一个节点。在构建好了输入图片的图表征之后,作者使用 ViG 模型在所有节点之间交换信息。ViG 的基本组成单元包括两部分:用于图形信息处理的 GCN (图形卷积网络) 模块和用于节点特征变换的 FFN (前馈网络) 模块。在图像识别,目标检测等视觉任务中证明了该方法的有效性。

3.具体方法

3.1 一张图片的表示

     由于笔者不会在aistudio的Markdown里面打公式,所以就以图片方式展示原理。如下图所示
在这里插入图片描述

在这里插入图片描述

3.2使用图卷积网络作为骨干

在这里插入图片描述

3.3每一个图神经网络快的构成

在这里插入图片描述

在这里插入图片描述

4 模型架构

     由于本文只复现了VIG-s模型,而VIG-s模型是一个金子塔结构,所以这边我们展示一下VIG-s模型以及其他延伸版的结构。
在这里插入图片描述

讲了这么多原理,接下来咱们就来上代码!走起~

本文复现了VIG-s模型,基于PaddleViT,如果对PaddleViT感兴趣的同学可以去github上浏览。

代码位置:PaddleViT/imageclassification/VIG

5 VIG中核心组成块代码解读

VIG当中最核心的便是图神经网络block了,下面我们来展示其Paddle的实现方式。

class GCN_block(nn.Layer):
    def __init__(self,in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
                 bias=True,  stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
        super().__init__()
        self.channels = in_channels
        self.n = n #节点数
        self.r = r 
        self.fc1 = nn.Sequential(
            nn.Conv2D(in_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2D(in_channels),
        ) #映射层
        self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
                              act, norm, bias, stochastic, epsilon, r) #图卷积层
        self.fc2 = nn.Sequential(
            nn.Conv2D(in_channels * 2, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2D(in_channels),
        ) #映射层
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # VIT中的DropPath
        self.relative_pos = None
        if relative_pos:  #是否使用相对连接关系
            print('using relative_pos')
            relative_pos_tensor = paddle.to_tensor(np.float32(get_2d_relative_pos_embed(in_channels,
                int(n**0.5)))).unsqueeze(0).unsqueeze(1)
            relative_pos_tensor = F.interpolate(
                    relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
            self.relative_pos = add_parameter(self,-relative_pos_tensor.squeeze(1))
            self.relative_pos.stop_gradient=True

    def _get_relative_pos(self, relative_pos, H, W): 
        """
        这个是得到每一个节点在整个图中的相对位置
        """
        if relative_pos is None or H * W == self.n:
            return relative_pos
        else:
            N = H * W
            N_reduced = N // (self.r * self.r)
            return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)

    def forward(self, x):
        """
        这里便是图神经网络块的前向传播
        可以对照前面原理理解一下
        """
        _tmp = x
        x = self.fc1(x)
        B, C, H, W = x.shape
        relative_pos = self._get_relative_pos(self.relative_pos, H, W)
        x = self.graph_conv(x, relative_pos)
        x = self.fc2(x)
        x = self.drop_path(x) + _tmp
        return x

接着,就是VIG整个网络组网了,除了图神经网络块,其余部分和Vision Transformer类似。

具体的每个模块可以到PaddleViT/imageclassification/VIG下面查找

class DeepGCN(nn.Layer):
    def __init__(self,
                layers,
                k = 9,
                conv = 'mr',
                act = 'gelu',
                norm = 'batch',
                bias = True,
                dropout = 0.0,
                use_dilation = True,
                epsilon = 0.2,
                use_stochastic = False,
                drop_path = 0.0,
                channels = [48,96,240,384],
                n_classes = 1000,
                emb_dims = 1024,
                **kwargs):
        super().__init__()
        self.n_blocks = sum(layers)
        dpr = [x.item() for x in paddle.linspace(0, drop_path, self.n_blocks)]
        num_knn = [int(x.item()) for x in paddle.linspace(k, k, self.n_blocks)]
        max_dilation = 49 // max(num_knn)
        reduce_ratios = [4, 2, 1, 1]
        self.stem = Stem(out_dim=channels[0], act=act)
        self.pos_embed = add_parameter(self,paddle.zeros((1, channels[0], 224//4, 224//4)))
        HW = 224 // 4 * 224 // 4

        self.backbone = []
        idx = 0
        for i in range(len(layers)):
            if i > 0:
                self.backbone.append(Downsample(channels[i-1], channels[i]))
                HW = HW // 4
            for j in range(layers[i]):
                self.backbone+= [
                    nn.Sequential(GCN_block(channels[i], num_knn[idx], 
                                    min(idx // 4 + 1, max_dilation), conv, act, norm,
                                    bias, use_stochastic, epsilon, reduce_ratios[i], 
                                    n=HW, drop_path=dpr[idx],
                                    relative_pos=True),
                          FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx])
                         )]
                idx += 1
        self.backbone = nn.LayerList(self.backbone)
        self.backbone = nn.Sequential(*self.backbone)

        self.prediction = nn.Sequential(nn.Conv2D(channels[-1], 1024, 1),
                              nn.BatchNorm2D(1024),
                              act_layer(act),
                              nn.Dropout(dropout),
                              nn.Conv2D(1024, n_classes, 1))
        self.apply(self.cls_init_weights)

    def cls_init_weights(self, m):
        if isinstance(m, nn.Conv2D):
            kaiming(m.weight)
            #nn.initializer.KaimingNormal(m.weight)
            # trunc_normal_(m.weight)
            if isinstance(m, nn.Conv2D) and m.bias is not None:
                zeros_(m.bias)

    def forward(self, inputs):
        x = self.stem(inputs) + self.pos_embed
        B, C, H, W = x.shape
        for i in range(len(self.backbone)):
            x = self.backbone[i](x)

        x = F.adaptive_avg_pool2d(x, 1)
        return self.prediction(x).squeeze(-1).squeeze(-1)

6. ImageNet上验证准确率

说了这么多,最终咱还是要看模型性能的,我们在ImageNet的验证集上验证准确率。代码如下

#运行一次就行
%cd data/
/
!tar -xf data105740/ILSVRC2012_val.tar
/home/aistudio/data
%cd /home/aistudio/PaddleViT/image_classification/VIG/
/home/aistudio/PaddleViT/image_classification/VIG
!pip install yacs pyyaml
!python -m paddle.distributed.launch --gpus 0 main_multi_gpu.py -cfg='./configs/vig_s.yaml' -dataset='imagenet2012' -batch_size=256 -data_path='/home/aistudio/data/ILSVRC2012_val' -pretrained='./vig_s.pdparams' -eval -amp
LAUNCH INFO 2022-07-02 23:17:43,153 -----------  Configuration  ----------------------
LAUNCH INFO 2022-07-02 23:17:43,153 devices: None
LAUNCH INFO 2022-07-02 23:17:43,154 elastic_level: -1
LAUNCH INFO 2022-07-02 23:17:43,154 elastic_timeout: 30
LAUNCH INFO 2022-07-02 23:17:43,154 gloo_port: 6767
LAUNCH INFO 2022-07-02 23:17:43,154 host: None
LAUNCH INFO 2022-07-02 23:17:43,154 job_id: default
LAUNCH INFO 2022-07-02 23:17:43,154 legacy: False
LAUNCH INFO 2022-07-02 23:17:43,154 log_dir: log
LAUNCH INFO 2022-07-02 23:17:43,154 log_level: INFO
LAUNCH INFO 2022-07-02 23:17:43,154 master: None
LAUNCH INFO 2022-07-02 23:17:43,154 max_restart: 3
LAUNCH INFO 2022-07-02 23:17:43,154 nnodes: 1
LAUNCH INFO 2022-07-02 23:17:43,154 nproc_per_node: None
LAUNCH INFO 2022-07-02 23:17:43,154 rank: -1
LAUNCH INFO 2022-07-02 23:17:43,154 run_mode: collective
LAUNCH INFO 2022-07-02 23:17:43,154 server_num: None
LAUNCH INFO 2022-07-02 23:17:43,154 servers: 
LAUNCH INFO 2022-07-02 23:17:43,154 trainer_num: None
LAUNCH INFO 2022-07-02 23:17:43,154 trainers: 
LAUNCH INFO 2022-07-02 23:17:43,154 training_script: 0
LAUNCH INFO 2022-07-02 23:17:43,154 training_script_args: ['main_multi_gpu.py', '-cfg=./configs/vig_s.yaml', '-dataset=imagenet2012', '-batch_size=256', '-data_path=/home/aistudio/data/ILSVRC2012_val', '-pretrained=./vig_s.pdparams', '-eval', '-amp']
LAUNCH INFO 2022-07-02 23:17:43,154 with_gloo: 0
LAUNCH INFO 2022-07-02 23:17:43,154 --------------------------------------------------
LAUNCH WARNING 2022-07-02 23:17:43,154 Compatible mode enable with args ['--gpus']
-----------  Configuration Arguments -----------
backend: auto
cluster_topo_path: None
elastic_pre_hook: None
elastic_server: None
enable_auto_mapping: False
force: False
gpus: 0
heter_devices: 
heter_worker_num: None
heter_workers: 
host: None
http_port: None
ips: 127.0.0.1
job_id: None
log_dir: log
np: None
nproc_per_node: None
rank_mapping_path: None
run_mode: None
scale: 0
server_num: None
servers: 
training_script: main_multi_gpu.py
training_script_args: ['-cfg=./configs/vig_s.yaml', '-dataset=imagenet2012', '-batch_size=256', '-data_path=/home/aistudio/data/ILSVRC2012_val', '-pretrained=./vig_s.pdparams', '-eval', '-amp']
worker_num: None
workers: 
------------------------------------------------
WARNING 2022-07-02 23:17:43,155 launch.py:519] Not found distinct arguments and compiled with cuda or xpu or npu or mlu. Default use collective mode
WARNING 2022-07-02 23:17:43,155 launch.py:519] Not found distinct arguments and compiled with cuda or xpu or npu or mlu. Default use collective mode
launch train in GPU mode!
INFO 2022-07-02 23:17:43,157 launch_utils.py:561] Local start 1 processes. First process distributed environment info (Only For Debug): 
    +=======================================================================================+
    |                        Distributed Envs                      Value                    |
    +---------------------------------------------------------------------------------------+
    |                       PADDLE_TRAINER_ID                        0                      |
    |                 PADDLE_CURRENT_ENDPOINT                 127.0.0.1:33289               |
    |                     PADDLE_TRAINERS_NUM                        1                      |
    |                PADDLE_TRAINER_ENDPOINTS                 127.0.0.1:33289               |
    |                     PADDLE_RANK_IN_NODE                        0                      |
    |                 PADDLE_LOCAL_DEVICE_IDS                        0                      |
    |                 PADDLE_WORLD_DEVICE_IDS                        0                      |
    |                     FLAGS_selected_gpus                        0                      |
    |             FLAGS_selected_accelerators                        0                      |
    +=======================================================================================+

INFO 2022-07-02 23:17:43,157 launch_utils.py:561] Local start 1 processes. First process distributed environment info (Only For Debug): 
    +=======================================================================================+
    |                        Distributed Envs                      Value                    |
    +---------------------------------------------------------------------------------------+
    |                       PADDLE_TRAINER_ID                        0                      |
    |                 PADDLE_CURRENT_ENDPOINT                 127.0.0.1:33289               |
    |                     PADDLE_TRAINERS_NUM                        1                      |
    |                PADDLE_TRAINER_ENDPOINTS                 127.0.0.1:33289               |
    |                     PADDLE_RANK_IN_NODE                        0                      |
    |                 PADDLE_LOCAL_DEVICE_IDS                        0                      |
    |                 PADDLE_WORLD_DEVICE_IDS                        0                      |
    |                     FLAGS_selected_gpus                        0                      |
    |             FLAGS_selected_accelerators                        0                      |
    +=======================================================================================+

INFO 2022-07-02 23:17:43,157 launch_utils.py:566] details about PADDLE_TRAINER_ENDPOINTS can be found in log/endpoints.log, and detail running logs maybe found in log/workerlog.0
INFO 2022-07-02 23:17:43,157 launch_utils.py:566] details about PADDLE_TRAINER_ENDPOINTS can be found in log/endpoints.log, and detail running logs maybe found in log/workerlog.0
launch proc_id:4013 idx:0
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
----- Imagenet2012 val_list.txt len = 50000
INFO 2022-07-02 23:17:45,105 cloud_utils.py:122] get cluster from args:job_server:None pods:['rank:0 id:None addr:127.0.0.1 port:None visible_gpu:[] trainers:["gpu:[\'0\'] endpoint:127.0.0.1:35773 rank:0"]'] job_stage_flag:None hdfs:None
2022-07-02 23:17:47,117 MASTER_LOG ----- world_size = 1, local_rank = 0 
----- AMP: False
BASE: ['']
DATA:
  BATCH_SIZE: 256
  BATCH_SIZE_EVAL: 256
  CROP_PCT: 0.9
  DATASET: imagenet2012
  DATA_PATH: /home/aistudio/data/ILSVRC2012_val
  IMAGENET_MEAN: [0.485, 0.456, 0.406]
  IMAGENET_STD: [0.229, 0.224, 0.225]
  IMAGE_CHANNELS: 3
  IMAGE_SIZE: 224
  NUM_WORKERS: 1
EVAL: True
MODEL:
  ATTENTION_DROPOUT: 0.0
  CHANNELS: [80, 160, 400, 640]
  DOWNSAMPLES: [True, True, True, True]
  DROPOUT: 0.0
  DROPPATH: 0.1
  EMBED_DIMS: 1024
  LAYERS: [2, 2, 6, 2]
  LAYER_SCALE_INIT_VALUE: 1e-05
  MLP_RATIOS: [4, 4, 4, 4]
  NAME: DeepGCN_s
  NUM_CLASSES: 1000
  PRETRAINED: ./vig_s.pdparams
  RESUME: None
  TYPE: DeepGCN
REPORT_FREQ: 20
SAVE: ./output/eval-20220702-23-17
SAVE_FREQ: 10
SEED: 0
TRAIN:
  ACCUM_ITER: 1
  AUTO_AUGMENT: False
  BASE_LR: 0.001
  COLOR_JITTER: 0.4
  CUTMIX_ALPHA: 1.0
  CUTMIX_MINMAX: None
  END_LR: 1e-05
  GRAD_CLIP: None
  LAST_EPOCH: 0
  LINEAR_SCALED_LR: 1024
  MIXUP_ALPHA: 0.8
  MIXUP_MODE: batch
  MIXUP_PROB: 1.0
  MIXUP_SWITCH_PROB: 0.5
  MODEL_EMA: False
  MODEL_EMA_DECAY: 0.99996
  MODEL_EMA_FORCE_CPU: True
  NUM_EPOCHS: 300
  OPTIMIZER:
    BETAS: (0.9, 0.999)
    EPS: 1e-08
    NAME: AdamW
  RANDOM_ERASE_COUNT: 1
  RANDOM_ERASE_MODE: pixel
  RANDOM_ERASE_PROB: 0.25
  RANDOM_ERASE_SPLIT: False
  RAND_AUGMENT: True
  RAND_AUGMENT_LAYERS: 2
  RAND_AUGMENT_MAGNITUDE: 9
  SMOOTHING: 0.1
  WARMUP_EPOCHS: 5
  WARMUP_START_LR: 1e-06
  WEIGHT_DECAY: 0.05
VALIDATE_FREQ: 1
W0702 23:17:47.117699  4028 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0702 23:17:47.121506  4028 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
2022-07-02 23:17:51,458 MASTER_LOG ----- Total # of val batch (single gpu): 196
2022-07-02 23:17:51,737 MASTER_LOG ----- Pretrained: Load model state from ./vig_s.pdparams
2022-07-02 23:17:51,737 MASTER_LOG ----- Start Validation
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/distributed/parallel.py:158: UserWarning: Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything.
  "Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/parallel.py:631: UserWarning: The program will return to single-card operation. Please check 1, whether you use spawn or fleetrun to start the program. 2, Whether it is a multi-card program. 3, Is the current environment multi-card.
  warnings.warn("The program will return to single-card operation. "
2022-07-02 23:17:57,053 MASTER_LOG Step[0000/0196], Avg Loss: 1.0437, Avg Acc@1: 0.7734, Avg Acc@5: 0.9414
2022-07-02 23:18:45,671 MASTER_LOG Step[0020/0196], Avg Loss: 0.9593, Avg Acc@1: 0.8013, Avg Acc@5: 0.9526
2022-07-02 23:19:33,569 MASTER_LOG Step[0040/0196], Avg Loss: 0.9620, Avg Acc@1: 0.8043, Avg Acc@5: 0.9524
2022-07-02 23:20:23,634 MASTER_LOG Step[0060/0196], Avg Loss: 0.9678, Avg Acc@1: 0.8021, Avg Acc@5: 0.9513
2022-07-02 23:21:14,096 MASTER_LOG Step[0080/0196], Avg Loss: 0.9643, Avg Acc@1: 0.8029, Avg Acc@5: 0.9518
2022-07-02 23:22:03,238 MASTER_LOG Step[0100/0196], Avg Loss: 0.9651, Avg Acc@1: 0.8027, Avg Acc@5: 0.9516
2022-07-02 23:22:53,329 MASTER_LOG Step[0120/0196], Avg Loss: 0.9654, Avg Acc@1: 0.8034, Avg Acc@5: 0.9512
2022-07-02 23:23:43,297 MASTER_LOG Step[0140/0196], Avg Loss: 0.9659, Avg Acc@1: 0.8028, Avg Acc@5: 0.9513
2022-07-02 23:24:33,482 MASTER_LOG Step[0160/0196], Avg Loss: 0.9675, Avg Acc@1: 0.8027, Avg Acc@5: 0.9509
2022-07-02 23:25:23,225 MASTER_LOG Step[0180/0196], Avg Loss: 0.9677, Avg Acc@1: 0.8025, Avg Acc@5: 0.9509
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/multiprocessing/semaphore_tracker.py:144: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
2022-07-02 23:25:59,995 MASTER_LOG ----- Validation: Validation Loss: 0.9671, Validation Acc@1: 0.8030, Validation Acc@5: 0.9511, time: 488.26
INFO 2022-07-02 23:26:01,674 launch.py:402] Local processes completed.
INFO 2022-07-02 23:26:01,674 launch.py:402] Local processes completed.

从上面验证结果看,在ImageNet上VIG-s模型准确率为 Validation Acc@1: 0.8030, Validation Acc@5: 0.9511

和官方给的82.1%还是有差距,但总体性能还是不错的。

7. 训练代码

当然,不能光给有测试代码,来上训练代码,但是由于ImageNet数据集过大,因此这里提供示例的代码,大家训练的话将ImageNet数据集按照代码要求准备一下就可以啦!

!python -m paddle.distributed.launch --gpus 0 main_multi_gpu.py -cfg='./configs/vig_s.yaml' -dataset='imagenet2012' -batch_size=256 -data_path='/home/aistudio/data/ILSVRC2012_val' -pretrained='./vig_s.pdparams' -amp

8.总结

     通用视觉模型一般以序列的结构或者网格的结构来处理图片信息,本文作者创新性地提出以图的方式来处理图片:计算机视觉的一个基本任务是识别图像中的物体。由于图片中的物体通常不是形状规则的方形,所以经典的网格表示或者序列表示在处理图片时显得冗余且不够灵活。本文提出一种基于图表示的新型通用视觉架构 ViG。将输入图像分成许多小块,并将每个小块视为图中的一个节点。基于这些节点构造图形可以更好地表示不规则复杂的物体。

     在构建好了输入图片的图表征之后,作者使用 ViG 模型在所有节点之间交换信息。ViG 的基本组成单元包括两部分:用于图形信息处理的 GCN (图形卷积网络) 模块和用于节点特征变换的 FFN (前馈网络) 模块。直接在图像图形结构上使用图形卷积存在过平滑问题,性能较差。因此作者在每个节点内部引入了更多的特征变换来促进信息的多样性。在图像识别,目标检测等视觉任务中证明了该方法的有效性。

个人一些感悟

     1.图结构是否是比自注意力更优的结构有待探讨。

     2.图神经网络在特征表征上新的应用,有新意。

参考文献

图神经网络试图打入CV主流?中科大华为等联合开源ViG:首次用于视觉任务的GNN

开源链接:原项目https://aistudio.baidu.com/aistudio/projectdetail/4288323

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐