论文复现赛第六期】nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation

paper:nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation
github:https://github.com/MIC-DKFZ/nnUNet
复现repo:https://github.com/justld/nnunet_paddle

nnUNet(no-new-UNet)没有提出新的技巧,而是通过数据处理、自适应训练配置生成、自适应网络结构、多折训练、模型ensemble等各种技巧,将医疗图像分割任务进行了统一,下面一起来看看nnUNet的细节吧。

PS:现在的nnUNet经过多次更新,可能paper中部分内容与源码不一致,一切以官方源码为复现基准。此外,由于代码量很大,将陆续PR到MedicalSeg(目前已包含数据预处理功能)。

一、医学图像分割十项全能比赛(MSD)

Medical Segmentation Decathlon(MSD)挑战赛的参赛者要求参赛者使用一个语义分割算法完成人体10个部位的器官分割,要求不允许手动调整参数(可以自动调整)。
具体比赛内容可以参考:医学图像分割十项全能比赛(MSD)回顾,由于内容过多,在此不再过多介绍。

本次复现赛使用的数据集是肺部数据集,共有64个训练图像,32个测试图像,训练集示例如下:
在这里插入图片描述

上图为ITK-SNAP可视化结果。

二、网络结构

nnUNet作者认为,图像预处理、网络拓扑和后处理比网络结构更重要,nnUNet使用了初始的UNet结构,并做了几点修改:
(1)使用leakyrelu取代relu;
(2)用instance normalization替代batch normalization;

作者设计了3中结构:**2D-UNet,3D-UNet,UNet-cascade **

2D-UNet:虽然大部分医疗图像都是3d图像,但是对于某些数据集来说,使用3D卷积的效果并不比2D卷积好(比如MSD比赛的Prostate数据集);

3D-UNet:对于3D图像来说,3D-UNet是个很合适的方法,但是如果图像尺寸很大,那么由于GPU显存限制,网络可能学习不到充足的语义信息,对于大的图片,使用patch训练;(可能这里容易引起困惑,稍微解释一下,仅代表个人观点:nnUNet会设定一个显存限制,然后再该限制下计算网络结构,且要求最小的batch size大于等于2,那么当输入图片很大的时候,显存就不够用了。于是使用patch训练,但是patch只包含图片的部分信息,无法得到足够的上下文信息)

