简介
成片的葛藤压倒了佐治亚州的树木,而藤蔓则威胁着全世界十几个国家的生境。这只是众多入侵物种中的两种,它们会对环境、经济、甚至人类健康产生破坏性影响。尽管影响广泛,但追踪入侵物种的位置和传播的努力是如此昂贵,以至于难以大规模地进行。
目前,生态系统和植物分布监测取决于专家知识。训练有素的科学家访问指定区域,并注意到居住在这些区域的物种。使用这种高素质的劳动力是昂贵的,时间效率低,而且不充分,因为人类在取样时不能覆盖大面积。
由于科学家无法对大量的地区进行采样,一些机器学习算法被用来预测未被采样的地区是否存在入侵物种。这种方法的准确性远非最佳,但仍然有助于解决生态问题的方法。
在这个竞赛中,Kagglers面临的挑战是开发算法,以更准确地识别森林和树叶的图像是否含有入侵的绣球花。来自计算机视觉的技术与其他当前技术(如航空成像)一起,可以使入侵物种监测更便宜、更快、更可靠。在这里插入图片描述

  1. 解压数据集
    1.1 数据集介绍
    该数据集包含来自巴西热带雨林的森林覆盖图像,其中一些有绣球花。任务是训练一个卷积神经网络来处理这些图像并预测它们的存在。这种方法旨在减少实地工作所需的体力劳动和训练有素的劳动力。图像总数约为 3800,分为训练集和测试集,分别为 2295 和 1531。

In [1]

!unzip data/data152227/sample_submission.csv.zip -d work/data/

In [2]

!unzip data/data152227/train.zip -d work/data/

In [3]

!unzip data/data152227/test.zip -d work/data/

In [4]

!unzip data/data152227/train_labels.csv.zip -d work/data/

  1. 导入需要的包
    In [ ]
    !pip install paddlex
    In [ ]
    %matplotlib inline
    import paddle
    import numpy as np
    import matplotlib.pyplot as plt
    from paddle.vision.transforms import Transpose
    from paddle.io import Dataset, DataLoader
    from paddle import nn
    import paddle.vision.transforms as transforms
    from paddle.vision.models import resnet50, resnet34
    import os
    import paddle.nn.functional as F
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import figure
    from PIL import Image
    import cv2
    import pandas as pd
    import paddlex
  2. 对数据进行划分并构建数据集
    本次是将训练集中的前2000个样本作为训练,后296个样本作为验证

In [7]

labels = pd.read_csv(“work/data/train_labels.csv” ,index_col=0)

In [8]

trian_label=labels.loc[:2000]

test_label = labels.loc[2000:]

In [9]

trian_label.to_csv(“work/data/train.txt”, sep=‘\t’, mode=‘w’,header=False)

In [10]

test_label.to_csv(“work/data/val.txt”, sep=‘\t’, mode=‘w’,header=False)

In [11]

test_labels = pd.read_csv(“work/data/sample_submission.csv” ,index_col=0)

In [12]

test_labels.to_csv(“work/data/test.txt”, sep=‘\t’, mode=‘w’,header=False)

构建数据集

In [13]
IMG_EXTENSIONS = [
“.jpg”, “.JPG”, “.jpeg”, “.JPEG”,
“.png”, “.PNG”, “.ppm”, “.PPM”, “.bmp”, “.BMP”,
]

def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def default_loader(path):
return Image.open(path).convert(“RGB”)

def make_dataset(root, label, test_flag):
images = []
labeltxt = open(label)
for line in labeltxt:
data = line.strip().split(“\t”)
if is_image_file(data[0]+“.jpg”):
path = os.path.join(root, data[0]+“.jpg”)
gt = np.float32(data[1])
item = (path, gt)
images.append(item)
return images

class ImageSet(Dataset):
def init(self, root, label, transform=None, loader=default_loader,test_flag=0):
imgs = make_dataset(root, label, test_flag)
self.root = root
self.label = label
self.samples = imgs
self.transform = transform
self.loader = loader

