转载自AI Studio
项目链接https://aistudio.baidu.com/aistudio/projectdetail/3449604

XCiT:协方差图像全思变换网络

这是飞桨论文复现挑战赛(第五期)中《XCiT: Cross-Covariance Image Transformers》的冠军代码,GitHub地址为https://github.com/BrilliantYuKaimin/XCiT-PaddlePaddle。官方的PyTorch实现在https://github.com/facebookresearch/xcit

由于XCiT的代码已经被合入飞桨PASSL,所以本项目的代码将基于PASSL运行。关于XCiT的组网代码,可以在PASSL/passl/modeling/backbones/xcit.py中找到。

本项目提供了2个飞桨权重文件:

  • xcit_nano_12_p8_224_dist.pdparams是使用飞桨重新训练出的权重,它在ImageNet1k测试集上的准确率为77.28%,超过了官方结果76.3%。
  • regnety_160.pdparams是用于知识蒸馏的教师模型权重,由对应的PyTorch权重转换而来。

简介

紧跟着在自然语言处理领域的成功,全思变换网络(Transformer)在计算机视觉领域也展现出了巨大的前景。全思变换网络中的自注意力操作会在序列的每一项(比如在句子中就是一个词,在图片中就是一个小图片块)之间产生全局交互,并允许在卷积那样的局部交互之外对图像数据进行灵活建模。然而,这种灵活性的代价是时间和空间上的平方级别的复杂度,这阻碍了全思变换网络在长序列和高分辨率图像上的应用。作者提出了一个转置版本的自注意力操作,它借助键值询值的协方差矩阵在特征的通道(而不是序列的项)之间进行运算。这样的协方差注意力(XCA)在序列长度上具有线性复杂度,从而允许高效处理高分辨率图像。基于XCA,协方差图像全思变换网络(XCiT)将传统全思变换网络的精确性和卷积结构的可伸缩性结合了起来。

XCiT

协方差注意力

对于一个形状为 N × d N\times d N×d 的输入 X \bm X X,其中 N N N表示序列长度,通过三个不同的线性变换我们可以得到三个矩阵
Q = X W q ,   K = X W k ,   V = X W v , \bm Q=\bm X\bm W_{\mathrm q},\ \bm K=\bm X\bm W_{\mathrm k},\ \bm V=\bm X\bm W_{\mathrm v}, Q=XWq, K=XWk, V=XWv,
它们仍是 N × d N\times d N×d 的。原始的注意力机制计算的是
s o f t m a x   ( Q K T / d 1 / 2 ) V , \mathrm{softmax}\,\left(\bm Q\bm K^{\mathrm T}/d^{1/2}\right)\bm V, softmax(QKT/d1/2)V,
而协方差注意力机制计算的是
V s o f t m a x   ( K T Q / τ ) , \bm V \mathrm{softmax}\,\left(\bm K^{\mathrm T}\bm Q/\tau\right), Vsoftmax(KTQ/τ),
其中 τ \tau τ 是一个与 d 1 / 2 d^{1/2} d1/2 的地位类似的超参数。两者之间的一个显著区别是,计算 Q K T \bm Q\bm K^{\mathrm T} QKT 需要 N 2 d N^2d N2d 次乘法,而计算 K T Q \bm K^{\mathrm T}\bm Q KTQ 需要 N d 2 Nd^2 Nd2 次乘法。由此可见,协方差注意力机制的复杂度关于序列的长度是线性的。

另一方面,
K T Q = W k T X T X W q , \bm K^{\mathrm T}\bm Q=\bm W_{\mathrm k}^{\mathrm T}\bm X^{\mathrm T}\bm X\bm W_{\mathrm q}, KTQ=WkTXTXWq,
而其中 X T X \bm X^{\mathrm T}\bm X XTX 就是 X \bm X X 的行向量的(未归一化的)协方差矩阵。这也是协方差注意力机制名字的由来。

快速开始

如果您想进行训练或测试,请先运行下面的代码来解压并整理数据集,以及安装必要的库。

