MedMNIST数据集上基于Video Vision Transformer的医学轻量视频分类任务

作者: WangXi2016

日期: 2022.11.18

描述: 一种基于Transformer的视频分类架构。

一、介绍

视频是由多张序列的图像组成,假设有一个图像分类模型(CNN、ViT等)和一个序列模型(RNN、LSTM等)。

我们要求调整视频分类的模型,最简单的方法是用图像模型提取各帧特征,使用序列模型学习图像特征序列,然后再进行分类。

此外,还可以为视频分类构建一个基于Transformer的混合模型。

在本例中,我们最低限度地实现了Arnab等人的ViViT:Video Vision Transformer,这是一个基于Transformer的视频分类模型。

作者提出了一种新的嵌入方案和一些Transformer变体。

1.1 数据集下载

!wget https://zenodo.org/record/5208230/files/organmnist3d.npz?download=1 -O organmnist3d.npz

二、导入相关库

import os
import io
import imageio
import ipywidgets
import numpy as np
import cv2
from glob import glob
import matplotlib.pyplot as plt
from PIL import Image as PilImage
import paddle
import paddle.nn as nn
from paddle.io import Dataset
paddle.__version__
'2.4.0-rc0'

三、数据集类定义

飞桨(PaddlePaddle)数据集加载方案是统一使用Dataset(数据集定义) + DataLoader(多进程数据集加载)。

首先进行数据集的定义,数据集定义主要是实现一个新的Dataset类,继承父类paddle.io.Dataset,并实现父类中以下两个抽象方法,getitem__和__len

class MyDataset(Dataset):
def init(self):

# 每次迭代时返回数据和对应的标签
def __getitem__(self, idx):
    return x, y

# 返回整个数据集的总数
def __len__(self):
    return count(samples)

在数据集内部可以结合图像数据预处理相关API进行图像的预处理(改变大小、反转、调整格式等)。

对于我们的示例,我们使用MedMNIST v2:用于2D和3D生物医学图像分类数据集的大规模轻量级基准。这些视频数据集较轻,容易训练。

class MNDataset(Dataset):


    def __init__(self, mode='train',path="./organmnist3d.npz"):

        assert mode in ['train', 'val','test'], \
            "mode should be 'train' or 'val', but got {}".format(mode)

        pd = self.prepare_dataset(path)

        if mode=='train':
            self.data = pd[0]
        elif mode=='val':
            self.data = pd[1]
        else:
            self.data = pd[2]

        self.num_samples = len(self.data[0])

    def prepare_dataset(self,data_path):

        with np.load(data_path) as data:
            # Get videos
            train_videos = data["train_images"].astype('float32') /255.0
            valid_videos = data["val_images"].astype('float32')/255.0
            test_videos = data["test_images"].astype('float32')/255.0

            # Get labels
            train_labels = data["train_labels"].flatten()
            valid_labels = data["val_labels"].flatten()
            test_labels = data["test_labels"].flatten()

        return (
            (train_videos, train_labels),
            (valid_videos, valid_labels),
            (test_videos, test_labels),
        )
    

    def __getitem__(self, idx):
        img = self.data[0][idx]
        image = np.expand_dims(img,axis=0)
        # print(img.shape,image.shape)
        label =  self.data[1][idx]
        label = np.array([label], dtype="int64")
        return image, label

    def __len__(self):
        return self.num_samples

四、模型构建

4.1 TubeletEmbedding

在ViT中,图像被分割成块,然后在空间上展平,这一过程被称为标记化。
在这里插入图片描述

对于视频,TubeletEmbedding 首先,使用Conv3D提取视频的特征和时间信息。然后将卷展平以构建视频 tokens。
在这里插入图片描述

class TubeletEmbedding(nn.Layer):
    def __init__(self,in_channels,embed_dim, patch_size):
        super(TubeletEmbedding,self).__init__()
        self.projection = nn.Conv3D(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size,padding="VALID")
        self.embed_dim = embed_dim

    def forward(self, videos):
        projected_patches = self.projection(videos)
        flattened_patches = projected_patches.reshape((videos.shape[0],-1,self.embed_dim))
        return flattened_patches