def __getitem__(self, index):
    path, gt = self.samples[index]
    img = self.loader(path)

    if self.transform is not None:
        img = self.transform(img)
    
    return img, gt

def __len__(self):
    return len(self.samples)

In [14]

数据增强

train_tfm = transforms.Compose([
transforms.Resize((804, 804)),
transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
transforms.RandomResizedCrop(800, scale=(0.6, 1.0)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
transforms.Resize((800, 800)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
In [15]

设置批量大小

batch_size = 32
data_root = ‘work/data’
train_root = os.path.join(data_root, ‘train/train’)
train_label_file = os.path.join(data_root, ‘train.txt’)
train_set = ImageSet(train_root, loader=lambda x: Image.open(x), label=train_label_file, transform=train_tfm)
val_root = os.path.join(data_root, ‘train/train’)
val_label_file = os.path.join(data_root, ‘val.txt’)
val_set = ImageSet(val_root, loader=lambda x: Image.open(x), label=val_label_file, transform=test_tfm)
test_root = os.path.join(data_root, ‘test’)
test_label_file = os.path.join(data_root, ‘test.txt’)
test_set = ImageSet(test_root, loader=lambda x: Image.open(x), label=test_label_file, transform=test_tfm,test_flag=1)
In [16]
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False)
In [17]
def init_weight(m):
if type(m) == nn.Conv2D or type(m) == nn.Linear:
nn.initializer.KaimingNormal(m.weight)
In [18]

设置学习率并固定参数

learning_rate = 0.0001
n_epochs = 5
paddle.seed(42)
np.random.seed(42)
4. 选用的模型及标准
模型选择
本次实验选择的是resnet34的预训练模型,由于是二分类问题,因此在后面加了一个线性层将其映射为1,并使用sigmoid激活函数

In [19]

组网

net = nn.Sequential()
net.add_sublayer(‘resnet34’,resnet34(pretrained=True))
net.add_sublayer(‘class’,nn.Sequential(nn.Dropout(0.5),nn.Linear(1000, 1),nn.Sigmoid(),nn.Flatten(0)))
W0614 15:21:55.936321 17966 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0614 15:21:55.940794 17966 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.
Sequential(
(0): Dropout(p=0.5, axis=None, mode=upscale_in_train)
(1): Linear(in_features=1000, out_features=1, dtype=float32)
(2): Sigmoid()
(3): Flatten()
)
In [20]

网络结构输出

paddle.summary(net,(batch_size,3,224,224))

Layer (type) Input Shape Output Shape Param #

 Conv2D-1       [[32, 3, 224, 224]]   [32, 64, 112, 112]       9,408     

BatchNorm2D-1 [[32, 64, 112, 112]] [32, 64, 112, 112] 256
ReLU-5 [[32, 64, 112, 112]] [32, 64, 112, 112] 0
MaxPool2D-1 [[32, 64, 112, 112]] [32, 64, 56, 56] 0
Conv2D-2 [[32, 64, 56, 56]] [32, 64, 56, 56] 36,864
BatchNorm2D-2 [[32, 64, 56, 56]] [32, 64, 56, 56] 256
ReLU-6 [[32, 64, 56, 56]] [32, 64, 56, 56] 0
Conv2D-3 [[32, 64, 56, 56]] [32, 64, 56, 56] 36,864
BatchNorm2D-3 [[32, 64, 56, 56]] [32, 64, 56, 56] 256
BasicBlock-1 [[32, 64, 56, 56]] [32, 64, 56, 56] 0
Conv2D-4 [[32, 64, 56, 56]] [32, 64, 56, 56] 36,864
BatchNorm2D-4 [[32, 64, 56, 56]] [32, 64, 56, 56] 256
ReLU-7 [[32, 64, 56, 56]] [32, 64, 56, 56] 0
Conv2D-5 [[32, 64, 56, 56]] [32, 64, 56, 56] 36,864
BatchNorm2D-5 [[32, 64, 56, 56]] [32, 64, 56, 56] 256
BasicBlock-2 [[32, 64, 56, 56]] [32, 64, 56, 56] 0
Conv2D-6 [[32, 64, 56, 56]] [32, 64, 56, 56] 36,864
BatchNorm2D-6 [[32, 64, 56, 56]] [32, 64, 56, 56] 256
ReLU-8 [[32, 64, 56, 56]] [32, 64, 56, 56] 0
Conv2D-7 [[32, 64, 56, 56]] [32, 64, 56, 56] 36,864
BatchNorm2D-7 [[32, 64, 56, 56]] [32, 64, 56, 56] 256
BasicBlock-3 [[32, 64, 56, 56]] [32, 64, 56, 56] 0
Conv2D-9 [[32, 64, 56, 56]] [32, 128, 28, 28] 73,728
BatchNorm2D-9 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
ReLU-9 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-10 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-10 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
Conv2D-8 [[32, 64, 56, 56]] [32, 128, 28, 28] 8,192
BatchNorm2D-8 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
BasicBlock-4 [[32, 64, 56, 56]] [32, 128, 28, 28] 0
Conv2D-11 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-11 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
ReLU-10 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-12 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-12 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
BasicBlock-5 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-13 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-13 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
ReLU-11 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-14 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-14 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
BasicBlock-6 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-15 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-15 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
ReLU-12 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-16 [[32, 128, 28, 28]] [32, 128, 28, 28] 147,456
BatchNorm2D-16 [[32, 128, 28, 28]] [32, 128, 28, 28] 512
BasicBlock-7 [[32, 128, 28, 28]] [32, 128, 28, 28] 0
Conv2D-18 [[32, 128, 28, 28]] [32, 256, 14, 14] 294,912
BatchNorm2D-18 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
ReLU-13 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-19 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-19 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
Conv2D-17 [[32, 128, 28, 28]] [32, 256, 14, 14] 32,768
BatchNorm2D-17 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
BasicBlock-8 [[32, 128, 28, 28]] [32, 256, 14, 14] 0
Conv2D-20 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-20 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
ReLU-14 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-21 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-21 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
BasicBlock-9 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-22 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-22 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
ReLU-15 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-23 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-23 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
BasicBlock-10 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-24 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-24 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
ReLU-16 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-25 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-25 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
BasicBlock-11 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-26 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-26 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
ReLU-17 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-27 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-27 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
BasicBlock-12 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-28 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-28 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
ReLU-18 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-29 [[32, 256, 14, 14]] [32, 256, 14, 14] 589,824
BatchNorm2D-29 [[32, 256, 14, 14]] [32, 256, 14, 14] 1,024
BasicBlock-13 [[32, 256, 14, 14]] [32, 256, 14, 14] 0
Conv2D-31 [[32, 256, 14, 14]] [32, 512, 7, 7] 1,179,648
BatchNorm2D-31 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
ReLU-19 [[32, 512, 7, 7]] [32, 512, 7, 7] 0
Conv2D-32 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,359,296
BatchNorm2D-32 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
Conv2D-30 [[32, 256, 14, 14]] [32, 512, 7, 7] 131,072
BatchNorm2D-30 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
BasicBlock-14 [[32, 256, 14, 14]] [32, 512, 7, 7] 0
Conv2D-33 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,359,296
BatchNorm2D-33 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
ReLU-20 [[32, 512, 7, 7]] [32, 512, 7, 7] 0
Conv2D-34 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,359,296
BatchNorm2D-34 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
BasicBlock-15 [[32, 512, 7, 7]] [32, 512, 7, 7] 0
Conv2D-35 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,359,296
BatchNorm2D-35 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
ReLU-21 [[32, 512, 7, 7]] [32, 512, 7, 7] 0
Conv2D-36 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,359,296
BatchNorm2D-36 [[32, 512, 7, 7]] [32, 512, 7, 7] 2,048
BasicBlock-16 [[32, 512, 7, 7]] [32, 512, 7, 7] 0
AdaptiveAvgPool2D-1 [[32, 512, 7, 7]] [32, 512, 1, 1] 0
Linear-1 [[32, 512]] [32, 1000] 513,000
ResNet-1 [[32, 3, 224, 224]] [32, 1000] 0
Dropout-1 [[32, 1000]] [32, 1000] 0
Linear-2 [[32, 1000]] [32, 1] 1,001
Sigmoid-2 [[32, 1]] [32, 1] 0
Flatten-1 [[32, 1]] [32] 0

Total params: 21,815,697
Trainable params: 21,781,649
Non-trainable params: 34,048

Input size (MB): 18.38
Forward/backward pass size (MB): 2744.86
Params size (MB): 83.22
Estimated Total Size (MB): 2846.45

{‘total_params’: 21815697, ‘trainable_params’: 21781649}
Binary Accuracy
由于paddle没有实现Binary Accuracy标准用于二分类评估,故在这里实现
Binary Accuracy主要的思想是给定一个阈值对于sigmoid的输出做判断,如果大于该阈值我们将其认为是标签1,反之则认为是0
本文的实现上略有不同,我们将其转化为判断真实标签与sigmoid的输出是否小于某个阈值以判断其是否预测正确,本次实验阈值为0.5

In [21]
class Binary_Accuracy(paddle.metric.Metric):
def init(self, topk=(1, ), threshold=0.5, name=None, *args, **kwargs):
super(Binary_Accuracy, self).init(*args, **kwargs)
self._init_name(name)
self.reset()
self.threshold=threshold # 设置置信阈值
self.total = 0
self.count = 0

def compute(self, pred, label, *args):
    """
    计算正确预测位置的函数,1表示预测正确,0表示预测错误
    """
    correct = paddle.abs(pred-label) <= self.threshold # 预测和真实差的绝对值小于阈值的位置
    return paddle.cast(correct, dtype='float32')

def update(self, correct, *args):
    """
    更新函数
    """
    if isinstance(correct, (paddle.Tensor, paddle.fluid.core.eager.Tensor)):
        correct = correct.numpy()
    num_samples = np.prod(np.array(correct.shape[0])) # 样本总个数
    num_corrects = correct.sum() # 预测正确的个数
    self.total += num_corrects
    self.count +=  num_samples
    accs = float(correct.sum()) / num_samples # 计算准确率
    return accs

def reset(self):
    """
    重置函数
    """
    self.total = [0.]
    self.count = [0]

def accumulate(self):
    """
    累加函数
    """
    return self.total / self.count

def _init_name(self, name):
    """
    初始化名字的函数
    """
    self._name = name or 'acc'

def name(self):
    """
    返回名字的函数
    """
    return self._name

对于本文实现的Binary Accuracy进行测试
In [22]
import numpy as np
import paddle

x = paddle.to_tensor(np.array([
[0.6],
[0.6],
[0.2],
[0.6]]))
y = paddle.to_tensor(np.array([[0.0], [1.0], [1.0], [1.0]]))

m = Binary_Accuracy()
correct = m.compute(x, y)
print(correct)
m.update(correct)
res = m.accumulate()
print(res) # 0.5
Tensor(shape=[4, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[0.],
[1.],
[0.],
[1.]])
0.5
5. 模型的训练
In [23]
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=len(train_set) // batch_size * n_epochs, verbose=False)
model = paddle.Model(net)
optimizer = paddle.optimizer.AdamW(parameters=model.parameters(), learning_rate=learning_rate,weight_decay=2e-7)
model.prepare(optimizer, paddle.nn.BCEWithLogitsLoss(), Binary_Accuracy())
model.fit(train_data=train_loader,eval_data=val_loader,epochs=n_epochs,verbose=2)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/5
step 10/62 - loss: 0.5026 - acc: 0.7812 - 3s/step
step 20/62 - loss: 0.4985 - acc: 0.8375 - 3s/step
step 30/62 - loss: 0.4879 - acc: 0.8635 - 3s/step
step 40/62 - loss: 0.4821 - acc: 0.8773 - 3s/step
step 50/62 - loss: 0.4614 - acc: 0.8888 - 3s/step
step 60/62 - loss: 0.4599 - acc: 0.8958 - 3s/step
step 62/62 - loss: 0.4609 - acc: 0.8977 - 3s/step
Eval begin…
step 10/10 - loss: 0.3934 - acc: 0.9561 - 1s/step
Eval samples: 296
Epoch 2/5
step 10/62 - loss: 0.5536 - acc: 0.9062 - 3s/step
step 20/62 - loss: 0.4676 - acc: 0.9406 - 3s/step
step 30/62 - loss: 0.4533 - acc: 0.9531 - 3s/step
step 40/62 - loss: 0.4862 - acc: 0.9563 - 3s/step
step 50/62 - loss: 0.4682 - acc: 0.9606 - 3s/step
step 60/62 - loss: 0.4952 - acc: 0.9620 - 3s/step
step 62/62 - loss: 0.4680 - acc: 0.9627 - 3s/step
Eval begin…
step 10/10 - loss: 0.3607 - acc: 0.9831 - 1s/step
Eval samples: 296
Epoch 3/5
step 10/62 - loss: 0.4795 - acc: 0.9812 - 3s/step
step 20/62 - loss: 0.4676 - acc: 0.9781 - 3s/step
step 30/62 - loss: 0.4936 - acc: 0.9781 - 3s/step
step 40/62 - loss: 0.4863 - acc: 0.9781 - 3s/step
step 50/62 - loss: 0.4403 - acc: 0.9762 - 3s/step
step 60/62 - loss: 0.4378 - acc: 0.9760 - 3s/step
step 62/62 - loss: 0.4693 - acc: 0.9743 - 3s/step
Eval begin…
step 10/10 - loss: 0.3608 - acc: 0.9797 - 1s/step
Eval samples: 296
Epoch 4/5
step 10/62 - loss: 0.4558 - acc: 0.9750 - 3s/step
step 20/62 - loss: 0.4201 - acc: 0.9750 - 3s/step
step 30/62 - loss: 0.4803 - acc: 0.9698 - 3s/step
step 40/62 - loss: 0.4682 - acc: 0.9711 - 3s/step
step 50/62 - loss: 0.4805 - acc: 0.9725 - 3s/step
step 60/62 - loss: 0.4806 - acc: 0.9745 - 3s/step
step 62/62 - loss: 0.4686 - acc: 0.9748 - 3s/step
Eval begin…
step 10/10 - loss: 0.3607 - acc: 0.9932 - 1s/step
Eval samples: 296
Epoch 5/5
step 10/62 - loss: 0.4699 - acc: 0.9750 - 3s/step
step 20/62 - loss: 0.5559 - acc: 0.9703 - 3s/step
step 30/62 - loss: 0.4950 - acc: 0.9729 - 3s/step
step 40/62 - loss: 0.4679 - acc: 0.9781 - 3s/step
step 50/62 - loss: 0.4677 - acc: 0.9794 - 3s/step
step 60/62 - loss: 0.5045 - acc: 0.9792 - 3s/step
step 62/62 - loss: 0.4388 - acc: 0.9788 - 3s/step
Eval begin…
step 10/10 - loss: 0.3607 - acc: 0.9899 - 1s/step
Eval samples: 296
6. 模型的测试
In [24]
preds = model.predict(test_set)
Predict begin…
step 1531/1531 [==============================] - ETA: 1:09 - 45ms/ste - ETA: 1:09 - 46ms/ste - ETA: 1:11 - 47ms/ste - ETA: 1:09 - 46ms/ste - ETA: 1:08 - 45ms/ste - ETA: 1:07 - 44ms/ste - ETA: 1:06 - 44ms/ste - ETA: 1:07 - 44ms/ste - ETA: 1:06 - 44ms/ste - ETA: 1:06 - 44ms/ste - ETA: 1:05 - 44ms/ste - ETA: 1:06 - 44ms/ste - ETA: 1:06 - 44ms/ste - ETA: 1:06 - 44ms/ste - ETA: 1:05 - 43ms/ste - ETA: 1:04 - 43ms/ste - ETA: 1:05 - 43ms/ste - ETA: 1:04 - 43ms/ste - ETA: 1:04 - 43ms/ste - ETA: 1:03 - 43ms/ste - ETA: 1:03 - 43ms/ste - ETA: 1:03 - 42ms/ste - ETA: 1:02 - 42ms/ste - ETA: 1:02 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:02 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:01 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 1:00 - 42ms/ste - ETA: 59s - 42ms/ste - ETA: 59s - 42ms/st - ETA: 59s - 42ms/st - ETA: 59s - 42ms/st - ETA: 59s - 42ms/st - ETA: 59s - 42ms/st - ETA: 59s - 41ms/st - ETA: 58s - 41ms/st - ETA: 58s - 41ms/st - ETA: 58s - 41ms/st - ETA: 58s - 41ms/st - ETA: 58s - 41ms/st - ETA: 58s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 57s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 56s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 55s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 54s - 41ms/st - ETA: 53s - 41ms/st - ETA: 53s - 41ms/st - ETA: 53s - 41ms/st - ETA: 53s - 41ms/st - ETA: 53s - 40ms/st - ETA: 53s - 40ms/st - ETA: 53s - 40ms/st - ETA: 53s - 40ms/st - ETA: 53s - 40ms/st - ETA: 53s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 52s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 51s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 50s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 49s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 48s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 47s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 46s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 45s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 44s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 43s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 42s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 41s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 40s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 39s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 40ms/st - ETA: 38s - 41ms/st - ETA: 38s - 41ms/st - ETA: 38s - 41ms/st - ETA: 38s - 41ms/st - ETA: 38s - 41ms/st - ETA: 38s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 37s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 36s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 35s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 34s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 33s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 32s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 31s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 30s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 29s - 41ms/st - ETA: 28s - 41ms/st - ETA: 28s - 41ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 28s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 27s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 26s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 25s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 24s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 23s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 22s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 21s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 20s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 19s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 18s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 17s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 16s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 15s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 14s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 13s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 12s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 11s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 10s - 42ms/st - ETA: 9s - 42ms/step - ETA: 9s - 42ms/ste - ETA: 9s - 41ms/ste - ETA: 9s - 41ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 9s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 8s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 7s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 41ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 42ms/ste - ETA: 6s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 42ms/ste - ETA: 5s - 42ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 5s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 4s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 3s - 41ms/ste - ETA: 2s - 41ms/ste - ETA: 2s - 41ms/ste - ETA: 2s - 41ms/ste - ETA: 2s - 41ms/ste - ETA: 2s - 41ms/ste - ETA: 2s - 42ms/ste - ETA: 2s - 42ms/ste - ETA: 2s - 42ms/ste - ETA: 2s - 42ms/ste - ETA: 2s - 42ms/ste - ETA: 2s - 42ms/ste - ETA: 2s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 1s - 42ms/ste - ETA: 0s - 42ms/ste - ETA: 0s - 42ms/ste - ETA: 0s - 42ms/ste - ETA: 0s - 42ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - ETA: 0s - 41ms/ste - 41ms/step
Predict samples: 1531
In [ ]

对测试集进行预测并写入csv文件

sub = pd.read_csv(“work/data/sample_submission.csv”)

print(sub[“invasive”])

preds_np = np.array(preds).reshape((1,-1))

print(preds_np.shape)

sub[“invasive”] =preds_np[0]
sub.to_csv(“submission.csv”,index=False)
7. Kaggle提交的结果在这里插入图片描述
8. 总结
对于绣球花数据集来讲相对简单,使用预训练模型仅进行5个epochs的微调,在验证集上的精度最高可以达到0.9932,并且本次实验在Kaggle上提交了一下测试集的结果,精度可以达到0.99245。此外本次实验没有使用多模型融合技术,因此与Kaggle最先进的方法仍有一定差距(0.99245 vs 0.99770),在未来将尝试多模型融合技术进一步提高性能。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