!tar -xf data/data114241/Light_ILSVRC2012_part_0.tar --directory data/
!rm -rf data/data114241/
!tar -xf data/data114746/Light_ILSVRC2012_part_1.tar --directory data/
!rm -rf data/data114746/
!mv data/Light_ILSVRC2012/ data/ILSVRC2012
!cd data/ILSVRC2012/val && bash ~/valprep.sh
!pip install einops
!pip install ftfy
!pip install regex
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting einops
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/6f/fb90ccb765bc521d363f605aaddb4c4169891d431b9c6fed0451c5a533f5/einops-0.4.0-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.0
[33mWARNING: You are using pip version 21.3.1; however, version 22.0.3 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting ftfy
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/1e/bf736f9576a8979752b826b75cbd83663ff86634ea3055a766e2d8ad3ee5/ftfy-6.1.1-py3-none-any.whl (53 kB)
     |████████████████████████████████| 53 kB 1.7 MB/s             
[?25hCollecting wcwidth>=0.2.5
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/59/7c/e39aca596badaf1b78e8f547c807b04dae603a433d3e7a7e04d67f2ef3e5/wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
Installing collected packages: wcwidth, ftfy
  Attempting uninstall: wcwidth
    Found existing installation: wcwidth 0.1.7
    Uninstalling wcwidth-0.1.7:
      Successfully uninstalled wcwidth-0.1.7
Successfully installed ftfy-6.1.1 wcwidth-0.2.5
[33mWARNING: You are using pip version 21.3.1; however, version 22.0.3 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting regex
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/82/b9/09143a2072af5571227f1687e44fd9041cc5933fffaf2fbc30394c720141/regex-2022.1.18-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (748 kB)
     |████████████████████████████████| 748 kB 1.3 MB/s            
[?25hInstalling collected packages: regex
Successfully installed regex-2022.1.18
[33mWARNING: You are using pip version 21.3.1; however, version 22.0.3 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m

训练

下面的两段命令分别表示蒸馏训练和普通训练(在单机四卡环境下,完成完整的蒸馏训练大约需要15天):

# 知识蒸馏
!python PASSL/tools/train.py -c PASSL/configs/xcit/xcit_nano_12_p8_224_dist.yaml
[02/27 20:07:19] passl INFO: Configs: {'epochs': 400, 'output_dir': 'outputs', 'seed': 0, 'device': 'gpu', 'model': {'name': 'DistillationWrapper', 'models': [{'Teacher': {'name': 'Classification', 'backbone': {'name': 'RegNet', 'w_a': 106.23, 'w_0': 200, 'w_m': 2.48, 'd': 18, 'group_w': 112, 'bot_mul': 1.0, 'q': 8, 'se_on': True}, 'head': {'name': 'ClasHead', 'in_channels': 3024, 'num_classes': 1000}}}, {'Student': {'name': 'SwinWrapper', 'architecture': {'name': 'XCiT', 'img_size': 224, 'patch_size': 8, 'embed_dim': 128, 'depth': 12, 'num_heads': 4, 'eta': 1.0, 'tokens_norm': False}, 'head': {'name': 'SwinTransformerClsHead', 'in_channels': 128, 'num_classes': 1000}}}], 'pretrained_list': ['regnety_160.pdparams', None], 'freeze_params_list': [True, False], 'infer_model_key': 'Student', 'dml_loss_weight': 0.5, 'head_loss_weight': 0.5}, 'dataloader': {'train': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': True, 'drop_last': True}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/train/', 'return_label': True, 'transforms': [{'name': 'RandomResizedCrop', 'size': 224, 'scale': [0.08, 1.0], 'interpolation': 'bicubic'}, {'name': 'RandomHorizontalFlip'}, {'name': 'AutoAugment', 'config_str': 'rand-m9-mstd0.5-inc1', 'interpolation': 'bicubic', 'img_size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}, {'name': 'RandomErasing', 'prob': 0.25, 'mode': 'pixel', 'max_count': 1}], 'batch_transforms': [{'name': 'Mixup', 'mixup_alpha': 0.8, 'prob': 1.0, 'switch_prob': 0.5, 'mode': 'batch', 'cutmix_alpha': 1.0}]}}, 'val': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': False, 'drop_last': False}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/val', 'return_label': True, 'transforms': [{'name': 'Resize', 'size': 224, 'interpolation': 'bicubic'}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}]}}}, 'lr_scheduler': {'name': 'LinearWarmup', 'learning_rate': {'name': 'CosineAnnealingDecay', 'learning_rate': 0.0005, 'T_max': 400, 'eta_min': 1e-05}, 'warmup_steps': 5, 'start_lr': 1e-06, 'end_lr': 0.0005}, 'optimizer': {'name': 'AdamW', 'beta1': 0.9, 'beta2': 0.999, 'weight_decay': 0.05, 'exclude_from_weight_decay': ['temperature', 'pos_embed', 'cls_token', 'dist_token']}, 'log_config': {'name': 'LogHook', 'interval': 10}, 'checkpoint': {'name': 'CheckpointHook', 'by_epoch': True, 'interval': 1, 'max_keep_ckpts': 50}, 'custom_config': [{'name': 'EvaluateHook'}], 'is_train': True, 'timestamp': '-2022-02-27-20-07'}
[02/27 20:07:19] passl.engine.trainer INFO: train with paddle 2.2.2 on CUDAPlace(0) device
W0227 20:07:19.224476 22323 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0227 20:07:19.229701 22323 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[02/27 20:07:25] passl.engine.trainer INFO: Number of Parameters is 3.05M.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:253: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float64, but right dtype is paddle.float32, the right dtype will convert to paddle.float64
  format(lhs_dtype, rhs_dtype, lhs_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:253: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float32, but right dtype is paddle.float64, the right dtype will convert to paddle.float32
  format(lhs_dtype, rhs_dtype, lhs_dtype))
[02/27 20:07:36] passl.engine.trainer INFO: Epoch [1/400][0/10009]	lr: 1.000e-06, eta: 250 days, 5:42:54, time: 5.400, data_time: 3.590, loss 4.9992e+00 (4.9992e+00), acc1  0.000 ( 0.000), acc5  0.781 ( 0.781)
[02/27 20:07:47] passl.engine.trainer INFO: Epoch [1/400][10/10009]	lr: 1.100e-06, eta: 69 days, 10:49:50, time: 1.499, data_time: 0.445, loss 5.4043e+00 (4.9075e+00), acc1  0.000 ( 0.000), acc5  0.000 ( 0.426)
^C
Traceback (most recent call last):
  File "PASSL/tools/train.py", line 52, in <module>
    main(args, cfg)
  File "PASSL/tools/train.py", line 46, in main
    trainer.train()
  File "/home/aistudio/PASSL/passl/engine/trainer.py", line 326, in train
    self.call_hook('train_iter_end')
  File "/home/aistudio/PASSL/passl/engine/trainer.py", line 282, in call_hook
    getattr(hook, fn_name)(self)
  File "/home/aistudio/PASSL/passl/hooks/log_hook.py", line 151, in train_iter_end
    trainer.logs[k].update(float(v))
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py", line 119, in _float_
    return float(var.numpy().flatten()[0])
KeyboardInterrupt
# 普通训练
!python PASSL/tools/train.py -c PASSL/configs/xcit/xcit_nano_12_p8_224.yaml
[02/27 19:54:25] passl INFO: Configs: {'epochs': 400, 'output_dir': 'outputs', 'seed': 0, 'device': 'gpu', 'model': {'name': 'SwinWrapper', 'architecture': {'name': 'XCiT', 'patch_size': 8, 'embed_dim': 128, 'depth': 12, 'num_heads': 4, 'eta': 1.0, 'tokens_norm': False}, 'head': {'name': 'SwinTransformerClsHead', 'in_channels': 128, 'num_classes': 1000}}, 'dataloader': {'train': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': True, 'drop_last': True}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/train/', 'return_label': True, 'transforms': [{'name': 'RandomResizedCrop', 'size': 224, 'scale': [0.08, 1.0], 'interpolation': 'bicubic'}, {'name': 'RandomHorizontalFlip'}, {'name': 'AutoAugment', 'config_str': 'rand-m9-mstd0.5-inc1', 'interpolation': 'bicubic', 'img_size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}, {'name': 'RandomErasing', 'prob': 0.25, 'mode': 'pixel', 'max_count': 1}], 'batch_transforms': [{'name': 'Mixup', 'mixup_alpha': 0.8, 'prob': 1.0, 'switch_prob': 0.5, 'mode': 'batch', 'cutmix_alpha': 1.0}]}}, 'val': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': False, 'drop_last': False}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/val', 'return_label': True, 'transforms': [{'name': 'Resize', 'size': 224, 'interpolation': 'bicubic'}, {'name': 'CenterCrop', 'size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}]}}}, 'lr_scheduler': {'name': 'LinearWarmup', 'learning_rate': {'name': 'CosineAnnealingDecay', 'learning_rate': 0.0005, 'T_max': 400, 'eta_min': 1e-05}, 'warmup_steps': 5, 'start_lr': 1e-06, 'end_lr': 0.0005}, 'optimizer': {'name': 'AdamW', 'beta1': 0.9, 'beta2': 0.999, 'weight_decay': 0.05, 'exclude_from_weight_decay': ['temperature', 'pos_embed', 'cls_token', 'dist_token']}, 'log_config': {'name': 'LogHook', 'interval': 10}, 'checkpoint': {'name': 'CheckpointHook', 'by_epoch': True, 'interval': 1, 'max_keep_ckpts': 50}, 'custom_config': [{'name': 'EvaluateHook'}], 'is_train': True, 'timestamp': '-2022-02-27-19-54'}
[02/27 19:54:25] passl.engine.trainer INFO: train with paddle 2.2.2 on CUDAPlace(0) device
W0227 19:54:25.295194 21195 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0227 19:54:25.300439 21195 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[02/27 19:54:30] passl.engine.trainer INFO: Number of Parameters is 3.05M.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:253: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float64, but right dtype is paddle.float32, the right dtype will convert to paddle.float64
  format(lhs_dtype, rhs_dtype, lhs_dtype))
[02/27 19:54:40] passl.engine.trainer INFO: Epoch [1/400][0/10009]	lr: 1.000e-06, eta: 221 days, 16:42:20, time: 4.784, data_time: 3.609, loss 6.9207e+00 (6.9207e+00), acc1  0.000 ( 0.000), acc5  0.781 ( 0.781)
[02/27 19:54:48] passl.engine.trainer INFO: Epoch [1/400][10/10009]	lr: 1.100e-06, eta: 53 days, 11:22:07, time: 1.154, data_time: 0.444, loss 6.9192e+00 (6.9154e+00), acc1  0.000 ( 0.142), acc5  0.000 ( 0.781)
[02/27 19:54:55] passl.engine.trainer INFO: Epoch [1/400][20/10009]	lr: 1.199e-06, eta: 42 days, 11:55:29, time: 0.917, data_time: 0.300, loss 6.9059e+00 (6.9142e+00), acc1  0.000 ( 0.260), acc5  0.000 ( 0.744)
[02/27 19:55:01] passl.engine.trainer INFO: Epoch [1/400][30/10009]	lr: 1.299e-06, eta: 38 days, 14:04:44, time: 0.833, data_time: 0.250, loss 6.9211e+00 (6.9134e+00), acc1  0.000 ( 0.176), acc5  0.000 ( 0.605)
[02/27 19:55:08] passl.engine.trainer INFO: Epoch [1/400][40/10009]	lr: 1.399e-06, eta: 36 days, 18:12:11, time: 0.793, data_time: 0.224, loss 6.9336e+00 (6.9138e+00), acc1  0.000 ( 0.152), acc5  0.000 ( 0.591)
[02/27 19:55:14] passl.engine.trainer INFO: Epoch [1/400][50/10009]	lr: 1.499e-06, eta: 35 days, 10:56:52, time: 0.765, data_time: 0.208, loss 6.9081e+00 (6.9138e+00), acc1  0.781 ( 0.138), acc5  2.344 ( 0.551)
[02/27 19:55:21] passl.engine.trainer INFO: Epoch [1/400][60/10009]	lr: 1.598e-06, eta: 34 days, 10:30:28, time: 0.743, data_time: 0.197, loss 6.9240e+00 (6.9133e+00), acc1  0.000 ( 0.115), acc5  0.000 ( 0.564)
^C
Traceback (most recent call last):
  File "PASSL/tools/train.py", line 52, in <module>
    main(args, cfg)
  File "PASSL/tools/train.py", line 46, in main
    trainer.train()
  File "/home/aistudio/PASSL/passl/engine/trainer.py", line 325, in train
    mixup_fn=self.mixup_fn)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__
    return self._dygraph_call_func(*inputs, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func
    outputs = self.forward(*inputs, **kwargs)
  File "/home/aistudio/PASSL/passl/modeling/architectures/SwinWrapper.py", line 66, in forward
    return self.train_iter(*inputs, **kwargs)
  File "/home/aistudio/PASSL/passl/modeling/architectures/SwinWrapper.py", line 50, in train_iter
    x = self.backbone_forward(img)
  File "/home/aistudio/PASSL/passl/modeling/architectures/SwinWrapper.py", line 41, in backbone_forward
    x = self.backbone(x)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__
    return self._dygraph_call_func(*inputs, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func
    outputs = self.forward(*inputs, **kwargs)
  File "/home/aistudio/PASSL/passl/modeling/backbones/xcit.py", line 487, in forward
    x = self.forward_features(x)
  File "/home/aistudio/PASSL/passl/modeling/backbones/xcit.py", line 475, in forward_features
    x = blk(x, Hp, Wp)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__
    return self._dygraph_call_func(*inputs, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func
    outputs = self.forward(*inputs, **kwargs)
  File "/home/aistudio/PASSL/passl/modeling/backbones/xcit.py", line 346, in forward
    x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__
    return self._dygraph_call_func(*inputs, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func
    outputs = self.forward(*inputs, **kwargs)
  File "/home/aistudio/PASSL/passl/modeling/backbones/xcit.py", line 288, in forward
    q = nn.functional.normalize(q, axis=-1)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/functional/norm.py", line 82, in normalize
    eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype)
  File "<decorator-gen-46>", line 2, in to_variable
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/wrapped_decorator.py", line 25, in __impl__
    return wrapped_func(*args, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/framework.py", line 229, in __impl__
    return func(*args, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py", line 728, in to_variable
    name=name if name else '')
KeyboardInterrupt

评估

训练下面的命令可以对训练好的模型进行评估。

!python PASSL/tools/train.py -c PASSL/configs/xcit/xcit_nano_12_p8_224.yaml \
                             --load xcit_nano_12_p8_224_dist.pdparams \
                             --evaluate-only
[03/01 19:47:20] passl INFO: Configs: {'epochs': 400, 'output_dir': 'outputs', 'seed': 0, 'device': 'gpu', 'model': {'name': 'SwinWrapper', 'architecture': {'name': 'XCiT', 'patch_size': 8, 'embed_dim': 128, 'depth': 12, 'num_heads': 4, 'eta': 1.0, 'tokens_norm': False}, 'head': {'name': 'SwinTransformerClsHead', 'in_channels': 128, 'num_classes': 1000}}, 'dataloader': {'train': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': True, 'drop_last': True}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/train/', 'return_label': True, 'transforms': [{'name': 'RandomResizedCrop', 'size': 224, 'scale': [0.08, 1.0], 'interpolation': 'bicubic'}, {'name': 'RandomHorizontalFlip'}, {'name': 'AutoAugment', 'config_str': 'rand-m9-mstd0.5-inc1', 'interpolation': 'bicubic', 'img_size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}, {'name': 'RandomErasing', 'prob': 0.25, 'mode': 'pixel', 'max_count': 1}], 'batch_transforms': [{'name': 'Mixup', 'mixup_alpha': 0.8, 'prob': 1.0, 'switch_prob': 0.5, 'mode': 'batch', 'cutmix_alpha': 1.0}]}}, 'val': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': False, 'drop_last': False}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/val', 'return_label': True, 'transforms': [{'name': 'Resize', 'size': 224, 'interpolation': 'bicubic'}, {'name': 'CenterCrop', 'size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}]}}}, 'lr_scheduler': {'name': 'LinearWarmup', 'learning_rate': {'name': 'CosineAnnealingDecay', 'learning_rate': 0.0005, 'T_max': 400, 'eta_min': 1e-05}, 'warmup_steps': 5, 'start_lr': 1e-06, 'end_lr': 0.0005}, 'optimizer': {'name': 'AdamW', 'beta1': 0.9, 'beta2': 0.999, 'weight_decay': 0.05, 'exclude_from_weight_decay': ['temperature', 'pos_embed', 'cls_token', 'dist_token']}, 'log_config': {'name': 'LogHook', 'interval': 10}, 'checkpoint': {'name': 'CheckpointHook', 'by_epoch': True, 'interval': 1, 'max_keep_ckpts': 50}, 'custom_config': [{'name': 'EvaluateHook'}], 'is_train': False, 'timestamp': '-2022-03-01-19-47'}
[03/01 19:47:20] passl.engine.trainer INFO: train with paddle 2.2.2 on CUDAPlace(0) device
W0301 19:47:20.903514 23036 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0301 19:47:20.908339 23036 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[03/01 19:47:25] passl.engine.trainer INFO: Number of Parameters is 3.05M.
[03/01 19:47:30] passl.engine.trainer INFO: start evaluate on epoch 1 ..
[03/01 19:47:30] passl.engine.trainer INFO: Evaluate total samples 50000
100%|█████████████████████████████████████████| 391/391 [02:41<00:00,  2.43it/s]
[03/01 19:50:11] passl.engine.trainer INFO: Validate Epoch [1] acc1 (77.276), acc5 (93.248)
Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