『医学影像』基于KiUnet的ChestRay肺炎分类
本项目主要对KIUnet网络结构进行复现,并演示如何使用ChestRay肺炎分类数据集。注释完整,结构清晰,欢迎Fork学习
基于KiUnet的ChestRay肺炎分类
本项目主要对KIUnet网络结构进行复现,并演示如何使用ChestRay肺炎分类数据集。
0.研究动机
随着疫情的不断反复,新冠肺炎的识别成为了研究的热点。但现有的数据集主要研究的是正常状态和新冠肺炎的判别。
但肺炎的种类分为多种,本项目主要基于ChestRay数据集研究正常状态,病毒性肺炎,细菌性肺炎的识别。
不同于平台现有的分类网络结构,本项目探索使用分割网络结构进行分类。
其次对本项目对KIUnet2D分割网络进行了学习复现复现。
1.数据集介绍
ChestXRay2017数据集共包含5856张胸腔X射线透视图,诊断结果(即分类标签)主要分为正常和肺炎,其中肺炎又可以细分为:细菌性肺炎和病毒性肺炎。
胸腔X射线图像选自广州市妇幼保健中心的1至5岁儿科患者的回顾性研究。所有胸腔X射线成像都是患者常规临床护理的一部分。
为了分析胸腔X射线图像,首先对所有胸腔X光片进行了筛查,去除所有低质量或不可读的扫描,从而保证图片质量。然后由两名专业医师对图像的诊断进行分级,最后为降低图像诊断错误,
还由第三位专家检查了测试集。
主要分为train和test两大子文件夹,分别用于模型的训练和测试。在每个子文件内又分为了NORMAL(正常)和PNEUMONIA(肺炎)两大类。
在PNEUMONIA文件夹内含有细菌性和病毒性肺炎两类,可以通过图片的命名格式进行判别。
三类不同的数据集展示为:
#解压数据集,仅第一次运行的时候运行
!unzip -o /home/aistudio/data/data106874/ChestXRay2017.zip -d /home/aistudio/work/
## 查看图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# 选择三张任意要求图片
_bacteria = 'work/ChestXRay2017/chest_xray/train/PNEUMONIA/person276_bacteria_1296.jpeg'
_virus = 'work/ChestXRay2017/chest_xray/train/PNEUMONIA/person478_virus_975.jpeg'
_normal = 'work/ChestXRay2017/chest_xray/train/NORMAL/NORMAL2-IM-1442-0001.jpeg'
# 读取
_bacteria = Image.open(_bacteria).convert('RGB')
_virus = Image.open(_virus).convert('RGB')
_normal = Image.open(_normal).convert('RGB')
# 绘图
plt.figure(figsize=(9, 5))
plt.subplot(1,3,1),plt.xticks([]),plt.yticks([]),plt.title('bacteria'),plt.imshow(_bacteria)
plt.subplot(1,3,2),plt.xticks([]),plt.yticks([]),plt.title('virus'),plt.imshow(_virus)
plt.subplot(1,3,3),plt.xticks([]),plt.yticks([]),plt.title('normal'),plt.imshow(_normal)
plt.show()
# 安装需要的库函数
!pip install paddleseg
!pip install opencv-python
import os
import glob
import paddle
from paddle.io import Dataset
import paddleseg.transforms as T
import numpy as np
import random
from PIL import Image
import numpy as np
# 重写数据读取类
class ChestXRayDataset(Dataset):
def __init__(self,mode = 'train',transform =None):
### 读取数据
rootPath = 'work/ChestXRay2017/chest_xray'
trainPath = os.path.join(rootPath,'train')
testPath = os.path.join(rootPath,'test')
self.transforms = transform
self.mode = mode
if self.mode == 'train':
## 读取非肺炎数据
normalTrainPath = os.path.join(trainPath,'NORMAL')
path_ =normalTrainPath + '/*.jpeg' # 使用通配符进行匹配
normalTrainList_ = glob.glob(path_)
## 读取肺炎数据
normalTrainPath = os.path.join(trainPath,'PNEUMONIA')
# 读取细菌性肺炎
path_ =normalTrainPath + '/*bacteria*.jpeg'
bacteriaTrainList_ = glob.glob(path_)
# 读取病毒性肺炎
path_ =normalTrainPath + '/*virus*.jpeg'
virusTrainList_ = glob.glob(path_)
# 加入标签 0【正常】,1【细菌性肺炎】,2【病毒性肺炎】
normalTrainList = [[item,0] for item in normalTrainList_]
bacteriaTrainList = [[item,1] for item in bacteriaTrainList_]
virusTrainList = [[item,2] for item in virusTrainList_]
self.jpeg_list = normalTrainList + bacteriaTrainList + virusTrainList
random.shuffle( self.jpeg_list )
else: # test
## 读取非肺炎数据
normalTestPath = os.path.join(testPath,'NORMAL')
path_ =normalTestPath + '/*.jpeg'
normalTestList_ = glob.glob(path_)
## 读取肺炎数据
normalTestPath = os.path.join(testPath,'PNEUMONIA')
# 读取细菌性肺炎
path_ =normalTestPath + '/*bacteria*.jpeg'
bacteriaTestList_ = glob.glob(path_)
# 读取病毒性肺炎
path_ =normalTestPath + '/*virus*.jpeg'
virusTestList_ = glob.glob(path_)
# 加入标签 0【正常】,1【细菌性肺炎】,2【病毒性肺炎】
normalTestList = [[item,0] for item in normalTestList_]
bacteriaTestList = [[item,1] for item in bacteriaTestList_]
virusTestList = [[item,2] for item in virusTestList_]
self.jpeg_list = normalTestList + bacteriaTestList + virusTestList
random.shuffle( self.jpeg_list )
def __getitem__(self, index):
pic,label = self.jpeg_list[index]
# 读取ipeg数据
data = Image.open(pic).convert('RGB')
#data = data.transpose((2,0,1))
if self.transforms:
data = self.transforms(data)
return data,label
def __len__(self):
return len(self.jpeg_list)
# ## 统计均值方差,用于进行数据归一化处理
# import os
# import cv2
# import numpy as np
# import glob
# from tqdm import tqdm
# def getMeanStd(allJpegList_):
# '''
# input: List 需要计算的所有图片的路径
# return: List [means, stdevs]
# '''
# means, stdevs = [], []
# img_dict = {}
# # 使用字典比使用list然后append更快
# # 原因是append每次需要重新申请地址,当数据量大时,后面append会越来越慢,因此可以改成dict,然后提取value转为list
# # 本例子中时间可以减少一半,数据越多加速越明显
# for idx,imgs_path in enumerate(allJpegList_):
# img = cv2.cvtColor(cv2.imread(imgs_path), cv2.COLOR_BGR2RGB)
# temp_ = img.reshape(-1,3)
# mean_ = np.expand_dims(np.mean(temp_,0),1)
# std_ = np.expand_dims(np.std(temp_,0),1)
# img_dict[idx] = np.concatenate((mean_,std_),axis=1)
# #获取字典的值
# img_list = list(img_dict.values())
# stas_ = np.mean(img_list, axis=0)/255
# return stas_
# # 读取训练集及测试集数据
# # 数据为RGB数据
# path_ ='work/ChestXRay2017/chest_xray/*/*/*.jpeg'
# allJpegList_ = glob.glob(path_)
# statistic_result = getMeanStd(allJpegList_)
# # 统计结果,第一列是RGB的均值,第二列是RGB的方差
# print(statistic_result)
# '''
# [0.48151479 0.22348982]
# [0.48151479 0.22348982]
# [0.48151479 0.22348982]
# '''
2.论文介绍
2.0 KiU-Net: Towards Accurate Segmentation of Biomedical Images using Over-complete Representat
大多数医学图像分割方法使用U-Net或其网络变体已经成功应用在大多数医学场景下。
这种“传统”的编码器-解码器方法在于较小的结构和边界分割时存在较大误差。
文献[1]认为当Unet过分关注深层特征时,会减少对浅层特征的注意。
为了克服这个问题,文献[1]提出了超完整卷积体系结构将输入图像转换为更高的维度,以便抑制深层次感受野的关注度。
2.1 文章亮点
-
探索了过完备的网络结构Ki-Net
-
将欠完备与过完备深度网络结合起来提出了新的网络结构KiU-Net
-
在分割领域取得了更快的收敛速度和更好的性能
为了更好地将每一个卷积块的特征结合起来,作者提出了一个cross residual fusion block。
两个分支相同层级的特征图同时作为输入,然后获得两个输出作为两个分支下一层级的输入。
将UNet分支的特征图通过卷积层、ReLu之后与Ki-Net的特征图相加作为Ki-Net下一层级的输入。同样的
将Ki-Net分支的特征图通过卷积层、ReLu之后与UNet的特征图相加作为UNet下一层级的输入。
最后将两个分支的特征图相加之后通过一个1x1的卷积获得输出的分割图像。[2]
2.2 网络结构
-
探索了过完备的网络结构Ki-Net
-
将欠完备与过完备深度网络结合起来提出了新的网络结构KiU-Net
-
在分割领域取得了更快的收敛速度和更好的性能
2.3 论文结果
实验结果显示训练速度更快,精度更高,并且模型更小
(笔者注)缺点是实际训练中需要的时间极长,需要的内存也更大
2.4 参考文献
[1] Valanarasu J M J, Sindagi V A, Hacihaliloglu I, et al. Kiu-net: Towards accurate segmentation of biomedical images using over-complete representations[C]//International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2020: 363-373.
# 定义KIUnet
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import initializer
def init_weights(init_type='kaiming'):
if init_type == 'normal':
return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())
elif init_type == 'xavier':
return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())
elif init_type == 'kaiming':
return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
class kiunet(nn.Layer):
def __init__(self ,in_channels = 3, n_classes =3):
super(kiunet,self).__init__()
self.in_channels = in_channels
self.n_class = n_classes
self.encoder1 = nn.Conv2D(self.in_channels, 16, 3, stride=1, padding=1) # First Layer GrayScale Image , change to input channels to 3 in case of RGB
self.en1_bn = nn.BatchNorm(16)
self.encoder2= nn.Conv2D(16, 32, 3, stride=1, padding=1)
self.en2_bn = nn.BatchNorm(32)
self.encoder3= nn.Conv2D(32, 64, 3, stride=1, padding=1)
self.en3_bn = nn.BatchNorm(64)
self.decoder1 = nn.Conv2D(64, 32, 3, stride=1, padding=1)
self.de1_bn = nn.BatchNorm(32)
self.decoder2 = nn.Conv2D(32,16, 3, stride=1, padding=1)
self.de2_bn = nn.BatchNorm(16)
self.decoder3 = nn.Conv2D(16, 8, 3, stride=1, padding=1)
self.de3_bn = nn.BatchNorm(8)
self.decoderf1 = nn.Conv2D(64, 32, 3, stride=1, padding=1)
self.def1_bn = nn.BatchNorm(32)
self.decoderf2= nn.Conv2D(32, 16, 3, stride=1, padding=1)
self.def2_bn = nn.BatchNorm(16)
self.decoderf3 = nn.Conv2D(16, 8, 3, stride=1, padding=1)
self.def3_bn = nn.BatchNorm(8)
self.encoderf1 = nn.Conv2D(in_channels, 16, 3, stride=1, padding=1) # First Layer GrayScale Image , change to input channels to 3 in case of RGB
self.enf1_bn = nn.BatchNorm(16)
self.encoderf2= nn.Conv2D(16, 32, 3, stride=1, padding=1)
self.enf2_bn = nn.BatchNorm(32)
self.encoderf3 = nn.Conv2D(32, 64, 3, stride=1, padding=1)
self.enf3_bn = nn.BatchNorm(64)
self.intere1_1 = nn.Conv2D(16,16,3, stride=1, padding=1)
self.inte1_1bn = nn.BatchNorm(16)
self.intere2_1 = nn.Conv2D(32,32,3, stride=1, padding=1)
self.inte2_1bn = nn.BatchNorm(32)
self.intere3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)
self.inte3_1bn = nn.BatchNorm(64)
self.intere1_2 = nn.Conv2D(16,16,3, stride=1, padding=1)
self.inte1_2bn = nn.BatchNorm(16)
self.intere2_2 = nn.Conv2D(32,32,3, stride=1, padding=1)
self.inte2_2bn = nn.BatchNorm(32)
self.intere3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)
self.inte3_2bn = nn.BatchNorm(64)
self.interd1_1 = nn.Conv2D(32,32,3, stride=1, padding=1)
self.intd1_1bn = nn.BatchNorm(32)
self.interd2_1 = nn.Conv2D(16,16,3, stride=1, padding=1)
self.intd2_1bn = nn.BatchNorm(16)
self.interd3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)
self.intd3_1bn = nn.BatchNorm(64)
self.interd1_2 = nn.Conv2D(32,32,3, stride=1, padding=1)
self.intd1_2bn = nn.BatchNorm(32)
self.interd2_2 = nn.Conv2D(16,16,3, stride=1, padding=1)
self.intd2_2bn = nn.BatchNorm(16)
self.interd3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)
self.intd3_2bn = nn.BatchNorm(64)
self.final = nn.Sequential(
nn.Conv2D(8,self.n_class,1,stride=1,padding=0),
nn.AdaptiveAvgPool2D(output_size=1))
# initialise weights
for m in self.sublayers ():
if isinstance(m, nn.Conv2D):
m.weight_attr = init_weights(init_type='kaiming')
m.bias_attr = init_weights(init_type='kaiming')
elif isinstance(m, nn.BatchNorm):
m.param_attr =init_weights(init_type='kaiming')
m.bias_attr = init_weights(init_type='kaiming')
def forward(self, x):
# input: c * h * w -> 16 * h/2 * w/2
out = F.relu(self.en1_bn(F.max_pool2d(self.encoder1(x),2,2))) #U-Net branch
# c * h * w -> 16 * 2h * 2w
out1 = F.relu(self.enf1_bn(F.interpolate(self.encoderf1(x),scale_factor=(2,2),mode ='bicubic'))) #Ki-Net branch
# 16 * h/2 * w/2
tmp = out
# 16 * 2h * 2w -> 16 * h/2 * w/2
out = paddle.add(out,F.interpolate(F.relu(self.inte1_1bn(self.intere1_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic')) #CRFB
# 16 * h/2 * w/2 -> 16 * 2h * 2w
out1 = paddle.add(out1,F.interpolate(F.relu(self.inte1_2bn(self.intere1_2(tmp))),scale_factor=(4,4),mode ='bicubic')) #CRFB
# 16 * h/2 * w/2
u1 = out #skip conn
# 16 * 2h * 2w
o1 = out1 #skip conn
# 16 * h/2 * w/2 -> 32 * h/4 * w/4
out = F.relu(self.en2_bn(F.max_pool2d(self.encoder2(out),2,2)))
# 16 * 2h * 2w -> 32 * 4h * 4w
out1 = F.relu(self.enf2_bn(F.interpolate(self.encoderf2(out1),scale_factor=(2,2),mode ='bicubic')))
# 32 * h/4 * w/4
tmp = out
# 32 * 4h * 4w -> 32 * h/4 *w/4
out = paddle.add(out,F.interpolate(F.relu(self.inte2_1bn(self.intere2_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))
# 32 * h/4 * w/4 -> 32 *4h *4w
out1 = paddle.add(out1,F.interpolate(F.relu(self.inte2_2bn(self.intere2_2(tmp))),scale_factor=(16,16),mode ='bicubic'))
# 32 * h/4 *w/4
u2 = out
# 32 *4h *4w
o2 = out1
# 32 * h/4 *w/4 -> 64 * h/8 *w/8
out = F.relu(self.en3_bn(F.max_pool2d(self.encoder3(out),2,2)))
# 32 *4h *4w -> 64 * 8h *8w
out1 = F.relu(self.enf3_bn(F.interpolate(self.encoderf3(out1),scale_factor=(2,2),mode ='bicubic')))
# 64 * h/8 *w/8
tmp = out
# 64 * 8h *8w -> 64 * h/8 * w/8
out = paddle.add(out,F.interpolate(F.relu(self.inte3_1bn(self.intere3_1(out1))),scale_factor=(0.015625,0.015625),mode ='bicubic'))
# 64 * h/8 *w/8 -> 64 * 8h * 8w
out1 = paddle.add(out1,F.interpolate(F.relu(self.inte3_2bn(self.intere3_2(tmp))),scale_factor=(64,64),mode ='bicubic'))
### End of encoder block
### Start Decoder
# 64 * h/8 * w/8 -> 32 * h/4 * w/4
out = F.relu(self.de1_bn(F.interpolate(self.decoder1(out),scale_factor=(2,2),mode ='bicubic'))) #U-NET
# 64 * 8h * 8w -> 32 * 4h * 4w
out1 = F.relu(self.def1_bn(F.max_pool2d(self.decoderf1(out1),2,2))) #Ki-NET
# 32 * h/4 * w/4
tmp = out
# 32 * 4h * 4w -> 32 * h/4 * w/4
out = paddle.add(out,F.interpolate(F.relu(self.intd1_1bn(self.interd1_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))
# 32 * h/4 * w/4 -> 32 * 4h * 4w
out1 = paddle.add(out1,F.interpolate(F.relu(self.intd1_2bn(self.interd1_2(tmp))),scale_factor=(16,16),mode ='bicubic'))
# 32 * h/4 * w/4
out = paddle.add(out,u2) #skip conn
# 32 * 4h * 4w
out1 = paddle.add(out1,o2) #skip conn
# 32 * h/4 * w/4 -> 16 * h/2 * w/2
out = F.relu(self.de2_bn(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bicubic')))
# 32 * 4h * 4w -> 16 * 2h * 2w
out1 = F.relu(self.def2_bn(F.max_pool2d(self.decoderf2(out1),2,2)))
# 16 * h/2 * w/2
tmp = out
# 16 * 2h * 2w -> 16 * h/2 * w/2
out = paddle.add(out,F.interpolate(F.relu(self.intd2_1bn(self.interd2_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic'))
# 16 * h/2 * w/2 -> 16 * 2h * 2w
out1 = paddle.add(out1,F.interpolate(F.relu(self.intd2_2bn(self.interd2_2(tmp))),scale_factor=(4,4),mode ='bicubic'))
# 16 * h/2 * w/2
out = paddle.add(out,u1)
# 16 * 2h * 2w
out1 = paddle.add(out1,o1)
# 16 * h/2 * w/2 -> 8 * h * w
out = F.relu(self.de3_bn(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bicubic')))
# 16 * 2h * 2w -> 8 * h * w
out1 = F.relu(self.def3_bn(F.max_pool2d(self.decoderf3(out1),2,2)))
# 8 * h * w
out = paddle.add(out,out1) # fusion of both branches
# 最后一层用sigmoid激活函数
out = F.sigmoid(self.final(out)) #1*1 conv
return out
# 可视化KIunet结构
KIunet = kiunet(in_channels = 3, n_classes =3)
model = paddle.Model(KIunet)
model.summary((2,3, 256, 256))
3. 训练代码
-
训练时间太长,仅训练两个EPOCH作为说明
-
修改网络结构最后一层,增加AdaptiveAvgPool2D将网络用于分类问题。
-
添加了丰富的数据增强方法
model = kiunet(in_channels = 3, n_classes =3)
# 开启模型训练模式
model.train()
# 定义优化算法,使用随机梯度下降Adam,学习率设置为0.001,学习率策略为StepDecay
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=20, gamma=0.1, verbose=False)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
EPOCH_NUM = 2 # 设置外层循环次数
BATCH_SIZE = 2 # 设置batch大小
from paddle.vision import transforms as T
Transforms_train = T.Compose([
T.RandomHorizontalFlip(0.5), # 水平翻转
T.RandomVerticalFlip(0.5), # 垂直翻转
T.RandomRotation(15), # 随机旋转
T.Resize(( 128, 128 )), # 调整大小
T.ColorJitter(0.2, 0.2, 0.2, 0.2),# 随机调整亮度,对比度,饱和度和色调。
T.Transpose(),
T.Normalize(
[122.78627145, 122.78627145, 122.78627145],
[56.9899041, 56.9899041, 56.9899041]), # 标准化
])
Transforms_test = T.Compose([
T.Resize(( 256, 256 )), # 调整大小
T.Transpose(),
T.Normalize(
[122.78627145, 122.78627145, 122.78627145],
[56.9899041, 56.9899041, 56.9899041]), # 标准化
])
train_dataset = ChestXRayDataset(mode='train',transform = Transforms_train)
test_dataset = ChestXRayDataset(mode='test',transform = Transforms_test)
# 使用paddle.io.DataLoader 定义DataLoader对象用于加载Python生成器产生的数据,
data_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_data_loader = paddle.io.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 使用BCEloss
loss_BCEloss = paddle.nn.BCELoss()
W1227 16:42:41.671527 6593 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1227 16:42:41.675050 6593 device_context.cc:465] device: 0, cuDNN Version: 7.6.
# 定义外层循环
for epoch_id in range(EPOCH_NUM):
# 定义内层循环
for iter_id, data in enumerate(data_loader()):
x, y = data # x 为数据 ,y 为标签
y = np.squeeze(y)
# 将标签转为onehot编码
one_hot = paddle.nn.functional.one_hot(y, num_classes=3) #一共三个类别
# 将numpy数据转为飞桨动态图tensor形式
x = paddle.to_tensor(x,dtype='float32')
y = paddle.to_tensor(one_hot,dtype='float32')
# 前向计算
predicts = model(x)
predicts = paddle.squeeze(predicts)
# 计算损失
loss = loss_BCEloss(predicts, y)
# 清除梯度
optimizer.clear_grad()
# 反向传播
loss.backward()
# 最小化loss,更新参数
optimizer.step()
scheduler.step()
print("epoch: {}, iter: {}, loss is: {}".format(epoch_id+1, iter_id+1, loss.numpy()))
# 保存模型参数,文件名为 kiunet_model.pdparams
paddle.save(model.state_dict(), 'work/kiunet_model.pdparams')
print("模型保存成功,模型参数保存在 kiunet_model.pdparams中")
4. 测试代码
- 计算平均误差
import paddle
from sklearn.metrics import accuracy_score
# 模型验证
# 清理缓存
print("开始测试")
# 用于加载之前的训练过的模型参数
para_state_dict = paddle.load('work/kiunet_model.pdparams')
model = kiunet(in_channels = 3, n_classes =3)
model.set_dict(para_state_dict)
Error = []
for iter_id, data in enumerate(test_data_loader()):
x, y = data # x 为数据 ,y 为标签
# 将numpy数据转为飞桨动态图tensor形式
x = paddle.to_tensor(x,dtype='float32')
# 前向计算
predicts = model(x)
predicts = paddle.squeeze(predicts)
predicts = predicts.cpu().numpy()
y = y.cpu().numpy()
predLabel = np.argmax(predicts,1)
Error.append(accuracy_score(predLabel, y))
r_id, data in enumerate(test_data_loader()):
x, y = data # x 为数据 ,y 为标签
# 将numpy数据转为飞桨动态图tensor形式
x = paddle.to_tensor(x,dtype='float32')
# 前向计算
predicts = model(x)
predicts = paddle.squeeze(predicts)
predicts = predicts.cpu().numpy()
y = y.cpu().numpy()
predLabel = np.argmax(predicts,1)
Error.append(accuracy_score(predLabel, y))
print("模型测试集平均定位误差为:",np.mean(Error))
开始测试
模型测试集平均定位误差为: 0.38782051282051283
5. 项目总结
-
本项目主要对KIUnet算法进行实现,实现了修改Unet类网络在分类任务的应用。
-
KIUnet的训练时间久,需要的内存较大,建议对该数据感兴趣的同学可以自行修改为分类网络,平台上已经有许多注释完整的优秀分类项目。
-
本项目主要是演示该数据集的使用及KIUnet网络结构的复现,代码注释完整,适合对分割和分类感兴趣的同学学习。
更多推荐
所有评论(0)