4.2 Positional Encoder

该层将位置信息添加到编码的视频特征中。

class PositionalEncoder(nn.Layer):
    def __init__(self, embed_dim,num_tokens):
        super(PositionalEncoder,self).__init__()
        self.embed_dim = embed_dim
        self.positions = paddle.arange(0,num_tokens,1,"int32")

        self.position_embedding = nn.Embedding(num_embeddings=num_tokens,embedding_dim=self.embed_dim)
        self.positions = paddle.arange(start=0,end=num_tokens,step=1)
 
        

    def forward(self, encoded_tokens):

        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens

4.3 Video Vision Transformer

作者提出了Vision Transformer的4种变体:

  • Spatio-temporal attention
  • Factorized encoder
  • Factorized self-attention
  • Factorized dot-product attention

在本例中,为了简单起见,我们将实现Spatio-temporal attention模型。

class TransfromerLayer(nn.Layer):
    def __init__(self,input_shape,embed_dim,num_heads):
        super(TransfromerLayer,self).__init__()
        self.ln1 = nn.LayerNorm(normalized_shape=input_shape,epsilon=1e-6)
        dim = input_shape[1]
        self.mha = nn.MultiHeadAttention(dim,num_heads,kdim=embed_dim,dropout=0.1)
        self.ln2 = nn.LayerNorm(normalized_shape=input_shape,epsilon=1e-6)
        self.gelu = paddle.nn.GELU()
        self.line1 = nn.Linear(dim,dim*4)
        
        self.line2 = nn.Linear(dim*4,dim)


    def forward(self, inputs):
        x = self.ln1(inputs)
        x = self.mha(x,x)
        x1 = x + inputs 
        
        x2 = self.ln2(x1)
        x2 = self.gelu(self.line1(x2))
        x2 = self.gelu(self.line2(x2))
        return x2 + x1
    
class VideoVisionTransformer(nn.Layer):
    def __init__(self,num_layers,num_class):
        super(VideoVisionTransformer,self).__init__()
        self.te = TubeletEmbedding(1,128,(8,8,8))
        self.pe = PositionalEncoder(128,27)
        self.tls = nn.Sequential(*[(str(i),TransfromerLayer((27,128),128,8)) for i in range(num_layers)])
        self.ln = nn.LayerNorm(normalized_shape=(27,128),epsilon=1e-6)
        self.line = nn.Linear(128,num_class)


    def forward(self, inputs):
        # print(inputs.shape)
        x = self.te(inputs)
        x = self.pe(x)
        x = self.tls(x)
        # print(x.shape)
        x = self.ln(x)
        x = x.mean(axis=1)
        x = self.line(x)
        return x
vivi = VideoVisionTransformer(6,20)
model = paddle.Model(vivi)

4.4 模型可视化

model.summary((1,1,28,28,28))
------------------------------------------------------------------------------------------
     Layer (type)              Input Shape              Output Shape         Param #    
==========================================================================================
      Conv3D-22            [[1, 1, 28, 28, 28]]      [1, 128, 3, 3, 3]       65,664     
 TubeletEmbedding-22       [[1, 1, 28, 28, 28]]         [1, 27, 128]            0       
     Embedding-22                 [[27]]                 [27, 128]            3,456     
 PositionalEncoder-22         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-322             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-922              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-923              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-924              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-925              [[1, 27, 128]]            [1, 27, 128]         16,512     
MultiHeadAttention-151 [[1, 27, 128], [1, 27, 128]]     [1, 27, 128]            0       
    LayerNorm-323             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-926              [[1, 27, 128]]            [1, 27, 512]         66,048     
       GELU-151               [[1, 27, 128]]            [1, 27, 128]            0       
      Linear-927              [[1, 27, 512]]            [1, 27, 128]         65,664     
 TransfromerLayer-151         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-324             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-928              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-929              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-930              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-931              [[1, 27, 128]]            [1, 27, 128]         16,512     
