基于KiUnet的ChestRay肺炎分类

本项目主要对KIUnet网络结构进行复现,并演示如何使用ChestRay肺炎分类数据集。

0.研究动机

随着疫情的不断反复,新冠肺炎的识别成为了研究的热点。但现有的数据集主要研究的是正常状态和新冠肺炎的判别。

但肺炎的种类分为多种,本项目主要基于ChestRay数据集研究正常状态,病毒性肺炎,细菌性肺炎的识别。

不同于平台现有的分类网络结构,本项目探索使用分割网络结构进行分类。

其次对本项目对KIUnet2D分割网络进行了学习复现复现。

1.数据集介绍

ChestXRay2017数据集共包含5856张胸腔X射线透视图,诊断结果(即分类标签)主要分为正常和肺炎,其中肺炎又可以细分为:细菌性肺炎和病毒性肺炎。

胸腔X射线图像选自广州市妇幼保健中心的1至5岁儿科患者的回顾性研究。所有胸腔X射线成像都是患者常规临床护理的一部分。

为了分析胸腔X射线图像,首先对所有胸腔X光片进行了筛查,去除所有低质量或不可读的扫描,从而保证图片质量。然后由两名专业医师对图像的诊断进行分级,最后为降低图像诊断错误,
还由第三位专家检查了测试集。

主要分为train和test两大子文件夹,分别用于模型的训练和测试。在每个子文件内又分为了NORMAL(正常)和PNEUMONIA(肺炎)两大类。

在PNEUMONIA文件夹内含有细菌性和病毒性肺炎两类,可以通过图片的命名格式进行判别。

三类不同的数据集展示为:

#解压数据集,仅第一次运行的时候运行
!unzip  -o  /home/aistudio/data/data106874/ChestXRay2017.zip -d /home/aistudio/work/
## 查看图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# 选择三张任意要求图片
_bacteria = 'work/ChestXRay2017/chest_xray/train/PNEUMONIA/person276_bacteria_1296.jpeg'
_virus = 'work/ChestXRay2017/chest_xray/train/PNEUMONIA/person478_virus_975.jpeg'
_normal = 'work/ChestXRay2017/chest_xray/train/NORMAL/NORMAL2-IM-1442-0001.jpeg'

# 读取
_bacteria =  Image.open(_bacteria).convert('RGB')
_virus =   Image.open(_virus).convert('RGB') 
_normal =  Image.open(_normal).convert('RGB')

# 绘图
plt.figure(figsize=(9, 5))
plt.subplot(1,3,1),plt.xticks([]),plt.yticks([]),plt.title('bacteria'),plt.imshow(_bacteria)
plt.subplot(1,3,2),plt.xticks([]),plt.yticks([]),plt.title('virus'),plt.imshow(_virus)
plt.subplot(1,3,3),plt.xticks([]),plt.yticks([]),plt.title('normal'),plt.imshow(_normal)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pj2cnE0P-1641089245861)(output_3_0.png)]

