论文复现赛第6期-SimSiam论文复现冠军方案
SimSiam是恺明发表在2021CVPR上的一篇巨作,文章指出,在对比学习中,对于loss与网络结构而言,“崩溃解”确实存在,但是“stop-gradient”操作对于避免“崩溃解有非常重要的作用
基于paddle实现的SimSiam模型
基于paddle框架的Exploring Simple Siamese Representation Learning实现
注:本项目根目录在/home/aistudio/work/PASSL下
一、论文解读
本项目使用paddle框架复现SimSiam模型。在无监督的视觉表示学习的各种新模型中,孪生网络已经成为一种常见的结构。现有模型通过最大化一幅图像的两次增强之间的相似度来避免崩溃解问题。作者通过实现发现不使用 (1) negative sample pairs;(2)large batch;(3)momentum encoders简单孪生网络也可以学习到有用的特征表示。作者通过实验表明:对于loss与网络结构而言,“崩溃解”确实存在,但是“stop-gradient”操作对于避免“崩溃解”有非常重要的作用。作者提出了一种新颖的“stop-gradient”思想并通过实验对其进行了验证,所提出的SimSiam模型在ImageNet及下游任务上均取得了有竞争力的结果。
论文:
- [1] X. Chen and K. He, “Exploring Simple Siamese Representation Learning”, CVPR, 2021.
参考项目:
- Simsiam [官方实现]
二、代码复现
2.1 数据集
本项目所使用的数据集为ImageNet2012。该数据集共包含了1000个类别,训练集包含了120W张图像,验证集包含了5W张图像。
# 挂载数据集
tar xf /home/aistudio/.jupyter/lab/workspaces/data/data114241/Light_ILSVRC2012_part_0.tar
tar xf /home/aistudio/.jupyter/lab/workspaces/data/data114746/Light_ILSVRC2012_part_1.tar
2.2 环境依赖
-
硬件:CPU、GPU
-
软件:
- Python 3.7
- PaddlePaddle-GPU == 2.3.0
cd /home/aistudio/work/PASSL/
pip install -r requirements.txt
2.3 模型的创建
import paddle
import paddle.nn as nn
from ...modules.init import init_backbone_weight
from ...modules import freeze_batchnorm_statictis
from .builder import MODELS
from ..backbones import build_backbone
from ..necks import build_neck
from ..heads import build_head
@MODELS.register()
class SimSiam(nn.Layer):
"""
Build a SimSiam model.
https://arxiv.org/abs/2011.10566
"""
def __init__(self,
backbone,
head=None,
predictor=None,
dim=2048,
use_synch_bn=True
):
"""
Args:
backbone (dict): config of backbone.
head (dict): config of head.
predictor (dict): config of predictor.
use_synch_bn (bool): whether apply apply sync bn.
"""
super(SimSiam, self).__init__()
# Create the encoder
# number classes is the output fc dimension, zero-initialize last BNs
self.encoder = build_backbone(backbone)
# build a 3-layer projector
prev_dim = self.encoder.fc.weight.shape[0]
self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias_attr=False),
nn.BatchNorm1D(prev_dim),
nn.ReLU(),
nn.Linear(prev_dim, prev_dim, bias_attr=False),
nn.BatchNorm1D(prev_dim),
nn.ReLU(),
self.encoder.fc,
nn.BatchNorm1D(dim, weight_attr=False, bias_attr=False))
self.encoder.fc[6].bias.stop_gradient = True
self.predictor = build_neck(predictor)
self.head = build_head(head)
# Convert BatchNorm*d to SyncBatchNorm*d
if use_synch_bn:
self.encoder = nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
self.predictor = nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor)
def train_iter(self, *inputs, **kwargs):
x1, x2 = inputs
# compute features for one view
z1 = self.encoder(x1) # NxC
z2 = self.encoder(x2) # NxC
p1 = self.predictor(z1) # NxC
p2 = self.predictor(z2) # NxC
outputs = self.head(p1, p2, z1.detach(), z2.detach())
return outputs
def forward(self, *inputs, mode='train', **kwargs):
if mode == 'train':
return self.train_iter(*inputs, **kwargs)
else:
raise Exception("No such mode: {}".format(mode))
2.4 训练
python tools/train.py -c configs/simsiam/simsiam_r50.yaml
class Trainer:
r"""
# trainer calling logic:
#
# build_model || model(BaseModel)
# | ||
# build_dataloader || dataloader
# | ||
# build_lr_scheduler || lr_scheduler
# | ||
# build_optimizer || optimizers
# | ||
# build_train_hooks || train hooks
# | ||
# build_custom_hooks || custom hooks
# | ||
# train loop || train loop
# | ||
# hook(print log, checkpoint, evaluate, ajust lr) || call hook
# | ||
# end \/
"""
def __init__(self, cfg):
# base config
self.logger = logging.getLogger(__name__)
self.cfg = cfg
self.output_dir = cfg.output_dir
dp_rank = dist.get_rank()
self.log_interval = cfg.log_config.interval
# set seed
seed = cfg.get('seed', False)
if seed:
seed += dp_rank
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
# set device
assert cfg['device'] in ['cpu', 'gpu', 'xpu', 'npu']
self.device = paddle.set_device(cfg['device'])
self.logger.info('train with paddle {} on {} device'.format(
paddle.__version__, self.device))
self.start_epoch = 0
self.current_epoch = 0
self.current_iter = 0
self.inner_iter = 0
self.batch_id = 0
self.global_steps = 0
use_byol_iters = cfg.get('use_byol_iters', False)
self.use_byol_iters = use_byol_iters
use_simclr_iters = cfg.get('use_simclr_iters', False)
self.use_simclr_iters = use_simclr_iters
self.epochs = cfg.get('epochs', None)
self.timestamp = cfg.timestamp
self.logs = OrderedDict()
# Ensure that the vdl log file can be closed normally
# build model
self.model = build_model(cfg.model)
n_parameters = sum(p.numel() for p in self.model.parameters()
if not p.stop_gradient).item()
i = int(math.log(n_parameters, 10) // 3)
size_unit = ['', 'K', 'M', 'B', 'T', 'Q']
self.logger.info("Number of Parameters is {:.2f}{}.".format(
n_parameters / math.pow(1000, i), size_unit[i]))
# build train dataloader
self.train_dataloader, self.mixup_fn = build_dataloader(
cfg.dataloader.train, self.device)
self.iters_per_epoch = len(self.train_dataloader)
# use byol iters
if self.use_byol_iters:
self.global_batch_size = cfg.global_batch_size
self.byol_total_iters = self.epochs * cfg.total_images // self.global_batch_size
if self.use_byol_iters:
self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler,
self.byol_total_iters)
elif self.use_simclr_iters:
self.batch_size = cfg.dataloader.train.sampler.batch_size
self.global_batch_size = cfg.global_batch_size
self.epochs = cfg.epochs
self.lr_scheduler = build_lr_scheduler_simclr(
cfg.lr_scheduler, self.iters_per_epoch, self.batch_size * 8,
cfg.epochs, self.current_iter)
else:
self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler,
self.iters_per_epoch)
self.optimizer = build_optimizer(cfg.optimizer, self.lr_scheduler,
[self.model])
# distributed settings
if dist.get_world_size() > 1:
strategy = fleet.DistributedStrategy()
## Hybrid Parallel Training
strategy.hybrid_configs = cfg.pop(
'hybrid') if 'hybrid' in cfg else {}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
mp_rank = hcg.get_model_parallel_rank()
pp_rank = hcg.get_stage_id()
dp_rank = hcg.get_data_parallel_rank()
set_hyrbid_parallel_seed(
seed, 0, mp_rank, pp_rank, device=self.device)
# amp training
self.use_amp = cfg.get('use_amp',
False) #if 'use_amp' in cfg else False
if self.use_amp:
amp_cfg = cfg.pop('AMP')
self.auto_cast = amp_cfg.pop('auto_cast')
scale_loss = amp_cfg.pop('scale_loss')
self.scaler = paddle.amp.GradScaler(init_loss_scaling=scale_loss)
amp_cfg['models'] = self.model
self.model = paddle.amp.decorate(**amp_cfg) # decorate for level O2
# ZeRO
self.sharding_strategies = cfg.get('sharding', False)
if self.sharding_strategies:
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
self.sharding_stage = self.sharding_strategies['sharding_stage']
accumulate_grad = self.sharding_strategies['accumulate_grad']
offload = self.sharding_strategies['offload']
if self.sharding_stage == 2:
self.optimizer = ShardingOptimizerStage2(
params=self.model.parameters(),
optim=self.optimizer,
offload=offload)
self.model = ShardingStage2(
self.model,
self.optimizer,
accumulate_grads=accumulate_grad)
self.scaler = ShardingScaler(self.scaler)
else:
raise NotImplementedError()
# data parallel
elif dist.get_world_size() > 1:
self.model = fleet.distributed_model(self.model)
# build hooks
self.hooks = []
self.add_train_hooks()
self.add_custom_hooks()
self.hooks = sorted(self.hooks, key=lambda x: x.priority)
if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch
self.by_epoch = True
else:
self.by_epoch = False
self.total_iters = cfg.total_iters
def add_train_hooks(self):
optim_cfg = self.cfg.get('optimizer_config', None)
if optim_cfg is not None:
self.add_hook(build_hook(optim_cfg))
else:
self.add_hook(build_hook({'name': 'OptimizerHook'}))
timer_cfg = self.cfg.get('timer_config', None)
if timer_cfg is not None:
self.add_hook(build_hook(timer_cfg))
else:
self.add_hook(build_hook({'name': 'IterTimerHook'}))
ckpt_cfg = self.cfg.get('checkpoint', None)
if ckpt_cfg is not None:
self.add_hook(build_hook(ckpt_cfg))
else:
self.add_hook(build_hook({'name': 'CheckpointHook'}))
log_cfg = self.cfg.get('log_config', None)
if log_cfg is not None:
self.add_hook(build_hook(log_cfg))
else:
self.add_hook(build_hook({'name': 'LogHook'}))
lr_cfg = self.cfg.get('lr_config', None)
if lr_cfg is not None:
self.add_hook(build_hook(lr_cfg))
else:
self.add_hook(build_hook({'name': 'LRSchedulerHook'}))
def add_custom_hooks(self):
custom_cfgs = self.cfg.get('custom_config', None)
if custom_cfgs is None:
return
for custom_cfg in custom_cfgs:
cfg_ = custom_cfg.copy()
insert_index = cfg_.pop('insert_index', None)
self.add_hook(build_hook(cfg_), insert_index)
def add_hook(self, hook, insert_index=None):
assert isinstance(hook, Hook)
if insert_index is None:
self.hooks.append(hook)
elif isinstance(insert_index, int):
self.hooks.insert(insert_index, hook)
def call_hook(self, fn_name):
for hook in self.hooks:
getattr(hook, fn_name)(self)
def train(self):
self.mode = 'train'
self.model.train()
iter_loader = IterLoader(self.train_dataloader, self.current_epoch)
self.call_hook('run_begin')
while self.current_iter < (self.total_iters):
if self.current_iter % self.iters_per_epoch == 0:
self.call_hook('train_epoch_begin')
self.inner_iter = self.current_iter % self.iters_per_epoch
self.current_iter += 1
self.current_epoch = iter_loader.epoch
data = next(iter_loader)
self.call_hook('train_iter_begin')
if self.use_amp:
with paddle.amp.auto_cast(**self.auto_cast):
if self.use_byol_iters:
self.outputs = self.model(
*data,
total_iters=self.byol_total_iters,
current_iter=self.current_iter,
mixup_fn=self.mixup_fn)
else:
self.outputs = self.model(
*data,
total_iters=self.total_iters,
current_iter=self.current_iter,
mixup_fn=self.mixup_fn)
else:
if self.use_byol_iters:
self.outputs = self.model(
*data,
total_iters=self.byol_total_iters,
current_iter=self.current_iter,
mixup_fn=self.mixup_fn)
else:
self.outputs = self.model(
*data,
total_iters=self.total_iters,
current_iter=self.current_iter,
mixup_fn=self.mixup_fn)
self.call_hook('train_iter_end')
if self.current_iter % self.iters_per_epoch == 0:
self.call_hook('train_epoch_end')
self.current_epoch += 1
self.call_hook('run_end')
2.5 验证
python tools/train.py -c configs/simsiam/simsiam_clas_r50.yaml --load ${CLS_WEGHT_FILE} --evaluate-only
def val(self, **kargs):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader, mixup_fn = build_dataloader(
self.cfg.dataloader.val, self.device)
self.logger.info('start evaluate on epoch {} ..'.format(
self.current_epoch + 1))
rank = dist.get_rank()
world_size = dist.get_world_size()
model = self.model
total_samples = len(self.val_dataloader.dataset)
self.logger.info('Evaluate total samples {}'.format(total_samples))
if rank == 0:
dataloader = tqdm(self.val_dataloader)
else:
dataloader = self.val_dataloader
accum_samples = 0
self.model.eval()
outs = OrderedDict()
for data in dataloader:
if isinstance(data, paddle.Tensor):
batch_size = data.shape[0]
elif isinstance(data, (list, tuple)):
batch_size = data[0].shape[0]
else:
raise TypeError('unknown type of data')
labels = data[-1]
if self.use_amp:
with paddle.amp.auto_cast(**self.auto_cast):
pred = model(*data, mode='test')
else:
pred = model(*data, mode='test')
current_samples = batch_size * world_size
accum_samples += current_samples
# for k, v in outputs.items():
if world_size > 1:
pred_list = []
dist.all_gather(pred_list, pred)
pred = paddle.concat(pred_list, 0)
label_list = []
dist.all_gather(label_list, labels)
labels = paddle.concat(label_list, 0)
if accum_samples > total_samples:
self.logger.info('total samples {} {} {}'.format(
total_samples, accum_samples, total_samples +
current_samples - accum_samples))
pred = pred[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
res = self.val_dataloader.dataset.evaluate(pred, labels, **kargs)
for k, v in res.items():
if k not in outs:
outs[k] = AverageMeter(k, ':6.3f')
outs[k].update(float(v), current_samples)
log_str = f'Validate Epoch [{self.current_epoch + 1}] '
log_items = []
for name, val in outs.items():
if isinstance(val, AverageMeter):
string = '{} ({' + outs[k].fmt + '})'
val = string.format(val.name, val.avg)
log_items.append(val)
log_str += ', '.join(log_items)
self.logger.info(log_str)
self.model.train()
2.6 复现结果
本项目验证其在图像分类下游任务中的性能,所使用的数据集为ImageNet2012,复现精度如下:
Model | Official | Passl |
---|---|---|
Simsiam | 68.3 | 68.4 |
三、参考资料
- https://github.com/facebookresearch/simsiam
四、总结
- 飞桨是一个非常好的团队!有问题一定要及时提issue,或者进入专属群询问飞桨工程师们。这要比自己瞎捣鼓效率要高的多!
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接
更多推荐
所有评论(0)