MultiHeadAttention-152 [[1, 27, 128], [1, 27, 128]]     [1, 27, 128]            0       
    LayerNorm-325             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-932              [[1, 27, 128]]            [1, 27, 512]         66,048     
       GELU-152               [[1, 27, 128]]            [1, 27, 128]            0       
      Linear-933              [[1, 27, 512]]            [1, 27, 128]         65,664     
 TransfromerLayer-152         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-326             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-934              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-935              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-936              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-937              [[1, 27, 128]]            [1, 27, 128]         16,512     
MultiHeadAttention-153 [[1, 27, 128], [1, 27, 128]]     [1, 27, 128]            0       
    LayerNorm-327             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-938              [[1, 27, 128]]            [1, 27, 512]         66,048     
       GELU-153               [[1, 27, 128]]            [1, 27, 128]            0       
      Linear-939              [[1, 27, 512]]            [1, 27, 128]         65,664     
 TransfromerLayer-153         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-328             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-940              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-941              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-942              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-943              [[1, 27, 128]]            [1, 27, 128]         16,512     
MultiHeadAttention-154 [[1, 27, 128], [1, 27, 128]]     [1, 27, 128]            0       
    LayerNorm-329             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-944              [[1, 27, 128]]            [1, 27, 512]         66,048     
       GELU-154               [[1, 27, 128]]            [1, 27, 128]            0       
      Linear-945              [[1, 27, 512]]            [1, 27, 128]         65,664     
 TransfromerLayer-154         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-330             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-946              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-947              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-948              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-949              [[1, 27, 128]]            [1, 27, 128]         16,512     
MultiHeadAttention-155 [[1, 27, 128], [1, 27, 128]]     [1, 27, 128]            0       
    LayerNorm-331             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-950              [[1, 27, 128]]            [1, 27, 512]         66,048     
       GELU-155               [[1, 27, 128]]            [1, 27, 128]            0       
      Linear-951              [[1, 27, 512]]            [1, 27, 128]         65,664     
 TransfromerLayer-155         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-332             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-952              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-953              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-954              [[1, 27, 128]]            [1, 27, 128]         16,512     
      Linear-955              [[1, 27, 128]]            [1, 27, 128]         16,512     
MultiHeadAttention-156 [[1, 27, 128], [1, 27, 128]]     [1, 27, 128]            0       
    LayerNorm-333             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-956              [[1, 27, 128]]            [1, 27, 512]         66,048     
       GELU-156               [[1, 27, 128]]            [1, 27, 128]            0       
      Linear-957              [[1, 27, 512]]            [1, 27, 128]         65,664     
 TransfromerLayer-156         [[1, 27, 128]]            [1, 27, 128]            0       
    LayerNorm-334             [[1, 27, 128]]            [1, 27, 128]          6,912     
      Linear-958                [[1, 128]]                [1, 20]             2,580     
==========================================================================================
Total params: 1,348,116
Trainable params: 1,348,116
Non-trainable params: 0
------------------------------------------------------------------------------------------
Input size (MB): 0.08
Forward/backward pass size (MB): 2.35
Params size (MB): 5.14
Estimated Total Size (MB): 7.57
------------------------------------------------------------------------------------------






{'total_params': 1348116, 'trainable_params': 1348116}

五、模型训练

train_dataset = MNDataset()
val_dataset = MNDataset(mode='val')
test_dataset = MNDataset(mode='test')
# 自定义Callback 需要继承基类 Callback
class LossCallback(paddle.callbacks.Callback):

    def __init__(self):
        self.trainlosses = []
        self.trainaccs = []

        self.vallosses = []
        self.valacc = []
        self.maxacc = 0
    def on_train_begin(self, logs={}):
        # 在fit前 初始化losses,用于保存每个batch的loss结果
        self.trainlosses = []
        self.trainaccs = []

        self.vallosses = []
        self.valacc = []
        self.maxacc = 0
        
    
    def on_epoch_end(self, step, logs={}):
        # 每个batch训练完成后调用,把当前loss添加到losses中
        self.trainlosses.append(logs.get('loss')[0])
        self.trainaccs.append(logs.get('acc'))

    def on_train_end(self,lof={}):
        self.model.save(f"finishmodel")

    def on_eval_end(self,logs={}):
        self.vallosses.append(logs.get('loss')[0])
        self.valacc.append(logs.get('acc'))
        if self.maxacc<logs.get('acc'):
            self.maxacc = logs.get('acc')
            self.model.save(f"bestmodel")
            print(f'\n{self.maxacc:.3f} Save the best model.\n')