# 安装需要的库函数
!pip install paddleseg
!pip install opencv-python
import os
import glob
import paddle
from paddle.io import Dataset
import paddleseg.transforms as T
import numpy as np
import random
from PIL import Image
import numpy as np
# 重写数据读取类
class ChestXRayDataset(Dataset):
    def __init__(self,mode = 'train',transform =None):
       
        ### 读取数据
        rootPath = 'work/ChestXRay2017/chest_xray'
        trainPath = os.path.join(rootPath,'train')
        testPath = os.path.join(rootPath,'test')
        self.transforms = transform
        self.mode = mode

        if self.mode == 'train':

            ## 读取非肺炎数据
            normalTrainPath = os.path.join(trainPath,'NORMAL')
            path_ =normalTrainPath + '/*.jpeg' # 使用通配符进行匹配
            normalTrainList_ = glob.glob(path_)

            ## 读取肺炎数据
            normalTrainPath = os.path.join(trainPath,'PNEUMONIA')
            # 读取细菌性肺炎
            path_ =normalTrainPath + '/*bacteria*.jpeg'
            bacteriaTrainList_ = glob.glob(path_)

            # 读取病毒性肺炎
            path_ =normalTrainPath + '/*virus*.jpeg'
            virusTrainList_ = glob.glob(path_) 

            # 加入标签 0【正常】,1【细菌性肺炎】,2【病毒性肺炎】
            normalTrainList = [[item,0] for item in normalTrainList_]
            bacteriaTrainList = [[item,1] for item in bacteriaTrainList_]
            virusTrainList = [[item,2] for item in virusTrainList_]
            self.jpeg_list = normalTrainList + bacteriaTrainList + virusTrainList
            random.shuffle( self.jpeg_list )
        else: # test

            ## 读取非肺炎数据
            normalTestPath = os.path.join(testPath,'NORMAL')
            path_ =normalTestPath + '/*.jpeg'
            normalTestList_ = glob.glob(path_)

            ## 读取肺炎数据
            normalTestPath = os.path.join(testPath,'PNEUMONIA')
            # 读取细菌性肺炎
            path_ =normalTestPath + '/*bacteria*.jpeg'
            bacteriaTestList_ = glob.glob(path_)

            # 读取病毒性肺炎
            path_ =normalTestPath + '/*virus*.jpeg'
            virusTestList_ = glob.glob(path_)

            # 加入标签 0【正常】,1【细菌性肺炎】,2【病毒性肺炎】
            normalTestList = [[item,0] for item in normalTestList_]
            bacteriaTestList = [[item,1] for item in bacteriaTestList_]
            virusTestList = [[item,2] for item in virusTestList_]

            self.jpeg_list = normalTestList + bacteriaTestList + virusTestList
            random.shuffle( self.jpeg_list )

    def __getitem__(self, index):

        pic,label = self.jpeg_list[index]

        # 读取ipeg数据
        data = Image.open(pic).convert('RGB')
        #data = data.transpose((2,0,1))
        
        if self.transforms:
            data = self.transforms(data)
        return data,label

    def __len__(self):

        return len(self.jpeg_list)  

# ## 统计均值方差,用于进行数据归一化处理
# import os
# import cv2
# import numpy as np
# import glob
# from tqdm import tqdm

# def getMeanStd(allJpegList_):
#     '''
#     input: List 需要计算的所有图片的路径
#     return: List [means, stdevs]
#     '''
#     means, stdevs = [], []
#     img_dict = {} 
#     # 使用字典比使用list然后append更快
#     # 原因是append每次需要重新申请地址,当数据量大时,后面append会越来越慢,因此可以改成dict,然后提取value转为list
#     # 本例子中时间可以减少一半,数据越多加速越明显

#     for idx,imgs_path in enumerate(allJpegList_):
#         img = cv2.cvtColor(cv2.imread(imgs_path), cv2.COLOR_BGR2RGB)
#         temp_ = img.reshape(-1,3)
#         mean_ = np.expand_dims(np.mean(temp_,0),1)
#         std_ = np.expand_dims(np.std(temp_,0),1)
#         img_dict[idx] = np.concatenate((mean_,std_),axis=1)

#     #获取字典的值
#     img_list = list(img_dict.values())
#     stas_ = np.mean(img_list, axis=0)/255
#     return stas_

# # 读取训练集及测试集数据
# # 数据为RGB数据
# path_ ='work/ChestXRay2017/chest_xray/*/*/*.jpeg'
# allJpegList_ = glob.glob(path_)
# statistic_result = getMeanStd(allJpegList_)
# # 统计结果,第一列是RGB的均值,第二列是RGB的方差
# print(statistic_result)

# '''
#  [0.48151479 0.22348982]
#  [0.48151479 0.22348982]
#  [0.48151479 0.22348982]
# '''

2.论文介绍

2.0 KiU-Net: Towards Accurate Segmentation of Biomedical Images using Over-complete Representat

大多数医学图像分割方法使用U-Net或其网络变体已经成功应用在大多数医学场景下。

这种“传统”的编码器-解码器方法在于较小的结构和边界分割时存在较大误差。