UNet-cascade:为了解决3D-UNet的缺点,作者提出了级联的方法,训练分为2个阶段,第一阶段对输入图片下采样,使用下采样的图片得到粗糙的分割结果(下采样来降低显存要求);第二阶段,将第一阶段的分割结果上采样,然后与原图concat,使用patch训练。网络结构如下:
PS:在阅读源码的时候,发现nnUNet的stage就是3D-UNet,在nnUNet常见问题最后一个问题“Why is no 3d_lowres model created?”有相关说明,说明中提到可以设置HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0 来构建低分辨率模型,然而在配置文件生成时并未用到该参数(官方repo教程指令并未使用),但是官方repo给出的配置生成替代策略有使用到该参数,此处遵循官方repo的教程,未手动指定该参数,即UNet-cascade 第一阶段的模型是3D-UNet。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WUe2kX0e-1656725000069)(https://ai-studio-static-online.cdn.bcebos.com/fab60e4211734c28a5da90910e404f22919270208cc54be6ac81d21c6449e5da)]

由于不同数据集的图像尺寸不同(Liver的图片中值尺寸是482x512x512,Hippocampus的中值尺寸是 36 × 50 × 35),故对于不同的数据集,patch和网络结构也会不同。网络的配置规则如下(这里可能和paper中说的不太一致,以下叙述结合官方repo代码,并非完全按照paper叙述,此外nnUNet的配置生成目前有2个版本,本文以experiment_planning目录下experiment_planner_xx_v21.py为准):

1、对于2D-UNet,遍历数据集,得到所有图片的shape,取shape中值作为patch(为了保证patch size是下采样次数的倍数,会对patch进行修正),网络配置保证最小的特征图高和宽小于8(但是下采样次数不超过6次),网络起始卷积层通道数为32(论文中是30,参考源码32),每次下采样通道数翻倍。

2、对于3D-UNet,由于其占用显存较多,故限制其输入在128x128x128以内,如果数据集的median shape小于该限制,则适当增加batch size,网络下采样次数最大值是5次,保证最小特征图的尺寸小于等于8。

此外,对于batch size做出了额外限制:batch size不大于数据集数据数量的5%不小于2,最终配置如下图(在MSD挑战赛7个数据集上的配置):

在这里插入图片描述

三、预处理

nnUNet预处理也是自动配置策略,具体内容如下。

1、Cropping

将数据集为0的区域裁剪丢弃,为0的区域对大部分数据集来说是无用的,然而裁剪后可以有效地降低计算成本。源码传送

2、Resampling

3D医疗图像和普通的2D图像不同,不同图片的像素间距(spacing)由于采集数据时的设置不同可能不同,对所有数据重采样,使voxel spacing为数据集的median voxel spacing。相关概念可参考这里,此处不多赘述。源码参考

3、Normalization

对于CT数据,统计整个数据集的像素均值和方差,对像素强度分布在[0.5,99.5]的像素进行归一化(这里可能容易引起误解,建议看一下源码),非CT数据,单个数据使用各自的均值标准差归一化,源码链接同CT。

四、训练

1、损失函数

损失函数采用Dice_loss + cross entropy loss(源码链接),源码中使用了多个辅助损失函数,例如输入是512x512,经过6次下采样,得到的特征图有256x256,128x128,64x64,32x32,16x16,8x8,那么对标签进行平均降低分辨率(注意:非最近邻插值,参考源码)。

2、数据增强

论文中提到使用的数据增强有:随机旋转、随机缩放、随机弹性变换、伽马校正、镜像,但是repo中还使用了别的(亮度、噪音、模糊等等),具体参考源码

3、Patch sampling

因为医疗图像中前景非常的小,使用patch训练很可能出现数据中不包含前景的情况,为了训练的稳定,强制一个batch数据中1/3的数据包含至少一个前景。源码链接

五、后处理、推理、ensemble、验证

1、后处理

大部分数据中都只有一个前景,故添加了可选的后处理:保留最大连通区域。这里可能会引入一个问题,保留最大连通区域一定会提高dice coefficient吗?其实后处理是可选的,可以在验证的时候查看使用后处理和不适用后处理的dice coef,然后决定是否使用后处理。

2、推理

在推理阶段,使用patch推理策略,patch移动步长是0.5 * patch_size,并且给予patch中心和边缘不同的权重,此外使用了测试时数据增强策略镜像。

3、ensemble

经过训练,得到了2D-UNet、3D-UNet、UNet-cascade 共3个模型,使用3个模型分别预测验证集(验证集的概念见下一小节),然后将预测结果分别两两融合,并查看dice coef,最后在多个模型及融合的模型中选择dice coef最好的模型提交。

4、验证

数据集中并未给出验证集,作者使用5折交叉验证的方法,训练5个模型,每个模型有对应的验证集,已2D-UNet为例,5个模型分别预测验证集数据,那么验证集的总数就等于训练集,在整个数据集上验证dice,并作为2D-UNet的模型性能。在预测阶段,将5个模型融合预测。

验证集划分代码链接
多折模型融合预测代码链接

PS:简单来说,一共需要训练3 * 5个模型(跑这个项目先看看算力卡够不够哦)。

六、实验结果

nnUNet在7个数据集上的性能如下,加粗表示该模型为提交的模型。(作者仅提交一次就能拿下第一,强无敌)
在这里插入图片描述

七、快速体验

经过一番介绍,来到了快速体验环节,需要声明几点:
1、自己测试了以下nnUNet官方repo在msd lung数据集上训练速度(付费版本BML平台),一个epoch约8~9分钟,一共1000个epoch,虽然有early stop策略,但是训练的时间成本依然无法接受,故降低了训练epoch;
2、本次复现不支持多卡训练,官方repo也不推荐多卡训练(参考README多卡训练部分);
3、由于aistudio平台100GB内存限制,无法同时训练2D-UNet和3D-UNet,每次训练后建议删除另一个模型的数据;
4、日志中val的精度请忽略,这是参考了官方repo的方法,使用了一些经过数据增强的假数据来验证,具体可以参考官方repo(训练代码);
5、其实官方设置的epoch也不准确,可以参考这里,训练集和验证集的训练批次都是自己设置的;

PS:看了论文,感觉nnUNet并没有什么太难的地方,都是一些容易理解的概念,但是实际复现起来完全不同,将一些简单的理论结合到一起,并且实现,是一个很复杂的过程,在整个复现过程中,无数次被官方repo绕晕(官方repo的代码真的有点乱),初次以外,nnUNet的代码量极大,此次复现大部分代码来自官方repo,且由于个人水平有限,可能存在错误,欢迎各位大佬批评指正。

本次复现的精度如下,权重请参考本文开头复现链接。

NetWorkpost_processingfoldsstepsoptimage_sizebatch_sizedatasetmemorycardavg diceconfigweightlog
nnUNet_2dFalse525k/30kAdam--MSD LUNG32G152.397%---
nnUNet_2dTrue525k/30kAdam--MSD LUNG32G153.549%---
nnUNet_cascade_stage1False525kAdam--MSD LUNG32G167.676%---
nnUNet_cascade_stage1True525kAdam--MSD LUNG32G168.281%---
nnUNet_cascade_stage2False520kAdam--MSD LUNG32G159.894%---
nnUNet_cascade_stage2True520kAdam--MSD LUNG32G167.996%---
ensemble_2d_3d_cascadeFalse5----MSD LUNG32G162.635%---
ensemble_2d_3d_cascadeTrue5----MSD LUNG32G164.355%---
# step 1: git clone 
%cd ~/
!git clone https://gitee.com/dudulang001/nnunet_paddle.git
# step 2: pip install requirements
%cd ~/nnunet_paddle/
!pip install -r requirements.txt
# step 3: 解压数据集
%cd ~/
!tar -xf ~/data/data125872/Task06_Lung.tar -C ~/data

1、2D UNet

第一次运行2D UNet训练,会有一个数据预处理的过程,包含crop、resample、normalization等过程,这个时间非常的长(Tesla V100约1-2H)。

此外,与官方repo一致,使用混合精度训练。

# 2d unet fold 0  train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_2d_fold_0.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet_2d/fold_0

# 2d unet fold 0  val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_2d_fold_0.yml \
    --model_path output/nnunet_2d/fold_0/iter25000/model.pdparams --precision fp16 --save_dir ~/val_2d --val_save_folder ~/val_2d
# 2d unet fold 1  train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_2d_fold_1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet_2d/fold_1

# 2d unet fold 1  val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_2d_fold_1.yml \
    --model_path output/nnunet_2d/fold_1/iter25000/model.pdparams --precision fp16 --save_dir ~/val_2d --val_save_folder ~/val_2d
# 2d unet fold 2  train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_2d_fold_2.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet_2d/fold_2

# 2d unet fold 2  val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_2d_fold_2.yml \
    --model_path output/nnunet_2d/fold_2/iter25000/model.pdparams --precision fp16 --save_dir ~/val_2d --val_save_folder ~/val_2d
# 2d unet fold 3  train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_2d_fold_3.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet_2d/fold_3

# 2d unet fold 3  val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_2d_fold_3.yml \
    --model_path output/nnunet_2d/fold_3/iter25000/model.pdparams --precision fp16 --save_dir ~/val_2d --val_save_folder ~/val_2d
# 2d unet fold 4 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_2d_fold_4.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet_2d/fold_4

# 2d unet fold 4  val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_2d_fold_4.yml \
    --model_path output/nnunet_2d/fold_4/iter25000/model.pdparams --precision fp16 --save_dir ~/val_2d --val_save_folder ~/val_2d
# 5折总验证,请依次运行以上验证步骤后,在运行这一步
%cd ~/nnunet_paddle/
!python nnunet_tools/all_fold_eval.py --gt_dir ~/val_2d/gt_niftis --val_dir ~/val_2d

2、UNet—cascade 第一阶段

级联训练需要注意:
1、验证的时候加上–predict_next_stage,否则下一阶段没法训练;
2、记得删除之前2DUNet的数据集,否则内存不够用了。

# 3d unet cascade 第一阶段 fold 0 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_0.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage0/fold_0
# 3d unet cascade 第一阶段 fold 0 val
python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_0.yml --model_path output/nnunet3d_stage0/fold_0/iter_25000/model.pdparams --precision fp16 --save_dir ~/val_3d --val_save_folder ~/val_3d --predict_next_stage 
# 3d unet cascade 第一阶段 fold 1 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage0/fold_1
# 3d unet cascade 第一阶段 fold 1 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_1.yml --model_path output/nnunet3d_stage0/fold_1/iter_25000/model.pdparams --precision fp16 --save_dir ~/val_3d --val_save_folder ~/val_3d --predict_next_stage 
# 3d unet cascade 第一阶段 fold 2 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_2.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage0/fold_2
# 3d unet cascade 第一阶段 fold 2 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_2.yml --model_path output/nnunet3d_stage0/fold_2/iter_25000/model.pdparams --precision fp16 --save_dir ~/val_3d --val_save_folder ~/val_3d --predict_next_stage 
# 3d unet cascade 第一阶段 fold 3 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_3.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage0/fold_3
# 3d unet cascade 第一阶段 fold 3 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_3.yml --model_path output/nnunet3d_stage0/fold_3/iter_25000/model.pdparams --precision fp16 --save_dir ~/val_3d --val_save_folder ~/val_3d --predict_next_stage 
# 3d unet cascade 第一阶段 fold 4 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_4.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage0/fold_4
# 3d unet cascade 第一阶段 fold 4 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_4.yml --model_path output/nnunet3d_stage0/fold_4/iter_25000/model.pdparams --precision fp16 --save_dir ~/val_3d --val_save_folder ~/val_3d --predict_next_stage 

3、UNet—cascade 第二阶段

由于单卡训练太耗时,复现赛快结束,所以降低了训练iters。

# 3d unet cascade 第二阶段 fold 0 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_0_stage1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage1/fold_0

# 3d unet cascade 第二阶段 fold 0 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_0_stage1.yml --model_path output/nnunet3d_stage1/fold_0/iter_16000/model.pdparams --precision fp16 --save_dir ~/val_3d_stage1 --val_save_folder ~/val_3d_stage1

# 3d unet cascade 第二阶段 fold 1 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_1_stage1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage1/fold_1

# 3d unet cascade 第二阶段 fold 1 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_1_stage1.yml --model_path output/nnunet3d_stage1/fold_1/iter_16000/model.pdparams --precision fp16 --save_dir ~/val_3d_stage1 --val_save_folder ~/val_3d_stage1

# 3d unet cascade 第二阶段 fold 2 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_2_stage1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage1/fold_2

# 3d unet cascade 第二阶段 fold 2 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_2_stage1.yml --model_path output/nnunet3d_stage1/fold_2/iter_16000/model.pdparams --precision fp16 --save_dir ~/val_3d_stage1 --val_save_folder ~/val_3d_stage1

# 3d unet cascade 第二阶段 fold 3 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_3_stage1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage1/fold_3

# 3d unet cascade 第二阶段 fold 3 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_3_stage1.yml --model_path output/nnunet3d_stage1/fold_3/iter_16000/model.pdparams --precision fp16 --save_dir ~/val_3d_stage1 --val_save_folder ~/val_3d_stage1

# 3d unet cascade 第二阶段 fold 4 train
%cd ~/nnunet_paddle/
!python train.py --config configs/msd/msd_lung_3d_fold_4_stage1.yml --use_vdl --do_eval --log_iters 10 --save_interval 1000 --seed 10000 --precision fp16 --save_dir output/nnunet3d_stage1/fold_4

# 3d unet cascade 第二阶段 fold 4 val
%cd ~/nnunet_paddle/
!python nnunet_tools/nnunet_fold_val.py --config configs/msd/msd_lung_3d_fold_4_stage1.yml --model_path output/nnunet3d_stage1/fold_4/iter_16000/model.pdparams --precision fp16 --save_dir ~/val_3d_stage1 --val_save_folder ~/val_3d_stage1

# 5折总验证
%cd ~/nnunet_paddle/
!python nnunet_tools/all_fold_eval.py --gt_dir ~/val_3d/gt_niftis --val_dir ~/val_3d_stage1

4、ensemble

%cd ~/nnunet_paddle/
!python nnunet_tools/ensemble.py --nnunet_2d_val_dir ~/val_2d --nnunet_3d_cascade_val_dir ~/val_3d_stage1 --ensemble_output ~/nnunet_ensembles --plan_2d_path /home/aistudio/data/preprocessed/Task006_Lung/nnUNetPlansv2.1_plans_3D.pkl \
    --gt_dir /home/aistudio/data/preprocessed/Task006_Lung/gt_segmentations
    

5、预测示例

由于使用多个模型预测,故需要将权重目录格式改为:
|–model_dir
|—|-fold0
|—|–|-model.pdparams

%cd ~/nnunet_paddle/
!python nnunet_tools/predict.py -i ~/data/Task006_Lung/imagesTs \
    -o ~/nnunet_predict/predict_3d \
    -plan_path /home/aistudio/data/preprocessed/Task006_Lung/nnUNetPlansv2.1_plans_3D.pkl \
    -model_dir ~/MedicalSeg/output/nnunet3d_stage0  -folds 5 \
    -postprocessing_json_path ~/val_3d_stage0/postprocessing.json \

八、复现经验

1、这篇论文从了解到复现完成大概耗时5个月+(比赛之前就有了解),工作量是以前的复现工作的数倍,跑通官方repo,数据对其、损失函数对其等等过程掉了无数的头发,回首望去,总结下来就是坚持二字,在想放弃的时候,务必再坚持一下;
2、如果在复现过程中遇到问题,记得及时联系RD(nnUNet如果采用官方的1000epoch训练,单卡不触发early stop训练15个model大概需要83天左右,这是不大现实的);
3、不管遇到什么困难,一定要坚持,坚持才会成功。

九、致谢

感谢官方的算力支持,感谢RD 收费BML的算力支持,感谢RD小姐姐的耐心答疑。

作者仅为AiStudio搬运,原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4196840

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