tranloss_log = LossCallback()

model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters()),
              loss=paddle.nn.CrossEntropyLoss(),
              metrics=[paddle.metric.Accuracy(),paddle.metric.Accuracy(topk=(5,))])

visualdl=paddle.callbacks.VisualDL(log_dir='visual_log') # 开启训练可视化,


model.fit(
    train_data=train_dataset, 
    eval_data=val_dataset, 
    batch_size=200, 
    epochs=80, 
    verbose=1,
    shuffle=True,
    callbacks=[tranloss_log,visualdl] 
)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/80
step 5/5 [==============================] - loss: 2.5846 - acc: 0.1080 - acc_top5: 0.5051 - 39ms/step
Eval begin...
step 1/1 [==============================] - loss: 2.2706 - acc: 0.2857 - acc_top5: 0.6770 - 27ms/step
Eval samples: 161

0.286 Save the best model.

Epoch 2/80
step 5/5 [==============================] - loss: 1.9559 - acc: 0.2726 - acc_top5: 0.7685 - 33ms/step
Eval begin...
step 1/1 [==============================] - loss: 1.6688 - acc: 0.4783 - acc_top5: 0.8944 - 27ms/step
Eval samples: 161

0.478 Save the best model.

Epoch 3/80
step 5/5 [==============================] - loss: 1.6695 - acc: 0.4311 - acc_top5: 0.8961 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 1.3570 - acc: 0.6273 - acc_top5: 0.9379 - 27ms/step
Eval samples: 161

0.627 Save the best model.

Epoch 4/80
step 5/5 [==============================] - loss: 1.3275 - acc: 0.5113 - acc_top5: 0.9352 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.9739 - acc: 0.7578 - acc_top5: 0.9876 - 29ms/step
Eval samples: 161

0.758 Save the best model.

Epoch 5/80
step 5/5 [==============================] - loss: 1.2761 - acc: 0.6039 - acc_top5: 0.9486 - 37ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.8973 - acc: 0.7019 - acc_top5: 0.9814 - 30ms/step
Eval samples: 161
Epoch 6/80
step 5/5 [==============================] - loss: 1.0162 - acc: 0.5772 - acc_top5: 0.9619 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.6353 - acc: 0.7826 - acc_top5: 0.9938 - 54ms/step
Eval samples: 161

0.783 Save the best model.

Epoch 7/80
step 5/5 [==============================] - loss: 0.8544 - acc: 0.6944 - acc_top5: 0.9660 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.4340 - acc: 0.8944 - acc_top5: 0.9938 - 29ms/step
Eval samples: 161

0.894 Save the best model.

Epoch 8/80
step 5/5 [==============================] - loss: 0.8280 - acc: 0.7397 - acc_top5: 0.9702 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.4033 - acc: 0.8820 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161
Epoch 9/80
step 5/5 [==============================] - loss: 0.7864 - acc: 0.7315 - acc_top5: 0.9763 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.3294 - acc: 0.9006 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161

0.901 Save the best model.

Epoch 10/80
step 5/5 [==============================] - loss: 0.7972 - acc: 0.7613 - acc_top5: 0.9805 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.3145 - acc: 0.9130 - acc_top5: 0.9938 - 28ms/step
Eval samples: 161

0.913 Save the best model.

Epoch 11/80
step 5/5 [==============================] - loss: 0.6125 - acc: 0.7747 - acc_top5: 0.9887 - 38ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.3151 - acc: 0.8944 - acc_top5: 1.0000 - 34ms/step
Eval samples: 161
Epoch 12/80
step 5/5 [==============================] - loss: 0.8789 - acc: 0.7726 - acc_top5: 0.9856 - 50ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2463 - acc: 0.9255 - acc_top5: 1.0000 - 42ms/step
Eval samples: 161