文献[1]认为当Unet过分关注深层特征时,会减少对浅层特征的注意。

为了克服这个问题,文献[1]提出了超完整卷积体系结构将输入图像转换为更高的维度,以便抑制深层次感受野的关注度。


2.1 文章亮点

  1. 探索了过完备的网络结构Ki-Net

  2. 将欠完备与过完备深度网络结合起来提出了新的网络结构KiU-Net

  3. 在分割领域取得了更快的收敛速度和更好的性能

为了更好地将每一个卷积块的特征结合起来,作者提出了一个cross residual fusion block。
两个分支相同层级的特征图同时作为输入,然后获得两个输出作为两个分支下一层级的输入。
将UNet分支的特征图通过卷积层、ReLu之后与Ki-Net的特征图相加作为Ki-Net下一层级的输入。同样的
将Ki-Net分支的特征图通过卷积层、ReLu之后与UNet的特征图相加作为UNet下一层级的输入。
最后将两个分支的特征图相加之后通过一个1x1的卷积获得输出的分割图像。[2]


2.2 网络结构

  1. 探索了过完备的网络结构Ki-Net

  2. 将欠完备与过完备深度网络结合起来提出了新的网络结构KiU-Net

  3. 在分割领域取得了更快的收敛速度和更好的性能


2.3 论文结果

实验结果显示训练速度更快,精度更高,并且模型更小

(笔者注)缺点是实际训练中需要的时间极长,需要的内存也更大


2.4 参考文献

[1] Valanarasu J M J, Sindagi V A, Hacihaliloglu I, et al. Kiu-net: Towards accurate segmentation of biomedical images using over-complete representations[C]//International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2020: 363-373.

[2] https://blog.csdn.net/Qy1997/article/details/108356319

# 定义KIUnet 
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import initializer

def init_weights(init_type='kaiming'):
    if init_type == 'normal':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())
    elif init_type == 'xavier':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())
    elif init_type == 'kaiming':
        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)