0.925 Save the best model.

Epoch 13/80
step 5/5 [==============================] - loss: 0.5250 - acc: 0.8097 - acc_top5: 0.9907 - 38ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.4652 - acc: 0.8447 - acc_top5: 0.9814 - 30ms/step
Eval samples: 161
Epoch 14/80
step 5/5 [==============================] - loss: 0.4603 - acc: 0.8220 - acc_top5: 0.9907 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1978 - acc: 0.9255 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 15/80
step 5/5 [==============================] - loss: 0.4152 - acc: 0.8529 - acc_top5: 0.9938 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2073 - acc: 0.9317 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161

0.932 Save the best model.

Epoch 16/80
step 5/5 [==============================] - loss: 0.3363 - acc: 0.8724 - acc_top5: 0.9949 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2929 - acc: 0.9193 - acc_top5: 0.9938 - 30ms/step
Eval samples: 161
Epoch 17/80
step 5/5 [==============================] - loss: 0.3467 - acc: 0.8745 - acc_top5: 0.9959 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2010 - acc: 0.9379 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161

0.938 Save the best model.

Epoch 18/80
step 5/5 [==============================] - loss: 0.2109 - acc: 0.9012 - acc_top5: 0.9949 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2695 - acc: 0.9130 - acc_top5: 0.9938 - 29ms/step
Eval samples: 161
Epoch 19/80
step 5/5 [==============================] - loss: 0.3469 - acc: 0.9002 - acc_top5: 1.0000 - 32ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2121 - acc: 0.9379 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 20/80
step 5/5 [==============================] - loss: 0.2835 - acc: 0.9126 - acc_top5: 0.9969 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2374 - acc: 0.9255 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 21/80
step 5/5 [==============================] - loss: 0.2632 - acc: 0.9321 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2033 - acc: 0.9503 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161

0.950 Save the best model.

Epoch 22/80
step 5/5 [==============================] - loss: 0.1645 - acc: 0.9383 - acc_top5: 1.0000 - 40ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1987 - acc: 0.9379 - acc_top5: 1.0000 - 31ms/step
Eval samples: 161
Epoch 23/80
step 5/5 [==============================] - loss: 0.0996 - acc: 0.9455 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2222 - acc: 0.9317 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 24/80
step 5/5 [==============================] - loss: 0.2608 - acc: 0.9527 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1894 - acc: 0.9379 - acc_top5: 1.0000 - 31ms/step
Eval samples: 161
Epoch 25/80
step 5/5 [==============================] - loss: 0.0597 - acc: 0.9578 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2582 - acc: 0.9068 - acc_top5: 1.0000 - 30ms/step
Eval samples: 161
Epoch 26/80
step 5/5 [==============================] - loss: 0.1500 - acc: 0.9588 - acc_top5: 0.9990 - 40ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2544 - acc: 0.9379 - acc_top5: 0.9876 - 27ms/step
Eval samples: 161
Epoch 27/80
step 5/5 [==============================] - loss: 0.1003 - acc: 0.9660 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2151 - acc: 0.9503 - acc_top5: 0.9938 - 28ms/step
Eval samples: 161
Epoch 28/80
step 5/5 [==============================] - loss: 0.0769 - acc: 0.9784 - acc_top5: 0.9990 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.3080 - acc: 0.9193 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161
Epoch 29/80
step 5/5 [==============================] - loss: 0.0770 - acc: 0.9753 - acc_top5: 0.9990 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2331 - acc: 0.9317 - acc_top5: 0.9938 - 29ms/step
Eval samples: 161
Epoch 30/80
step 5/5 [==============================] - loss: 0.0557 - acc: 0.9856 - acc_top5: 1.0000 - 37ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2108 - acc: 0.9317 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 31/80
step 5/5 [==============================] - loss: 0.1013 - acc: 0.9815 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2438 - acc: 0.9193 - acc_top5: 0.9938 - 28ms/step
Eval samples: 161
Epoch 32/80
step 5/5 [==============================] - loss: 0.0235 - acc: 0.9918 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2555 - acc: 0.9317 - acc_top5: 0.9938 - 29ms/step
Eval samples: 161
Epoch 33/80
step 5/5 [==============================] - loss: 0.0331 - acc: 0.9949 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1961 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 34/80
step 5/5 [==============================] - loss: 0.0210 - acc: 0.9928 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1811 - acc: 0.9255 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 35/80
step 5/5 [==============================] - loss: 0.0258 - acc: 0.9949 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2118 - acc: 0.9441 - acc_top5: 0.9938 - 30ms/step
Eval samples: 161
Epoch 36/80
step 5/5 [==============================] - loss: 0.0217 - acc: 0.9938 - acc_top5: 1.0000 - 37ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2487 - acc: 0.9379 - acc_top5: 1.0000 - 30ms/step
Eval samples: 161
Epoch 37/80
step 5/5 [==============================] - loss: 0.0141 - acc: 0.9979 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2679 - acc: 0.9255 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 38/80
step 5/5 [==============================] - loss: 0.0203 - acc: 0.9969 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2329 - acc: 0.9379 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 39/80
step 5/5 [==============================] - loss: 0.0105 - acc: 0.9979 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2481 - acc: 0.9379 - acc_top5: 0.9938 - 29ms/step
Eval samples: 161
Epoch 40/80
step 5/5 [==============================] - loss: 0.0101 - acc: 0.9990 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2860 - acc: 0.9441 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 41/80
step 5/5 [==============================] - loss: 0.0094 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2448 - acc: 0.9441 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161
Epoch 42/80
step 5/5 [==============================] - loss: 0.0137 - acc: 0.9979 - acc_top5: 1.0000 - 38ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.3240 - acc: 0.9255 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 43/80
step 5/5 [==============================] - loss: 0.0432 - acc: 0.9918 - acc_top5: 1.0000 - 39ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2774 - acc: 0.9317 - acc_top5: 1.0000 - 31ms/step
Eval samples: 161
Epoch 44/80
step 5/5 [==============================] - loss: 0.0241 - acc: 0.9928 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2627 - acc: 0.9255 - acc_top5: 0.9938 - 28ms/step
Eval samples: 161
Epoch 45/80
step 5/5 [==============================] - loss: 0.1323 - acc: 0.9815 - acc_top5: 1.0000 - 38ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2399 - acc: 0.9130 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 46/80
step 5/5 [==============================] - loss: 0.0204 - acc: 0.9846 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2870 - acc: 0.9068 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161
Epoch 47/80
step 5/5 [==============================] - loss: 0.0233 - acc: 0.9835 - acc_top5: 1.0000 - 32ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.5399 - acc: 0.8634 - acc_top5: 0.9938 - 25ms/step
Eval samples: 161
Epoch 48/80
step 5/5 [==============================] - loss: 0.0370 - acc: 0.9681 - acc_top5: 1.0000 - 32ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2843 - acc: 0.9441 - acc_top5: 0.9938 - 27ms/step
Eval samples: 161
Epoch 49/80
step 5/5 [==============================] - loss: 0.0321 - acc: 0.9763 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2182 - acc: 0.9317 - acc_top5: 1.0000 - 30ms/step
Eval samples: 161
Epoch 50/80
step 5/5 [==============================] - loss: 0.0373 - acc: 0.9866 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2274 - acc: 0.9379 - acc_top5: 0.9876 - 29ms/step
Eval samples: 161
Epoch 51/80
step 5/5 [==============================] - loss: 0.0435 - acc: 0.9897 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2601 - acc: 0.9255 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 52/80
step 5/5 [==============================] - loss: 0.0785 - acc: 0.9877 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2628 - acc: 0.9255 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 53/80
step 5/5 [==============================] - loss: 0.0455 - acc: 0.9897 - acc_top5: 1.0000 - 33ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2311 - acc: 0.9317 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 54/80
step 5/5 [==============================] - loss: 0.0383 - acc: 0.9949 - acc_top5: 0.9990 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2259 - acc: 0.9317 - acc_top5: 1.0000 - 30ms/step
Eval samples: 161
Epoch 55/80
step 5/5 [==============================] - loss: 0.0079 - acc: 0.9959 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1935 - acc: 0.9441 - acc_top5: 1.0000 - 30ms/step
Eval samples: 161
Epoch 56/80
step 5/5 [==============================] - loss: 0.0055 - acc: 0.9969 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2084 - acc: 0.9379 - acc_top5: 1.0000 - 30ms/step
Eval samples: 161
Epoch 57/80
step 5/5 [==============================] - loss: 0.0042 - acc: 0.9979 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2144 - acc: 0.9565 - acc_top5: 1.0000 - 31ms/step
Eval samples: 161