class kiunet(nn.Layer):
    def __init__(self ,in_channels = 3, n_classes =3):
        super(kiunet,self).__init__()

        self.in_channels = in_channels
        self.n_class = n_classes

        self.encoder1 = nn.Conv2D(self.in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB 
        self.en1_bn = nn.BatchNorm(16)
        self.encoder2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)  
        self.en2_bn = nn.BatchNorm(32)
        self.encoder3=   nn.Conv2D(32, 64, 3, stride=1, padding=1)
        self.en3_bn = nn.BatchNorm(64)

        self.decoder1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)   
        self.de1_bn = nn.BatchNorm(32)
        self.decoder2 =   nn.Conv2D(32,16, 3, stride=1, padding=1)
        self.de2_bn = nn.BatchNorm(16)
        self.decoder3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)
        self.de3_bn = nn.BatchNorm(8)

        self.decoderf1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)
        self.def1_bn = nn.BatchNorm(32)
        self.decoderf2=   nn.Conv2D(32, 16, 3, stride=1, padding=1)
        self.def2_bn = nn.BatchNorm(16)
        self.decoderf3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)
        self.def3_bn = nn.BatchNorm(8)

        self.encoderf1 =   nn.Conv2D(in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB 
        self.enf1_bn = nn.BatchNorm(16)
        self.encoderf2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)
        self.enf2_bn = nn.BatchNorm(32)
        self.encoderf3 =   nn.Conv2D(32, 64, 3, stride=1, padding=1)
        self.enf3_bn = nn.BatchNorm(64)

        self.intere1_1 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.inte1_1bn = nn.BatchNorm(16)
        self.intere2_1 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.inte2_1bn = nn.BatchNorm(32)
        self.intere3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.inte3_1bn = nn.BatchNorm(64)

        self.intere1_2 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.inte1_2bn = nn.BatchNorm(16)
        self.intere2_2 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.inte2_2bn = nn.BatchNorm(32)
        self.intere3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.inte3_2bn = nn.BatchNorm(64)

        self.interd1_1 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.intd1_1bn = nn.BatchNorm(32)
        self.interd2_1 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.intd2_1bn = nn.BatchNorm(16)
        self.interd3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.intd3_1bn = nn.BatchNorm(64)

        self.interd1_2 = nn.Conv2D(32,32,3, stride=1, padding=1)
        self.intd1_2bn = nn.BatchNorm(32)
        self.interd2_2 = nn.Conv2D(16,16,3, stride=1, padding=1)
        self.intd2_2bn = nn.BatchNorm(16)
        self.interd3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)
        self.intd3_2bn = nn.BatchNorm(64)

        self.final = nn.Sequential(
            nn.Conv2D(8,self.n_class,1,stride=1,padding=0),
            nn.AdaptiveAvgPool2D(output_size=1))

        # initialise weights
        for m in self.sublayers ():
            if isinstance(m, nn.Conv2D):
                m.weight_attr = init_weights(init_type='kaiming')
                m.bias_attr = init_weights(init_type='kaiming')
            elif isinstance(m, nn.BatchNorm):
                m.param_attr =init_weights(init_type='kaiming')
                m.bias_attr = init_weights(init_type='kaiming') 

    def forward(self, x):
        # input: c * h * w -> 16 * h/2 * w/2
        out = F.relu(self.en1_bn(F.max_pool2d(self.encoder1(x),2,2)))  #U-Net branch
        # c * h * w -> 16 * 2h * 2w
        out1 = F.relu(self.enf1_bn(F.interpolate(self.encoderf1(x),scale_factor=(2,2),mode ='bicubic'))) #Ki-Net branch
        # 16 * h/2 * w/2
        tmp = out
        # 16 * 2h * 2w -> 16 * h/2 * w/2
        out = paddle.add(out,F.interpolate(F.relu(self.inte1_1bn(self.intere1_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic')) #CRFB
        # 16 * h/2 * w/2 -> 16 * 2h * 2w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte1_2bn(self.intere1_2(tmp))),scale_factor=(4,4),mode ='bicubic')) #CRFB
        
        # 16 * h/2 * w/2
        u1 = out  #skip conn
        # 16 * 2h * 2w
        o1 = out1  #skip conn

        # 16 * h/2 * w/2 -> 32 * h/4 * w/4
        out = F.relu(self.en2_bn(F.max_pool2d(self.encoder2(out),2,2)))
        # 16 * 2h * 2w -> 32 * 4h * 4w
        out1 = F.relu(self.enf2_bn(F.interpolate(self.encoderf2(out1),scale_factor=(2,2),mode ='bicubic')))
        #  32 * h/4 * w/4
        tmp = out
        # 32 * 4h * 4w -> 32 * h/4 *w/4
        out = paddle.add(out,F.interpolate(F.relu(self.inte2_1bn(self.intere2_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))
        # 32 * h/4 * w/4 -> 32 *4h *4w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte2_2bn(self.intere2_2(tmp))),scale_factor=(16,16),mode ='bicubic'))
        
        #  32 * h/4 *w/4
        u2 = out
        #  32 *4h *4w
        o2 = out1
        
        # 32 * h/4 *w/4 -> 64 * h/8 *w/8
        out = F.relu(self.en3_bn(F.max_pool2d(self.encoder3(out),2,2)))
        # 32 *4h *4w -> 64 * 8h *8w
        out1 = F.relu(self.enf3_bn(F.interpolate(self.encoderf3(out1),scale_factor=(2,2),mode ='bicubic')))
        #  64 * h/8 *w/8 
        tmp = out
        #  64 * 8h *8w -> 64 * h/8 * w/8
        out = paddle.add(out,F.interpolate(F.relu(self.inte3_1bn(self.intere3_1(out1))),scale_factor=(0.015625,0.015625),mode ='bicubic'))
        #  64 * h/8 *w/8 -> 64 * 8h * 8w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte3_2bn(self.intere3_2(tmp))),scale_factor=(64,64),mode ='bicubic'))
        
        ### End of encoder block

        ### Start Decoder
        
        # 64 * h/8 * w/8 -> 32 * h/4 * w/4 
        out = F.relu(self.de1_bn(F.interpolate(self.decoder1(out),scale_factor=(2,2),mode ='bicubic')))  #U-NET
        # 64 * 8h * 8w -> 32 * 4h * 4w 
        out1 = F.relu(self.def1_bn(F.max_pool2d(self.decoderf1(out1),2,2))) #Ki-NET
        # 32 * h/4 * w/4 
        tmp = out
        # 32 * 4h * 4w  -> 32 * h/4 * w/4 
        out = paddle.add(out,F.interpolate(F.relu(self.intd1_1bn(self.interd1_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))
        # 32 * h/4 * w/4  -> 32 * 4h * 4w 
        out1 = paddle.add(out1,F.interpolate(F.relu(self.intd1_2bn(self.interd1_2(tmp))),scale_factor=(16,16),mode ='bicubic'))
        
        # 32 * h/4 * w/4 
        out = paddle.add(out,u2)  #skip conn
        # 32 * 4h * 4w 
        out1 = paddle.add(out1,o2)  #skip conn

        # 32 * h/4 * w/4 -> 16 * h/2 * w/2 
        out = F.relu(self.de2_bn(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bicubic')))
        # 32 * 4h * 4w  -> 16 * 2h * 2w
        out1 = F.relu(self.def2_bn(F.max_pool2d(self.decoderf2(out1),2,2)))
        # 16 * h/2 * w/2 
        tmp = out
        # 16 * 2h * 2w -> 16 * h/2 * w/2
        out = paddle.add(out,F.interpolate(F.relu(self.intd2_1bn(self.interd2_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic'))
        # 16 * h/2 * w/2 -> 16 * 2h * 2w
        out1 = paddle.add(out1,F.interpolate(F.relu(self.intd2_2bn(self.interd2_2(tmp))),scale_factor=(4,4),mode ='bicubic'))
        
        # 16 * h/2 * w/2
        out = paddle.add(out,u1)
        # 16 * 2h * 2w
        out1 = paddle.add(out1,o1)

        # 16 * h/2 * w/2 -> 8 * h * w
        out = F.relu(self.de3_bn(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bicubic')))
        # 16 * 2h * 2w -> 8 * h * w
        out1 = F.relu(self.def3_bn(F.max_pool2d(self.decoderf3(out1),2,2)))

        # 8 * h * w
        out = paddle.add(out,out1) # fusion of both branches

        # 最后一层用sigmoid激活函数
        out = F.sigmoid(self.final(out))  #1*1 conv
        
        return out

# 可视化KIunet结构
KIunet = kiunet(in_channels = 3, n_classes =3)

model = paddle.Model(KIunet)
model.summary((2,3, 256, 256))

3. 训练代码

  • 训练时间太长,仅训练两个EPOCH作为说明

  • 修改网络结构最后一层,增加AdaptiveAvgPool2D将网络用于分类问题。

  • 添加了丰富的数据增强方法

model =  kiunet(in_channels = 3, n_classes =3)
# 开启模型训练模式
model.train()
# 定义优化算法,使用随机梯度下降Adam,学习率设置为0.001,学习率策略为StepDecay
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=20, gamma=0.1, verbose=False)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())

EPOCH_NUM = 2  # 设置外层循环次数
BATCH_SIZE = 2  # 设置batch大小

from paddle.vision import transforms as T

Transforms_train = T.Compose([
                        T.RandomHorizontalFlip(0.5),  # 水平翻转
                        T.RandomVerticalFlip(0.5), # 垂直翻转
                        T.RandomRotation(15),  # 随机旋转
                        T.Resize(( 128, 128 )),  # 调整大小
                        T.ColorJitter(0.2, 0.2, 0.2, 0.2),# 随机调整亮度,对比度,饱和度和色调。
                        T.Transpose(),
                        T.Normalize(
                            [122.78627145, 122.78627145, 122.78627145], 
                            [56.9899041, 56.9899041, 56.9899041]),  # 标准化
])

Transforms_test = T.Compose([
                        T.Resize(( 256, 256 )),  # 调整大小
                        T.Transpose(),
                        T.Normalize(
                            [122.78627145, 122.78627145, 122.78627145], 
                            [56.9899041, 56.9899041, 56.9899041]),  # 标准化

])

train_dataset =  ChestXRayDataset(mode='train',transform = Transforms_train)
test_dataset =   ChestXRayDataset(mode='test',transform = Transforms_test)

# 使用paddle.io.DataLoader 定义DataLoader对象用于加载Python生成器产生的数据,
data_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_data_loader = paddle.io.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 使用BCEloss
loss_BCEloss = paddle.nn.BCELoss()

W1227 16:42:41.671527  6593 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1227 16:42:41.675050  6593 device_context.cc:465] device: 0, cuDNN Version: 7.6.
# 定义外层循环
for epoch_id in range(EPOCH_NUM):
    # 定义内层循环
    for iter_id, data in enumerate(data_loader()):
        x, y = data # x 为数据 ,y 为标签
        y = np.squeeze(y)
        # 将标签转为onehot编码
        one_hot = paddle.nn.functional.one_hot(y, num_classes=3) #一共三个类别
        # 将numpy数据转为飞桨动态图tensor形式
        x = paddle.to_tensor(x,dtype='float32')
        y = paddle.to_tensor(one_hot,dtype='float32')

        # 前向计算
        predicts = model(x)
        predicts = paddle.squeeze(predicts)
        # 计算损失
        loss = loss_BCEloss(predicts, y)

        # 清除梯度
        optimizer.clear_grad()
        # 反向传播
        loss.backward()
        # 最小化loss,更新参数
        optimizer.step()
    scheduler.step()
    print("epoch: {}, iter: {}, loss is: {}".format(epoch_id+1, iter_id+1, loss.numpy()))

# 保存模型参数,文件名为 kiunet_model.pdparams
paddle.save(model.state_dict(), 'work/kiunet_model.pdparams')
print("模型保存成功,模型参数保存在 kiunet_model.pdparams中")

4. 测试代码

  • 计算平均误差
import paddle
from sklearn.metrics import accuracy_score 
# 模型验证
# 清理缓存
print("开始测试")
# 用于加载之前的训练过的模型参数
para_state_dict = paddle.load('work/kiunet_model.pdparams')
model =  kiunet(in_channels = 3, n_classes =3)
model.set_dict(para_state_dict)
Error = []
for iter_id, data in enumerate(test_data_loader()):

        x, y = data # x 为数据 ,y 为标签
        # 将numpy数据转为飞桨动态图tensor形式
        x = paddle.to_tensor(x,dtype='float32')

        # 前向计算
        predicts = model(x)
        predicts = paddle.squeeze(predicts)
        predicts = predicts.cpu().numpy()
        y = y.cpu().numpy()
        predLabel = np.argmax(predicts,1)

        Error.append(accuracy_score(predLabel, y))
    
r_id, data in enumerate(test_data_loader()):

        x, y = data # x 为数据 ,y 为标签
        # 将numpy数据转为飞桨动态图tensor形式
        x = paddle.to_tensor(x,dtype='float32')

        # 前向计算
        predicts = model(x)
        predicts = paddle.squeeze(predicts)
        predicts = predicts.cpu().numpy()
        y = y.cpu().numpy()
        predLabel = np.argmax(predicts,1)

        Error.append(accuracy_score(predLabel, y))
    
print("模型测试集平均定位误差为:",np.mean(Error))
开始测试
模型测试集平均定位误差为: 0.38782051282051283

5. 项目总结

  • 本项目主要对KIUnet算法进行实现,实现了修改Unet类网络在分类任务的应用。

  • KIUnet的训练时间久,需要的内存较大,建议对该数据感兴趣的同学可以自行修改为分类网络,平台上已经有许多注释完整的优秀分类项目。

  • 本项目主要是演示该数据集的使用及KIUnet网络结构的复现,代码注释完整,适合对分割和分类感兴趣的同学学习。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