0.957 Save the best model.

Epoch 58/80
step 5/5 [==============================] - loss: 0.0037 - acc: 1.0000 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1776 - acc: 0.9565 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 59/80
step 5/5 [==============================] - loss: 0.0067 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.1925 - acc: 0.9565 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 60/80
step 5/5 [==============================] - loss: 0.0067 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2061 - acc: 0.9441 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 61/80
step 5/5 [==============================] - loss: 0.0086 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2062 - acc: 0.9441 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 62/80
step 5/5 [==============================] - loss: 0.0011 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2250 - acc: 0.9379 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 63/80
step 5/5 [==============================] - loss: 0.0030 - acc: 1.0000 - acc_top5: 1.0000 - 41ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2298 - acc: 0.9503 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 64/80
step 5/5 [==============================] - loss: 0.0013 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2069 - acc: 0.9379 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 65/80
step 5/5 [==============================] - loss: 0.0012 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2021 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 66/80
step 5/5 [==============================] - loss: 9.8032e-04 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2066 - acc: 0.9441 - acc_top5: 1.0000 - 32ms/step
Eval samples: 161
Epoch 67/80
step 5/5 [==============================] - loss: 7.5654e-04 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2112 - acc: 0.9503 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 68/80
step 5/5 [==============================] - loss: 8.3809e-04 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2110 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 69/80
step 5/5 [==============================] - loss: 6.3972e-04 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2074 - acc: 0.9503 - acc_top5: 1.0000 - 27ms/step
Eval samples: 161
Epoch 70/80
step 5/5 [==============================] - loss: 8.0242e-04 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2055 - acc: 0.9503 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 71/80
step 5/5 [==============================] - loss: 7.9931e-04 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2048 - acc: 0.9503 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 72/80
step 5/5 [==============================] - loss: 7.1046e-04 - acc: 1.0000 - acc_top5: 1.0000 - 34ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2047 - acc: 0.9503 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 73/80
step 5/5 [==============================] - loss: 6.8809e-04 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2044 - acc: 0.9503 - acc_top5: 1.0000 - 31ms/step
Eval samples: 161
Epoch 74/80
step 5/5 [==============================] - loss: 6.0546e-04 - acc: 1.0000 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2057 - acc: 0.9503 - acc_top5: 1.0000 - 31ms/step
Eval samples: 161
Epoch 75/80
step 5/5 [==============================] - loss: 7.0117e-04 - acc: 1.0000 - acc_top5: 1.0000 - 36ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2061 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 76/80
step 5/5 [==============================] - loss: 7.0604e-04 - acc: 1.0000 - acc_top5: 1.0000 - 38ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2058 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 77/80
step 5/5 [==============================] - loss: 5.0498e-04 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2064 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
Epoch 78/80
step 5/5 [==============================] - loss: 5.5190e-04 - acc: 1.0000 - acc_top5: 1.0000 - 37ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2070 - acc: 0.9503 - acc_top5: 1.0000 - 34ms/step
Eval samples: 161
Epoch 79/80
step 5/5 [==============================] - loss: 5.5681e-04 - acc: 1.0000 - acc_top5: 1.0000 - 38ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2084 - acc: 0.9503 - acc_top5: 1.0000 - 28ms/step
Eval samples: 161
Epoch 80/80
step 5/5 [==============================] - loss: 9.9728e-04 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin...
step 1/1 [==============================] - loss: 0.2103 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161
param_state_dict = paddle.load( "finishmodel.pdparams")
vivi.set_dict(param_state_dict)
model.evaluate(test_dataset,verbose=1)
Eval begin...
step 610/610 [==============================] - loss: 1.5317e-04 - acc: 0.7770 - acc_top5: 0.9623 - 8ms/step      
Eval samples: 610

=] - loss: 9.9728e-04 - acc: 1.0000 - acc_top5: 1.0000 - 35ms/step
Eval begin…
step 1/1 [
========================] - loss: 0.2103 - acc: 0.9503 - acc_top5: 1.0000 - 29ms/step
Eval samples: 161

param_state_dict = paddle.load( "finishmodel.pdparams")
vivi.set_dict(param_state_dict)
model.evaluate(test_dataset,verbose=1)
Eval begin...
step 610/610 [==============================] - loss: 1.5317e-04 - acc: 0.7770 - acc_top5: 0.9623 - 8ms/step      
Eval samples: 610





{'loss': [0.00015317221],
 'acc': 0.7770491803278688,
 'acc_top5': 0.9622950819672131}
fig, ax = plt.subplots(2, 1, figsize=(8, 12))
ax = ax.ravel()


ax[0].plot(tranloss_log.trainlosses)
ax[0].plot(tranloss_log.vallosses)
ax[0].set_title("Model {}".format("Loss"))
ax[0].set_xlabel("epochs")
ax[0].legend(["train", "val"])

ax[1].plot(tranloss_log.trainaccs)
ax[1].plot(tranloss_log.valacc)
ax[1].set_title("Model {}".format("Accuracy"))
ax[1].set_xlabel("epochs")
# ax[1].set_ylabel(metric)
ax[1].legend(["train", "val"])
<matplotlib.legend.Legend at 0x7f2bfc527ed0>

在这里插入图片描述

六、模型预测

NUM_SAMPLES_VIZ = 25
ground_truths = []
preds = []
videos = []
labels = {
            0: "liver",
            1: "kidney-right",
            2: "kidney-left",
            3: "femur-right",
            4: "femur-left",
            5: "bladder",
            6: "heart",
            7: "lung-right",
            8: "lung-left",
            9: "spleen",
            10: "pancreas"
        }
vivi.eval()
for i, (testsample, label) in enumerate(test_dataset):
    # Generate gif
    
    with io.BytesIO() as gif:
        # print(testsample.shape)
        imageio.mimsave(gif, (np.squeeze(testsample,axis=0) * 255).astype("uint8"), "GIF", fps=5)
        videos.append(gif.getvalue())

    # Get model prediction
    output = vivi(paddle.to_tensor(testsample).unsqueeze(0))
    pred = np.argmax(output, axis=-1)[0]

    ground_truths.append(label.astype("int"))
    preds.append(pred)
    if i>=NUM_SAMPLES_VIZ:
        break


def make_box_for_grid(image_widget, fit):
    """Make a VBox to hold caption/image for demonstrating option_fit values.

    Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
    """
    # Make the caption
    if fit is not None:
        fit_str = "'{}'".format(fit)
    else:
        fit_str = str(fit)

    h = ipywidgets.HTML(value="" + str(fit_str) + "")

    # Make the green box with the image widget inside it
    boxb = ipywidgets.widgets.Box()
    boxb.children = [image_widget]

    # Compose into a vertical box
    vb = ipywidgets.widgets.VBox()
    vb.layout.align_items = "center"
    vb.children = [h, boxb]
    return vb


boxes = []
for i in range(NUM_SAMPLES_VIZ):
    ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
    true_class = labels[ground_truths[i][0]]
    pred_class = labels[preds[i]]
    caption = f"T: {true_class} | P: {pred_class}"

    boxes.append(make_box_for_grid(ib, caption))

ipywidgets.widgets.GridBox(
    boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
)
GridBox(children=(VBox(children=(HTML(value="'T: pancreas | P: pancreas'"), Box(children=(Image(value=b'GIF89a…

在这里插入图片描述
此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