一、简介

基于孪生网络的单目标追踪网络SiamFC是单目标追踪领域相当重要的一篇经典工作,本文将从网络结构,数据处理,损失函数,跟踪方法四个方面简单介绍SiamFC的具体内容。本文仅为个人的学习笔记,如有错误或者理解不到位的地方欢迎讨论。

论文地址:https://arxiv.org/pdf/1606.09549.pdf

本项目使用的数据集为GOT10K:

数据集网站:http://got-10k.aitestunion.com/

paddle版本的代码实现参考了GitHub上的各种pytorch版本:

pytorch版本:https://github.com/NieHa0ha0/siamfc-pytorch-got10k

# 解压数据集和安装数据集的工具包
# 这里仅使用spilt01-03测试,list.txt文件是数据的索引,
# !mkdir work/data
# !mkdir work/data/train
# !mv work/list.txt work/data/train/
# !unzip data/data126011/GOT-10k_Train_split_01.zip -d work/data/train/
# !unzip data/data126011/GOT-10k_Train_split_02.zip -d work/data/train/
# !unzip data/data126011/GOT-10k_Train_split_03.zip -d work/data/train/
# !pip install got10k
#各种配置参数

class Config:
    # basic parameters
    out_scale = 0.001       # 互相关之后需要乘一个系数,否则值太大会导致梯度爆炸
    exemplar_sz = 127       # 搜索图像的尺寸
    instance_sz = 255       # 模板图像的尺寸
    context = 0.5           # 上下文的比例
    # inference parameters
    scale_num = 3           # 推理中的不同尺度数
    scale_step = 1.0375     # 检测框的变化尺度
    scale_lr = 0.59         # 尺度变化的学习率
    scale_penalty = 0.9745  # 尺度变化的惩罚系数
    window_influence = 0.176    #汉宁窗惩罚的系数
    response_sz = 17        # 响应图的尺寸
    response_up = 16        # 插值
    total_stride = 8        # 步幅
    # train parameters
    epoch_num = 50          
    batch_size = 64
    num_workers = 2
    initial_lr = 1e-2
    ultimate_lr = 1e-5
    weight_decay = 5e-4
    r_pos = 16
    r_neg = 0


cfg = Config()

二、网络结构

SiamFC的网络结构如图所示,该网络有两个输入,分别是作为搜索模板的Z和作为搜索目标的X,两个输入分别通过权值共享的特征提取网络,分别得到了两个特征图,模板的特征图尺寸为6x6x128,目标的特征图尺寸为22x22x128。
作者使用的特征提取网络是AlexNet,网络的具体参数和每层的输入及输出如下表所示

之后,通过互相关操作,得到一个17x17的输出,该输出的每个值代表模板与目标在当前位置的相似程度。其中互相关操作的实现方式是将模板的特征图作为卷积核,与目标的特征图进行卷积操作,代码中是通过调用paddle.nn.functional.conv2d实现的。 除了直接使用卷积计算,后续的学者还提出了深度相关操作,深度相关与深度可分离卷积有类似的思想,分离通道进行卷积,可以降低计算量。

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import time

paddle.set_device='cpu'
# backbone
class AlexNet(nn.Layer):
    def __init__(self, out_channels, init=False, conv_weight_init=None, bias_init=None):
        super(AlexNet, self).__init__()
        self.conv_weight_init = conv_weight_init
        self.bias_init = bias_init
        # 如果需要初始化,则为初始化的两个变量赋值
        if init:
            self._init_weights()
        self.conv1 = nn.Sequential(
            nn.Conv2D(in_channels=3, out_channels=96, kernel_size=11, stride=2, padding=0
                      , weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
            nn.BatchNorm2D(96),
            nn.ReLU(),
            nn.MaxPool2D(3, 2))
        self.conv2 = nn.Sequential(
            nn.Conv2D(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=0
                      , weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
            nn.BatchNorm2D(256),
            nn.ReLU(),
            nn.MaxPool2D(3, 2))
        self.conv3 = nn.Sequential(
            nn.Conv2D(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=0
                      , weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
            nn.BatchNorm2D(384),
            nn.ReLU())
        self.conv4 = nn.Sequential(
            nn.Conv2D(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=0
                      , weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
            nn.BatchNorm2D(384),
            nn.ReLU())
        self.conv5 = nn.Sequential(
            nn.Conv2D(in_channels=384, out_channels=out_channels, kernel_size=3, stride=1, padding=0
                      , weight_attr=self.conv_weight_init, bias_attr=self.bias_init))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

    # 参数初始化,权重和偏置分别使用XavierUniform和Constant初始化
    def _init_weights(self):
        self.conv_weight_init = nn.initializer.XavierUniform()
        self.bias_init = nn.initializer.Constant(value=0)


# RPN模块 (暂时用不到)
# class Rpn(nn.Layer):
#     def __init__(self, anchor_num):
#         super(Rpn, self).__init__()
#         self.anchor_num = anchor_num  # 锚框数
#         self.conv_x_cls = nn.Conv2D(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
#         self.conv_x_reg = nn.Conv2D(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
#         self.conv_z_cls = nn.Conv2D(in_channels=256, out_channels=2 * anchor_num * 256, kernel_size=3, stride=1,
#                                     padding=0)
#         self.conv_z_reg = nn.Conv2D(in_channels=256, out_channels=4 * anchor_num * 256, kernel_size=3, stride=1,
#                                     padding=0)

#     def Xcorr_train(self, x, z, type):
#         out = []
#         b = x.shape[0]
#         k = 2 if type == 'cls' else 4
#         for i in range(b):
#             out.append(
#                 F.conv2d(x[i, :, :, :].unsqueeze(0),
#                          paddle.reshape(z[i, :, :, :].unsqueeze(0), [k * self.anchor_num, -1, 0, 0]),
#                          # 调整模板tensor的shape[1,2k,4,4]
#                          stride=1, padding=0))
#         out = paddle.concat(out, axis=0)
#         return out

#     def Xcorr_test(self, x, z, type):
#         k = 2 if type == 'cls' else 4
#         out = F.conv2d(x,
#                        paddle.reshape(z, [k * self.anchor_num, -1, 0, 0]),
#                        stride=1, padding=0)
#         return out

#     def forward(self, x, z):
#         x_reg = self.conv_x_reg(x)
#         z_reg = self.conv_z_reg(z)
#         x_cls = self.conv_x_cls(x)
#         z_cls = self.conv_z_cls(z)
#         cls_out = self.Xcorr_train(x_cls, z_cls, 'cls')
#         reg_out = self.Xcorr_train(x_reg, z_reg, 'reg')
#         return cls_out, reg_out

#     def track_init(self, z):
#         z_reg = self.conv_z_reg(z)
#         z_cls = self.conv_z_cls(z)
#         return z_reg, z_cls

#     def track_update(self, x, z_cls, z_reg):
#         x_reg = self.conv_x_reg(x)
#         x_cls = self.conv_x_cls(x)
#         cls_out = self.Xcorr_test(x_cls, z_cls, 'cls')
#         reg_out = self.Xcorr_test(x_reg, z_reg, 'reg')
#         return cls_out, reg_out
#——————————————————————————————————————————————————————————
# 使用for循环实现相关计算,速度较慢
def Xcorr(x, z):
    out = []
    b = x.shape[0]  # 每个batch独立计算,之后concat
    for i in range(b):
        out.append(
            nn.functional.conv2d(x[i, :, :, :].unsqueeze(0), z[i, :, :, :].unsqueeze(0), stride=1, padding=0))

    out = paddle.concat(out, axis=0)
    return out

# 参考pysot,使用分组卷积加速相关计算
def Xcorr_fast(x,z):
    b = z.shape[0]
    z = paddle.reshape(z,[-1,x.shape[1],0,0])
    x = paddle.reshape(x, [1,-1,0,0])
    out = F.conv2d(x,z,groups=b)
    out = paddle.reshape(out,[b,-1,0,0])
    return out

# 推理时,模板z不再更新,batch始终为1
def Xcorr_test(x, z):
    out = []
    b = x.shape[0]  # 每个batch独立计算,之后concat
    for i in range(b):
        out.append(
            nn.functional.conv2d(x[i, :, :, :].unsqueeze(0), z, stride=1, padding=0))

    out = paddle.concat(out, axis=0)
    return out
#————————————————————————————————————————————————————————
# 构建Siamfc类
class Siamfc(nn.Layer):
    def __init__(self, out_scale, init=False):
        super(Siamfc, self).__init__()
        self.out_scale = out_scale
        self.backbone = AlexNet(out_channels=128, init=init)

    def forward(self, x, z, mode='train'):  # x:detect z:template
        x = self.backbone(x)
        z = self.backbone(z)
        if mode == 'train':
            # out = Xcorr(x, z) * self.out_scale
            out = Xcorr_fast(x,z) * self.out_scale
            return out
        else:
            out = Xcorr_test(x, z) * self.out_scale
            return out


# class SiamRpn(nn.Layer):
#     def __init__(self, anchor_num):
#         super(SiamRpn, self).__init__()
#         self.anchor_num = anchor_num
#         self.backbone = AlexNet(out_channels=256)
#         self.head = Rpn(self.anchor_num)

#     def forward(self, x, z):
#         x = self.backbone(x)
#         z = self.backbone(z)
#         cls_out, reg_out = self.head(x, z)
#         return cls_out, reg_out

#     def track_init(self, z):
#         z = self.backbone(z)
#         reg_z, cls_z = Rpn(self.anchor_num).track_init(z)
#         return reg_z, cls_z

#     def track_update(self, x, cls_z, reg_z):
#         x = self.backbone(x)
#         cls_out, reg_out = Rpn(self.anchor_num).track_update(x, cls_z, reg_z)
#         return cls_out, reg_out


def main():
    model_SiamFc = Siamfc(out_scale=0.001, init=True)
    z_train = paddle.randn([32, 3, 128, 128])  
    x_train = paddle.randn([32, 3, 255, 255])
    pred_score = model_SiamFc(x_train, z_train, mode='train')
    print(pred_score.shape)



if __name__ == '__main__':
    main()

[32, 1, 17, 17]

三、数据处理

训练数据来源为GOT10K,训练中的图片是从视频序列中随机抽取两帧图片,通过图像处理将其作为模板和目标输入网络中。

以一个数据集中的图片为例解析图片处理的过程:
读取原始图像和bbox后,需要从原始图像中切出包含上下文信息的patch,patch 的中心即bbox的中心,patch的size由下面的公式计算:

其中x和z分别是目标和模板的size,h和w为bbox的高和宽,得到patch的大小后,在原图中,以bbox的中心为patch的中心切出来,会有下面两种情况,如果patch没有超出边界,则直接切即可,如果超出了边界则需要对原图使用像素的均值进行padding后再切出patch。

之后对patch进行resize得到255x255的patch,再通过一些随机缩放,随机切块等数据增强操作,即可得到255x255的目标和127x127的模板

流程如下图所示:


解压完数据集后,还需要一个list.txt文件,该文件是数据集的目录,路径在/work/data/train下

import numbers
from got10k.datasets import *
import paddle
import numpy as np
import cv2


# [l,t,w,h] -> [y,x,h,w]
def convert_coordinate(org_box: list):
    box = np.array([
        org_box[1] - 1 + (org_box[3] - 1) / 2,
        org_box[0] - 1 + (org_box[2] - 1) / 2,
        org_box[3], org_box[2]], dtype=np.float32)
    return box


# 就是把一系列的transforms串起来
class Compose(object):  # 继承了object类,就拥有了object类里面好多可以操作的对象

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):  # 为了将类的实例对象变为可调用对象(相当于重载()运算符)  a=Compose() a.__call__()   和a()的使用是一样的
        for t in self.transforms:
            img = t(img)
        return img


# 主要是随机的resize图片的大小,变化再[1 1.05之内]其中要注意cv2.resize()的一点用法
class RandomStretch(object):

    def __init__(self, max_stretch=0.05):
        self.max_stretch = max_stretch

    def __call__(self, img):
        interp = np.random.choice([  # 调用interp时候随机选择一个
            cv2.INTER_LINEAR,  # 双线性插值(默认设置)
            cv2.INTER_CUBIC,  # 4x4像素领域的双三次插值
            cv2.INTER_AREA,  # 像素区域关系重采样,类似与NEAREST
            cv2.INTER_NEAREST,  # 最近领插值
            cv2.INTER_LANCZOS4])  # 8x8像素的Lanczosc插值
        scale = 1.0 + np.random.uniform(
            -self.max_stretch, self.max_stretch)
        out_size = (
            round(img.shape[1] * scale),  # 这里是width
            round(img.shape[0] * scale))  # 这里是heigth  cv2的用法导致
        return cv2.resize(img, out_size, interpolation=interp)  # 将img的大小resize成out_size


# 从img中心抠一块(size, size)大小的patch,如果不够大,以图片均值进行pad之后再crop
class CenterCrop(object):

    def __init__(self, size):
        if isinstance(size, numbers.Number):  # isinstance(object, classinfo) 判断实例是否是这个类或者object是变量
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        h, w = img.shape[:2]  # img.shape为[height,width,channel]
        tw, th = self.size
        i = round((h - th) / 2.)  # round(x,n) 对x四舍五入,保留n位小数 省略n 0位小数
        j = round((w - tw) / 2.)

        npad = max(0, -i, -j)
        if npad > 0:
            avg_color = np.mean(img, axis=(0, 1))  # 取整个图片的像素均值
            img = cv2.copyMakeBorder(  # 添加边框函数,上下左右要扩展的像素数都是npad,BORDER_CONSTANT固定值填充,值为avg_color)
                img, npad, npad, npad, npad,
                cv2.BORDER_CONSTANT, value=avg_color)
            i += npad
            j += npad

        return img[i:i + th, j:j + tw]


# 用法类似CenterCrop,只不过从随机的位置抠,没有pad的考虑
class RandomCrop(object):

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        h, w = img.shape[:2]
        tw, th = self.size
        i = np.random.randint(0, h - th + 1)
        j = np.random.randint(0, w - tw + 1)
        return img[i:i + th, j:j + tw]


# 调整图片的尺寸
class Resize(object):
    def __init__(self,size):
        self.size = size

    def __call__(self,img):
        interp = np.random.choice([  # 调用interp时候随机选择一个
            cv2.INTER_LINEAR,  # 双线性插值(默认设置)
            cv2.INTER_CUBIC,  # 4x4像素领域的双三次插值
            cv2.INTER_AREA,  # 像素区域关系重采样,类似与NEAREST
            cv2.INTER_NEAREST,  # 最近领插值
            cv2.INTER_LANCZOS4])  # 8x8像素的Lanczosc插值
        out = cv2.resize(img,dsize=(self.size,self.size),interpolation=interp)
        return out


class ToTensor(object):
    def __call__(self, img):
        out = paddle.to_tensor(img,'float32')
        return paddle.transpose(out,perm=[2,0,1])


class SiamFCTransforms(object):
    def __init__(self, exemplar_sz=127, instance_sz=255, context=0.5):
        self.exemplar_sz = exemplar_sz
        self.instance_sz = instance_sz
        self.context = context
        self.transforms_z = Compose(
            [RandomStretch(),
             CenterCrop(instance_sz - 8),
             RandomCrop(instance_sz - 16),
             CenterCrop(exemplar_sz),
             ToTensor()])
        self.transforms_x = Compose([
            RandomStretch(),
            CenterCrop(instance_sz - 8),
            RandomCrop(instance_sz - 16),
            Resize(instance_sz),
            ToTensor()])


        
    def __call__(self, z, x, box_z, box_x):
        z = self._crop(z, box_z, self.instance_sz)  # 对z(x类似)图像 1、box转换(l,t,w,h)->(y,x,h,w),并且数据格式转为float32,得到center[y,x],和target_sz[h,w]
        x = self._crop(x, box_x, self.instance_sz)  # 2、得到size=((h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127
        z = self.transforms_z(z)  # 3、进入crop_and_resize:传入z作为图片img,center,size,outsize=255(instance_sz),随机选方式填充,均值填充
        x = self.transforms_x(x)  # 以center为中心裁剪一块边长为size大小的正方形框(注意裁剪时的padd边框填充问题),再resize成out_size=255(instance_sz)
        return z, x

    def _crop(self, img, box, out_size):
        box = convert_coordinate(box)  # 将[xmin,ymin,w,h]转换成[y,x,h,w]
        center,target_sz = box[:2],box[2:]

        context = self.context * np.sum(target_sz)
        size = np.sqrt(np.prod(target_sz + context))
        size *= out_size / self.exemplar_sz

        avg_color = np.mean(img, axis=(0, 1), dtype=float)
        interp = np.random.choice([
            cv2.INTER_LINEAR,
            cv2.INTER_CUBIC,
            cv2.INTER_AREA,
            cv2.INTER_NEAREST,
            cv2.INTER_LANCZOS4])
        patch = crop_and_resize(
            img, center, size, out_size, border_value=avg_color, interp=interp)

        return patch


def crop_and_resize(img, center, size, out_size,
                    border_type=cv2.BORDER_CONSTANT,
                    border_value=(0, 0, 0),  # border_value使用的是图像均值(averageR,aveG,aveB)
                    interp=cv2.INTER_LINEAR):
    size = round(size)  # 对size取整
    corners = np.concatenate((  # np.concatenate:数组的凭借 np.concatenate((a,b),axis)  axis=0是列拼接,axis=1是行拼接 省略axis为0
        np.round(center - (size - 1) / 2),
        np.round(center - (size - 1) / 2) + size))  # 得到corners=[ymin,xmin,ymax,xmax]
    corners = np.round(corners).astype(int)  # 转化为int型
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # print(img.shape)
    # cv2.imshow('original', img)
    # cv2.imwrite('original.png', img)
    # 填充
    pads = np.concatenate((
        -corners[:2], corners[2:] - img.shape[:2]))
    npad = max(0, int(max(pads)))  # 得到4个值中最大的与0对比
    if npad > 0:
        img = cv2.copyMakeBorder(
            img, npad, npad, npad, npad,
            border_type, value=border_value)

    # crop image patch
    corners = (corners + npad).astype(int)  # 如果经行了填充,那么中心坐标也要变
    # print(corners)
    patch = img[corners[0]:corners[2], corners[1]:corners[3]]  # 得到patch的大小
    # cv2.imshow('padding_img',img)
    # cv2.imwrite('padding_img.png', img)
    # cv2.imshow('contest_img',patch)
    # cv2.imwrite('contest_img.png', patch)
    # resize to out_size
    patch = cv2.resize(patch, (out_size, out_size),
                       interpolation=interp)
    # cv2.imshow('resize255_img',patch)
    # cv2.imwrite('resize255_img.png', patch)
    cv2.waitKey(0)
    return patch


class GOT10kDataset(paddle.io.Dataset):
    def __init__(self, seqs, transforms=None, pairs_per_seq=1):
        super(GOT10kDataset, self).__init__()
        self.seqs = seqs
        self.transforms = transforms
        self.pairs_per_seq = pairs_per_seq
        self.indices = np.random.permutation(len(seqs))
        self.return_meta = getattr(seqs, 'return_meta')  # 判断return_meta是否在segs中,如果不在,返回False,在的话返回1

    # 通过index索引返回item=(z,x,box_z,box_x),然后经过transforms返回一对pair(z,x)
    def __getitem__(self, index):
        # print(self.indices)
        index = self.indices[index % len(self.indices)]
        # print(index)
        # index = self.indices[index] 与上相同
        # get filename lists and annotations
        if self.return_meta:
            img_files, anno, meta = self.seqs[index]
            via_ratios = meta.get('cover', None)
        else:
            img_files, anno = self.seqs[index][:2]
            via_ratios = None

        val_indices = self._filter(
            cv2.imread(img_files[0], cv2.IMREAD_COLOR),
            anno, via_ratios)
        if len(val_indices) < 2:
            index = np.random.choice(len(self))
            return self.__getitem__(index)

        rand_z, rand_x = self._sample_pair_(val_indices)

        z = cv2.imread(img_files[rand_z], cv2.IMREAD_COLOR)
        x = cv2.imread(img_files[rand_x], cv2.IMREAD_COLOR)
        z = cv2.cvtColor(z, cv2.COLOR_BGR2RGB)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)

        box_z = anno[rand_z]
        box_x = anno[rand_x]

        item = (z, x, box_z, box_x)  # box就是ground_truth
        if self.transforms is not None:
            item = self.transforms(*item)

        return item

    # 这里定义的长度就是被索引到的视频序列数x每个序列提供的对数(1对)
    def __len__(self):
        return len(self.indices) * self.pairs_per_seq  # len(self.indices)=9335  返回9335*1对

    # 随机挑选两个索引,这里取的间隔不超过T=100
    def _sample_pair_(self, indices):
        n = len(indices)
        assert n > 0
        if n == 1:
            return indices[0], indices[0]
        elif n == 2:
            return indices[0], indices[1]
        else:
            for i in range(100):
                rand_z, rand_x = np.sort(
                    np.random.choice(indices, 2, replace=False))
                if rand_x - rand_z < 100:
                    break
            else:
                rand_z = np.random.choice(indices)
                rand_x = rand_z

            return rand_z, rand_x

    # 通过该函数筛选符合条件的有效索引val_indices
    def _filter(self, img0, anno, vis_ratios=None):
        size = np.array(img0.shape[1::-1])[np.newaxis, :]
        areas = anno[:, 2] * anno[:, 3]

        # acceptance conditions
        c1 = areas >= 20
        c2 = np.all(anno[:, 2:] >= 20, axis=1)
        c3 = np.all(anno[:, 2:] <= 500, axis=1)
        c4 = np.all((anno[:, 2:] / size) >= 0.01, axis=1)
        c5 = np.all((anno[:, 2:] / size) <= 0.5, axis=1)
        c6 = (anno[:, 2] / np.maximum(1, anno[:, 3])) >= 0.25
        c7 = (anno[:, 2] / np.maximum(1, anno[:, 3])) <= 4
        if vis_ratios is not None:
            c8 = (vis_ratios > max(1, vis_ratios.max() * 0.3))
        else:
            c8 = np.ones_like(c1)

        mask = np.logical_and.reduce(
            (c1, c2, c3, c4, c5, c6, c7, c8))
        val_indices = np.where(mask)[0]

        return val_indices


if __name__ == "__main__":
    root_dir = 'work/data/'
    seq_dataset = GOT10k(root_dir, subset='train')
    transforms = SiamFCTransforms(
        exemplar_sz=cfg.exemplar_sz,  # 127
        instance_sz=cfg.instance_sz,  # 255
        context=cfg.context)  # 0.5
    train_dataset = GOT10kDataset(seq_dataset, transforms)
    item = train_dataset.__getitem__(1)  # 返回随机的某个视频序列两帧处理后的图片
    print(item[0].shape)
    print(train_dataset.__len__())


[3, 127, 127]
1500

四、损失函数

论文中采用的loss函数如图

其中,yv分别是真实标签(取值+1和-1)和上文17x17的输出中xz相关程度的预测值,l(y,v)即为输出上的某点的loss。
而整个输出的loss为

D为输出图。即Loss为所有点的loss求和再除以整体输出

代码中,l 的实现是通过binary_cross_entropy_with_logits实现的,标签的构造也非论文中的{-1,+1},而是{0,1},推导过程如下

#loss和label
import numpy as np
import paddle
import paddle.nn as nn


def logistic_labels(x, y, r_pos):
    # x^2+y^2<4 的位置设为为1,其他为0
    dist = np.sqrt(x ** 2 + y ** 2)
    labels = np.where(dist <= r_pos, # r_pos=2
                      np.ones_like(x),
                      np.zeros_like(x))
    return labels


def get_label(size):

    n, c, h, w = size    # [8,1,17,17]
    x = np.arange(w) - (w - 1) / 2
    y = np.arange(h) - (h - 1) / 2
    x, y = np.meshgrid(x, y)

    r_pos = cfg.r_pos / cfg.total_stride
    labels = logistic_labels(x, y, r_pos)
    labels = labels.reshape((1, 1, h, w))
    labels = np.tile(labels, (n, c, 1, 1))
    return labels


class GetLoss(nn.Layer):
    def __init__(self, neg_weight=1.0):
        super(GetLoss, self).__init__()
        self.neg_weight = neg_weight

    def forward(self, input, target):
        pos_mask = (target == 1)
        neg_mask = (target == 0)
        pos_num = float(pos_mask.sum())
        neg_num = float(neg_mask.sum())
        weight = paddle.to_tensor(np.zeros(target.shape),'float32')
        weight[pos_mask] = 1 / pos_num
        weight[neg_mask] = 1 / neg_num * self.neg_weight
        weight /= weight.sum()
        return paddle.nn.functional.binary_cross_entropy_with_logits(
            input, target, weight, reduction='sum')


if __name__ == '__main__':
    labels = get_label([8, 1, 17, 17])

五、网络训练

建立了损失函数和标签之后,即可进行训练,训练采用了引入一阶动量的随机梯度下降优化参数,同时使用指数衰减调整学习率,公式如下:

其中
为初始学习率,为超参数,定义如下:

epoch为当前的epoch,Epoch为总的epoch

参数设置:学习率从1e-2衰减到1e-5,epoch总数50,batchsize为64

在网络的输出引入一个缩小输出值的scale,因为原本的网络相关输出的值很大,直接计算loss会导致在参数更新过程中出现梯度爆炸。

#训练
import paddle
from got10k.datasets import *
from paddle.io import DataLoader
from paddle import optimizer
from tqdm import tqdm
from paddle.optimizer.lr import ExponentialDecay
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def train(data_dir, net_path=None, save_dir='pre_trained'):
    # 读取数据集
    seq_dataset = GOT10k(data_dir, subset='train', return_meta=False)
    # 定义数据增强
    transforms = SiamFCTransforms(exemplar_sz=cfg.exemplar_sz,
                                  instance_sz=cfg.instance_sz,
                                  context=cfg.context)
    train_dataset = GOT10kDataset(seq_dataset, transforms)

    loader_dataset = DataLoader(dataset=train_dataset,
                                batch_size=cfg.batch_size,
                                shuffle=True,
                                num_workers=cfg.num_workers,
                                drop_last=True)

    # 初始化网络
    paddle.device.set_device('gpu')
    model = Siamfc(out_scale=cfg.out_scale,init=True)
    # 建立损失函数和标签
    getloss = GetLoss()
    labels = get_label(size=[cfg.batch_size, 1, cfg.response_sz, cfg.response_sz])
    labels = paddle.to_tensor(labels,'float32')

    # 建立优化器
    gamma = np.power(cfg.ultimate_lr/cfg.initial_lr,1.0/cfg.epoch_num)
    lr_scheduler = ExponentialDecay(cfg.initial_lr,gamma)  # 指数衰减
    opt = optimizer.SGD(learning_rate=lr_scheduler,  # 学习率
                        parameters=model.parameters(),  # 参数
                        weight_decay=cfg.weight_decay)  # 衰减系数



    # 训练
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    start_epoch = 1
    for epoch in range(start_epoch,cfg.epoch_num+1):
        model.train()
        for it, batch in enumerate(tqdm(loader_dataset)):
            # 获取输入
            z = batch[0]  # z.shape=([8,3,127,127])
            x = batch[1]  # x.shape=([8,3,239,239])
            # 通过网络
            output = model(x,z)
            loss = getloss(output,labels)
            # 反向传播
            opt.clear_grad()
            paddle.autograd.backward(loss)
            opt.step()
            print('Epoch: {}[{}/{}]    Loss: {:.5f}    lr: {:.2e}'.format(
                    epoch, it + 1, len(loader_dataset), loss.item(), opt.get_lr()))
        # 更新学习率
        lr_scheduler.step()
        # print(lr_scheduler)
        # save checkpoint
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_path = os.path.join(
            save_dir, 'siamfc_alexnet_e%d.pdparams' % (epoch))

        paddle.save({'epoch': epoch,
                    'model': model.state_dict(),
                    'optimizer': opt.state_dict()}, save_path)


if __name__ == '__main__':
    train('work/data/')

  0%|          | 0/23 [00:00<?, ?it/s]/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/tensor.py:624: UserWarning: paddle.assign doesn't support float64 input now due to current platform protobuf data limitation, we convert it to float32
  "paddle.assign doesn't support float64 input now due "
  4%|▍         | 1/23 [00:07<02:41,  7.36s/it]

Epoch: 1[1/23]    Loss: 0.93592    lr: 1.00e-02


  9%|▊         | 2/23 [00:07<01:49,  5.23s/it]

Epoch: 1[2/23]    Loss: 0.80803    lr: 1.00e-02


 13%|█▎        | 3/23 [00:14<01:56,  5.81s/it]

Epoch: 1[3/23]    Loss: 0.74549    lr: 1.00e-02


 17%|█▋        | 4/23 [00:15<01:20,  4.23s/it]

Epoch: 1[4/23]    Loss: 0.70060    lr: 1.00e-02


 22%|██▏       | 5/23 [00:21<01:27,  4.85s/it]

Epoch: 1[5/23]    Loss: 0.67397    lr: 1.00e-02


 26%|██▌       | 6/23 [00:22<01:02,  3.68s/it]

Epoch: 1[6/23]    Loss: 0.66659    lr: 1.00e-02


 30%|███       | 7/23 [00:28<01:09,  4.36s/it]

Epoch: 1[7/23]    Loss: 0.63990    lr: 1.00e-02


 35%|███▍      | 8/23 [00:29<00:50,  3.38s/it]

Epoch: 1[8/23]    Loss: 0.64395    lr: 1.00e-02


 39%|███▉      | 9/23 [00:36<01:00,  4.33s/it]

Epoch: 1[9/23]    Loss: 0.62834    lr: 1.00e-02


 43%|████▎     | 10/23 [00:37<00:43,  3.33s/it]

Epoch: 1[10/23]    Loss: 0.63017    lr: 1.00e-02


 48%|████▊     | 11/23 [00:43<00:52,  4.35s/it]

Epoch: 1[11/23]    Loss: 0.62883    lr: 1.00e-02


 52%|█████▏    | 12/23 [00:44<00:34,  3.17s/it]

Epoch: 1[12/23]    Loss: 0.61871    lr: 1.00e-02


 57%|█████▋    | 13/23 [00:51<00:43,  4.39s/it]

Epoch: 1[13/23]    Loss: 0.61395    lr: 1.00e-02


 61%|██████    | 14/23 [00:52<00:29,  3.31s/it]

Epoch: 1[14/23]    Loss: 0.62463    lr: 1.00e-02


 70%|██████▉   | 16/23 [00:59<00:21,  3.12s/it]

Epoch: 1[15/23]    Loss: 0.60954    lr: 1.00e-02
Epoch: 1[16/23]    Loss: 0.60873    lr: 1.00e-02


 78%|███████▊  | 18/23 [01:06<00:15,  3.04s/it]

Epoch: 1[17/23]    Loss: 0.61568    lr: 1.00e-02
Epoch: 1[18/23]    Loss: 0.61443    lr: 1.00e-02


 87%|████████▋ | 20/23 [01:14<00:09,  3.15s/it]

Epoch: 1[19/23]    Loss: 0.61060    lr: 1.00e-02
Epoch: 1[20/23]    Loss: 0.61134    lr: 1.00e-02


 96%|█████████▌| 22/23 [01:21<00:03,  3.06s/it]

Epoch: 1[21/23]    Loss: 0.59785    lr: 1.00e-02
Epoch: 1[22/23]    Loss: 0.59997    lr: 1.00e-02


100%|██████████| 23/23 [01:28<00:00,  4.34s/it]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/framework/io.py:415: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterable) and not isinstance(obj, (
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 1[23/23]    Loss: 0.59723    lr: 1.00e-02


  4%|▍         | 1/23 [00:07<02:36,  7.13s/it]

Epoch: 2[1/23]    Loss: 0.60008    lr: 8.71e-03


  9%|▊         | 2/23 [00:08<01:52,  5.37s/it]

Epoch: 2[2/23]    Loss: 0.58535    lr: 8.71e-03


 13%|█▎        | 3/23 [00:14<01:53,  5.69s/it]

Epoch: 2[3/23]    Loss: 0.59269    lr: 8.71e-03


 17%|█▋        | 4/23 [00:15<01:18,  4.11s/it]

Epoch: 2[4/23]    Loss: 0.59693    lr: 8.71e-03


 26%|██▌       | 6/23 [00:22<01:01,  3.64s/it]

Epoch: 2[5/23]    Loss: 0.57386    lr: 8.71e-03
Epoch: 2[6/23]    Loss: 0.58227    lr: 8.71e-03


 35%|███▍      | 8/23 [00:30<00:49,  3.32s/it]

Epoch: 2[7/23]    Loss: 0.60373    lr: 8.71e-03
Epoch: 2[8/23]    Loss: 0.58816    lr: 8.71e-03


 43%|████▎     | 10/23 [00:37<00:40,  3.13s/it]

Epoch: 2[9/23]    Loss: 0.57428    lr: 8.71e-03
Epoch: 2[10/23]    Loss: 0.58953    lr: 8.71e-03


 52%|█████▏    | 12/23 [00:44<00:34,  3.16s/it]

Epoch: 2[11/23]    Loss: 0.58164    lr: 8.71e-03
Epoch: 2[12/23]    Loss: 0.58790    lr: 8.71e-03


 57%|█████▋    | 13/23 [00:51<00:43,  4.30s/it]

Epoch: 2[13/23]    Loss: 0.59717    lr: 8.71e-03


 61%|██████    | 14/23 [00:52<00:28,  3.17s/it]

Epoch: 2[14/23]    Loss: 0.57433    lr: 8.71e-03


 65%|██████▌   | 15/23 [00:58<00:32,  4.09s/it]

Epoch: 2[15/23]    Loss: 0.57142    lr: 8.71e-03


 70%|██████▉   | 16/23 [00:59<00:21,  3.09s/it]

Epoch: 2[16/23]    Loss: 0.58748    lr: 8.71e-03


 74%|███████▍  | 17/23 [01:06<00:25,  4.21s/it]

Epoch: 2[17/23]    Loss: 0.57060    lr: 8.71e-03


 78%|███████▊  | 18/23 [01:06<00:15,  3.12s/it]

Epoch: 2[18/23]    Loss: 0.57931    lr: 8.71e-03


 83%|████████▎ | 19/23 [01:13<00:17,  4.29s/it]

Epoch: 2[19/23]    Loss: 0.55739    lr: 8.71e-03


 87%|████████▋ | 20/23 [01:14<00:09,  3.13s/it]

Epoch: 2[20/23]    Loss: 0.56377    lr: 8.71e-03


 91%|█████████▏| 21/23 [01:21<00:08,  4.29s/it]

Epoch: 2[21/23]    Loss: 0.58154    lr: 8.71e-03


 96%|█████████▌| 22/23 [01:21<00:03,  3.17s/it]

Epoch: 2[22/23]    Loss: 0.58087    lr: 8.71e-03


100%|██████████| 23/23 [01:28<00:00,  4.16s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 2[23/23]    Loss: 0.58262    lr: 8.71e-03


  9%|▊         | 2/23 [00:08<02:05,  5.98s/it]

Epoch: 3[1/23]    Loss: 0.57540    lr: 7.59e-03
Epoch: 3[2/23]    Loss: 0.55442    lr: 7.59e-03


 17%|█▋        | 4/23 [00:15<01:23,  4.41s/it]

Epoch: 3[3/23]    Loss: 0.58554    lr: 7.59e-03
Epoch: 3[4/23]    Loss: 0.57220    lr: 7.59e-03


 26%|██▌       | 6/23 [00:22<01:03,  3.72s/it]

Epoch: 3[5/23]    Loss: 0.56074    lr: 7.59e-03
Epoch: 3[6/23]    Loss: 0.55358    lr: 7.59e-03


 35%|███▍      | 8/23 [00:30<00:50,  3.34s/it]

Epoch: 3[7/23]    Loss: 0.56507    lr: 7.59e-03
Epoch: 3[8/23]    Loss: 0.56367    lr: 7.59e-03


 43%|████▎     | 10/23 [00:37<00:42,  3.24s/it]

Epoch: 3[9/23]    Loss: 0.56969    lr: 7.59e-03
Epoch: 3[10/23]    Loss: 0.56292    lr: 7.59e-03


 52%|█████▏    | 12/23 [00:44<00:34,  3.10s/it]

Epoch: 3[11/23]    Loss: 0.57934    lr: 7.59e-03
Epoch: 3[12/23]    Loss: 0.56415    lr: 7.59e-03


 61%|██████    | 14/23 [00:52<00:27,  3.07s/it]

Epoch: 3[13/23]    Loss: 0.56923    lr: 7.59e-03
Epoch: 3[14/23]    Loss: 0.55533    lr: 7.59e-03


 70%|██████▉   | 16/23 [01:00<00:22,  3.18s/it]

Epoch: 3[15/23]    Loss: 0.56040    lr: 7.59e-03
Epoch: 3[16/23]    Loss: 0.55846    lr: 7.59e-03


 78%|███████▊  | 18/23 [01:07<00:15,  3.14s/it]

Epoch: 3[17/23]    Loss: 0.54790    lr: 7.59e-03
Epoch: 3[18/23]    Loss: 0.58165    lr: 7.59e-03


 87%|████████▋ | 20/23 [01:15<00:09,  3.14s/it]

Epoch: 3[19/23]    Loss: 0.54137    lr: 7.59e-03
Epoch: 3[20/23]    Loss: 0.56882    lr: 7.59e-03


 96%|█████████▌| 22/23 [01:22<00:03,  3.08s/it]

Epoch: 3[21/23]    Loss: 0.56345    lr: 7.59e-03
Epoch: 3[22/23]    Loss: 0.56439    lr: 7.59e-03


100%|██████████| 23/23 [01:29<00:00,  4.31s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 3[23/23]    Loss: 0.55234    lr: 7.59e-03


  9%|▊         | 2/23 [00:08<02:01,  5.78s/it]

Epoch: 4[1/23]    Loss: 0.55495    lr: 6.61e-03
Epoch: 4[2/23]    Loss: 0.55586    lr: 6.61e-03


 17%|█▋        | 4/23 [00:15<01:24,  4.44s/it]

Epoch: 4[3/23]    Loss: 0.56306    lr: 6.61e-03
Epoch: 4[4/23]    Loss: 0.56745    lr: 6.61e-03


 26%|██▌       | 6/23 [00:23<01:02,  3.69s/it]

Epoch: 4[5/23]    Loss: 0.54936    lr: 6.61e-03
Epoch: 4[6/23]    Loss: 0.55679    lr: 6.61e-03


 30%|███       | 7/23 [00:30<01:15,  4.69s/it]

Epoch: 4[7/23]    Loss: 0.57084    lr: 6.61e-03


 35%|███▍      | 8/23 [00:30<00:50,  3.40s/it]

Epoch: 4[8/23]    Loss: 0.56528    lr: 6.61e-03


 39%|███▉      | 9/23 [00:37<01:01,  4.42s/it]

Epoch: 4[9/23]    Loss: 0.55478    lr: 6.61e-03


 43%|████▎     | 10/23 [00:37<00:42,  3.27s/it]

Epoch: 4[10/23]    Loss: 0.55739    lr: 6.61e-03


 48%|████▊     | 11/23 [00:44<00:49,  4.16s/it]

Epoch: 4[11/23]    Loss: 0.54679    lr: 6.61e-03


 52%|█████▏    | 12/23 [00:44<00:34,  3.15s/it]

Epoch: 4[12/23]    Loss: 0.54321    lr: 6.61e-03


 61%|██████    | 14/23 [00:52<00:27,  3.07s/it]

Epoch: 4[13/23]    Loss: 0.56846    lr: 6.61e-03
Epoch: 4[14/23]    Loss: 0.55159    lr: 6.61e-03


 65%|██████▌   | 15/23 [00:59<00:34,  4.31s/it]

Epoch: 4[15/23]    Loss: 0.55452    lr: 6.61e-03


 70%|██████▉   | 16/23 [00:59<00:21,  3.12s/it]

Epoch: 4[16/23]    Loss: 0.56130    lr: 6.61e-03


 74%|███████▍  | 17/23 [01:06<00:25,  4.25s/it]

Epoch: 4[17/23]    Loss: 0.56254    lr: 6.61e-03


 78%|███████▊  | 18/23 [01:07<00:16,  3.35s/it]

Epoch: 4[18/23]    Loss: 0.54062    lr: 6.61e-03


 83%|████████▎ | 19/23 [01:13<00:16,  4.00s/it]

Epoch: 4[19/23]    Loss: 0.54441    lr: 6.61e-03


 87%|████████▋ | 20/23 [01:15<00:10,  3.52s/it]

Epoch: 4[20/23]    Loss: 0.54258    lr: 6.61e-03


 91%|█████████▏| 21/23 [01:19<00:07,  3.71s/it]

Epoch: 4[21/23]    Loss: 0.55964    lr: 6.61e-03


 96%|█████████▌| 22/23 [01:22<00:03,  3.51s/it]

Epoch: 4[22/23]    Loss: 0.55492    lr: 6.61e-03


100%|██████████| 23/23 [01:27<00:00,  3.72s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 4[23/23]    Loss: 0.56480    lr: 6.61e-03


  9%|▊         | 2/23 [00:08<01:58,  5.65s/it]

Epoch: 5[1/23]    Loss: 0.56192    lr: 5.75e-03
Epoch: 5[2/23]    Loss: 0.56169    lr: 5.75e-03


 17%|█▋        | 4/23 [00:15<01:22,  4.35s/it]

Epoch: 5[3/23]    Loss: 0.53957    lr: 5.75e-03
Epoch: 5[4/23]    Loss: 0.55524    lr: 5.75e-03


 26%|██▌       | 6/23 [00:23<01:02,  3.70s/it]

Epoch: 5[5/23]    Loss: 0.55393    lr: 5.75e-03
Epoch: 5[6/23]    Loss: 0.55363    lr: 5.75e-03


 35%|███▍      | 8/23 [00:31<00:52,  3.51s/it]

Epoch: 5[7/23]    Loss: 0.56923    lr: 5.75e-03
Epoch: 5[8/23]    Loss: 0.54914    lr: 5.75e-03


 43%|████▎     | 10/23 [00:38<00:42,  3.24s/it]

Epoch: 5[9/23]    Loss: 0.54129    lr: 5.75e-03
Epoch: 5[10/23]    Loss: 0.54755    lr: 5.75e-03


 52%|█████▏    | 12/23 [00:45<00:35,  3.20s/it]

Epoch: 5[11/23]    Loss: 0.55717    lr: 5.75e-03
Epoch: 5[12/23]    Loss: 0.55964    lr: 5.75e-03


 61%|██████    | 14/23 [00:53<00:29,  3.29s/it]

Epoch: 5[13/23]    Loss: 0.53102    lr: 5.75e-03
Epoch: 5[14/23]    Loss: 0.55593    lr: 5.75e-03


 70%|██████▉   | 16/23 [01:01<00:22,  3.23s/it]

Epoch: 5[15/23]    Loss: 0.53759    lr: 5.75e-03
Epoch: 5[16/23]    Loss: 0.54130    lr: 5.75e-03


 78%|███████▊  | 18/23 [01:09<00:16,  3.22s/it]

Epoch: 5[17/23]    Loss: 0.54188    lr: 5.75e-03
Epoch: 5[18/23]    Loss: 0.54319    lr: 5.75e-03


 87%|████████▋ | 20/23 [01:16<00:09,  3.14s/it]

Epoch: 5[19/23]    Loss: 0.54164    lr: 5.75e-03
Epoch: 5[20/23]    Loss: 0.57035    lr: 5.75e-03


 96%|█████████▌| 22/23 [01:24<00:03,  3.13s/it]

Epoch: 5[21/23]    Loss: 0.55962    lr: 5.75e-03
Epoch: 5[22/23]    Loss: 0.54557    lr: 5.75e-03


100%|██████████| 23/23 [01:31<00:00,  4.28s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 5[23/23]    Loss: 0.53630    lr: 5.75e-03


  9%|▊         | 2/23 [00:07<01:52,  5.37s/it]

Epoch: 6[1/23]    Loss: 0.53901    lr: 5.01e-03
Epoch: 6[2/23]    Loss: 0.55966    lr: 5.01e-03


 17%|█▋        | 4/23 [00:15<01:21,  4.30s/it]

Epoch: 6[3/23]    Loss: 0.54261    lr: 5.01e-03
Epoch: 6[4/23]    Loss: 0.58012    lr: 5.01e-03


 26%|██▌       | 6/23 [00:23<01:02,  3.69s/it]

Epoch: 6[5/23]    Loss: 0.52803    lr: 5.01e-03
Epoch: 6[6/23]    Loss: 0.54670    lr: 5.01e-03


 35%|███▍      | 8/23 [00:30<00:51,  3.43s/it]

Epoch: 6[7/23]    Loss: 0.53775    lr: 5.01e-03
Epoch: 6[8/23]    Loss: 0.56192    lr: 5.01e-03


 43%|████▎     | 10/23 [00:37<00:40,  3.13s/it]

Epoch: 6[9/23]    Loss: 0.54907    lr: 5.01e-03
Epoch: 6[10/23]    Loss: 0.54323    lr: 5.01e-03


 52%|█████▏    | 12/23 [00:46<00:37,  3.43s/it]

Epoch: 6[11/23]    Loss: 0.53700    lr: 5.01e-03
Epoch: 6[12/23]    Loss: 0.52656    lr: 5.01e-03


 61%|██████    | 14/23 [00:53<00:28,  3.18s/it]

Epoch: 6[13/23]    Loss: 0.55234    lr: 5.01e-03
Epoch: 6[14/23]    Loss: 0.54426    lr: 5.01e-03


 70%|██████▉   | 16/23 [01:01<00:21,  3.12s/it]

Epoch: 6[15/23]    Loss: 0.54654    lr: 5.01e-03
Epoch: 6[16/23]    Loss: 0.52107    lr: 5.01e-03


 78%|███████▊  | 18/23 [01:08<00:15,  3.13s/it]

Epoch: 6[17/23]    Loss: 0.54821    lr: 5.01e-03
Epoch: 6[18/23]    Loss: 0.53889    lr: 5.01e-03


 87%|████████▋ | 20/23 [01:16<00:09,  3.15s/it]

Epoch: 6[19/23]    Loss: 0.53072    lr: 5.01e-03
Epoch: 6[20/23]    Loss: 0.54654    lr: 5.01e-03


 96%|█████████▌| 22/23 [01:23<00:03,  3.03s/it]

Epoch: 6[21/23]    Loss: 0.54716    lr: 5.01e-03
Epoch: 6[22/23]    Loss: 0.51913    lr: 5.01e-03


100%|██████████| 23/23 [01:30<00:00,  4.26s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 6[23/23]    Loss: 0.53160    lr: 5.01e-03


  4%|▍         | 1/23 [00:07<02:38,  7.21s/it]

Epoch: 7[1/23]    Loss: 0.52979    lr: 4.37e-03


  9%|▊         | 2/23 [00:08<01:52,  5.34s/it]

Epoch: 7[2/23]    Loss: 0.52126    lr: 4.37e-03


 13%|█▎        | 3/23 [00:14<01:54,  5.71s/it]

Epoch: 7[3/23]    Loss: 0.53736    lr: 4.37e-03


 17%|█▋        | 4/23 [00:16<01:23,  4.39s/it]

Epoch: 7[4/23]    Loss: 0.53954    lr: 4.37e-03


 22%|██▏       | 5/23 [00:21<01:26,  4.81s/it]

Epoch: 7[5/23]    Loss: 0.54176    lr: 4.37e-03


 26%|██▌       | 6/23 [00:23<01:03,  3.72s/it]

Epoch: 7[6/23]    Loss: 0.57112    lr: 4.37e-03


 35%|███▍      | 8/23 [00:30<00:50,  3.36s/it]

Epoch: 7[7/23]    Loss: 0.52973    lr: 4.37e-03
Epoch: 7[8/23]    Loss: 0.53718    lr: 4.37e-03


 43%|████▎     | 10/23 [00:38<00:43,  3.32s/it]

Epoch: 7[9/23]    Loss: 0.52441    lr: 4.37e-03
Epoch: 7[10/23]    Loss: 0.54617    lr: 4.37e-03


 52%|█████▏    | 12/23 [00:45<00:35,  3.20s/it]

Epoch: 7[11/23]    Loss: 0.52008    lr: 4.37e-03
Epoch: 7[12/23]    Loss: 0.54956    lr: 4.37e-03


 61%|██████    | 14/23 [00:53<00:29,  3.25s/it]

Epoch: 7[13/23]    Loss: 0.56509    lr: 4.37e-03
Epoch: 7[14/23]    Loss: 0.54426    lr: 4.37e-03


 70%|██████▉   | 16/23 [01:00<00:21,  3.05s/it]

Epoch: 7[15/23]    Loss: 0.54590    lr: 4.37e-03
Epoch: 7[16/23]    Loss: 0.54508    lr: 4.37e-03


 78%|███████▊  | 18/23 [01:08<00:15,  3.18s/it]

Epoch: 7[17/23]    Loss: 0.53662    lr: 4.37e-03
Epoch: 7[18/23]    Loss: 0.54384    lr: 4.37e-03


 87%|████████▋ | 20/23 [01:15<00:09,  3.15s/it]

Epoch: 7[19/23]    Loss: 0.53775    lr: 4.37e-03
Epoch: 7[20/23]    Loss: 0.52708    lr: 4.37e-03


 96%|█████████▌| 22/23 [01:22<00:02,  3.00s/it]

Epoch: 7[21/23]    Loss: 0.54664    lr: 4.37e-03
Epoch: 7[22/23]    Loss: 0.55685    lr: 4.37e-03


100%|██████████| 23/23 [01:29<00:00,  4.15s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 7[23/23]    Loss: 0.55369    lr: 4.37e-03


  9%|▊         | 2/23 [00:08<01:57,  5.59s/it]

Epoch: 8[1/23]    Loss: 0.52307    lr: 3.80e-03
Epoch: 8[2/23]    Loss: 0.53564    lr: 3.80e-03


 13%|█▎        | 3/23 [00:14<01:56,  5.83s/it]

Epoch: 8[3/23]    Loss: 0.54166    lr: 3.80e-03


 17%|█▋        | 4/23 [00:15<01:20,  4.25s/it]

Epoch: 8[4/23]    Loss: 0.53586    lr: 3.80e-03


 22%|██▏       | 5/23 [00:21<01:28,  4.92s/it]

Epoch: 8[5/23]    Loss: 0.53068    lr: 3.80e-03


 26%|██▌       | 6/23 [00:23<01:06,  3.91s/it]

Epoch: 8[6/23]    Loss: 0.55961    lr: 3.80e-03


 30%|███       | 7/23 [00:27<01:05,  4.08s/it]

Epoch: 8[7/23]    Loss: 0.53280    lr: 3.80e-03


 35%|███▍      | 8/23 [00:30<00:54,  3.60s/it]

Epoch: 8[8/23]    Loss: 0.53320    lr: 3.80e-03


 39%|███▉      | 9/23 [00:34<00:52,  3.77s/it]

Epoch: 8[9/23]    Loss: 0.53954    lr: 3.80e-03


 43%|████▎     | 10/23 [00:37<00:47,  3.65s/it]

Epoch: 8[10/23]    Loss: 0.54348    lr: 3.80e-03


 48%|████▊     | 11/23 [00:42<00:46,  3.91s/it]

Epoch: 8[11/23]    Loss: 0.52232    lr: 3.80e-03


 52%|█████▏    | 12/23 [00:45<00:41,  3.74s/it]

Epoch: 8[12/23]    Loss: 0.53384    lr: 3.80e-03


 57%|█████▋    | 13/23 [00:49<00:38,  3.86s/it]

Epoch: 8[13/23]    Loss: 0.54646    lr: 3.80e-03


 61%|██████    | 14/23 [00:53<00:34,  3.82s/it]

Epoch: 8[14/23]    Loss: 0.54290    lr: 3.80e-03


 65%|██████▌   | 15/23 [00:57<00:31,  3.91s/it]

Epoch: 8[15/23]    Loss: 0.54018    lr: 3.80e-03


 70%|██████▉   | 16/23 [01:00<00:24,  3.52s/it]

Epoch: 8[16/23]    Loss: 0.52218    lr: 3.80e-03


 74%|███████▍  | 17/23 [01:04<00:23,  3.95s/it]

Epoch: 8[17/23]    Loss: 0.53511    lr: 3.80e-03


 78%|███████▊  | 18/23 [01:07<00:17,  3.48s/it]

Epoch: 8[18/23]    Loss: 0.54220    lr: 3.80e-03


 83%|████████▎ | 19/23 [01:12<00:15,  3.94s/it]

Epoch: 8[19/23]    Loss: 0.53641    lr: 3.80e-03


 87%|████████▋ | 20/23 [01:14<00:10,  3.48s/it]

Epoch: 8[20/23]    Loss: 0.54800    lr: 3.80e-03


 91%|█████████▏| 21/23 [01:19<00:07,  3.79s/it]

Epoch: 8[21/23]    Loss: 0.52794    lr: 3.80e-03


 96%|█████████▌| 22/23 [01:21<00:03,  3.42s/it]

Epoch: 8[22/23]    Loss: 0.55691    lr: 3.80e-03


100%|██████████| 23/23 [01:27<00:00,  4.05s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 8[23/23]    Loss: 0.53947    lr: 3.80e-03


  4%|▍         | 1/23 [00:07<02:39,  7.26s/it]

Epoch: 9[1/23]    Loss: 0.53901    lr: 3.31e-03


  9%|▊         | 2/23 [00:07<01:50,  5.27s/it]

Epoch: 9[2/23]    Loss: 0.53211    lr: 3.31e-03


 13%|█▎        | 3/23 [00:14<01:55,  5.78s/it]

Epoch: 9[3/23]    Loss: 0.53220    lr: 3.31e-03


 17%|█▋        | 4/23 [00:15<01:19,  4.21s/it]

Epoch: 9[4/23]    Loss: 0.56607    lr: 3.31e-03


 26%|██▌       | 6/23 [00:22<01:01,  3.62s/it]

Epoch: 9[5/23]    Loss: 0.53106    lr: 3.31e-03
Epoch: 9[6/23]    Loss: 0.51545    lr: 3.31e-03


 35%|███▍      | 8/23 [00:29<00:47,  3.18s/it]

Epoch: 9[7/23]    Loss: 0.55350    lr: 3.31e-03
Epoch: 9[8/23]    Loss: 0.53537    lr: 3.31e-03


 43%|████▎     | 10/23 [00:36<00:40,  3.12s/it]

Epoch: 9[9/23]    Loss: 0.53736    lr: 3.31e-03
Epoch: 9[10/23]    Loss: 0.52271    lr: 3.31e-03


 48%|████▊     | 11/23 [00:44<00:53,  4.50s/it]

Epoch: 9[11/23]    Loss: 0.52839    lr: 3.31e-03


 52%|█████▏    | 12/23 [00:45<00:36,  3.33s/it]

Epoch: 9[12/23]    Loss: 0.52050    lr: 3.31e-03


 61%|██████    | 14/23 [00:52<00:28,  3.21s/it]

Epoch: 9[13/23]    Loss: 0.53473    lr: 3.31e-03
Epoch: 9[14/23]    Loss: 0.53693    lr: 3.31e-03


 70%|██████▉   | 16/23 [00:59<00:21,  3.10s/it]

Epoch: 9[15/23]    Loss: 0.52158    lr: 3.31e-03
Epoch: 9[16/23]    Loss: 0.53606    lr: 3.31e-03


 78%|███████▊  | 18/23 [01:06<00:15,  3.00s/it]

Epoch: 9[17/23]    Loss: 0.53131    lr: 3.31e-03
Epoch: 9[18/23]    Loss: 0.53540    lr: 3.31e-03


 87%|████████▋ | 20/23 [01:13<00:08,  2.98s/it]

Epoch: 9[19/23]    Loss: 0.54101    lr: 3.31e-03
Epoch: 9[20/23]    Loss: 0.52359    lr: 3.31e-03


 96%|█████████▌| 22/23 [01:21<00:03,  3.14s/it]

Epoch: 9[21/23]    Loss: 0.53666    lr: 3.31e-03
Epoch: 9[22/23]    Loss: 0.53864    lr: 3.31e-03


100%|██████████| 23/23 [01:29<00:00,  4.58s/it]
  0%|          | 0/23 [00:00<?, ?it/s]

Epoch: 9[23/23]    Loss: 0.53795    lr: 3.31e-03


  9%|▊         | 2/23 [00:08<02:02,  5.84s/it]

Epoch: 10[1/23]    Loss: 0.53813    lr: 2.88e-03
Epoch: 10[2/23]    Loss: 0.55388    lr: 2.88e-03


 17%|█▋        | 4/23 [00:15<01:24,  4.43s/it]

Epoch: 10[3/23]    Loss: 0.54600    lr: 2.88e-03
Epoch: 10[4/23]    Loss: 0.52611    lr: 2.88e-03


 26%|██▌       | 6/23 [00:22<01:02,  3.68s/it]

Epoch: 10[5/23]    Loss: 0.53944    lr: 2.88e-03
Epoch: 10[6/23]    Loss: 0.54629    lr: 2.88e-03


 35%|███▍      | 8/23 [00:30<00:50,  3.36s/it]

Epoch: 10[7/23]    Loss: 0.56240    lr: 2.88e-03
Epoch: 10[8/23]    Loss: 0.54403    lr: 2.88e-03


 43%|████▎     | 10/23 [00:37<00:41,  3.17s/it]

Epoch: 10[9/23]    Loss: 0.52237    lr: 2.88e-03
Epoch: 10[10/23]    Loss: 0.54204    lr: 2.88e-03


 52%|█████▏    | 12/23 [00:45<00:35,  3.20s/it]

Epoch: 10[11/23]    Loss: 0.53683    lr: 2.88e-03
Epoch: 10[12/23]    Loss: 0.52826    lr: 2.88e-03


 61%|██████    | 14/23 [00:52<00:28,  3.14s/it]

Epoch: 10[13/23]    Loss: 0.53804    lr: 2.88e-03
Epoch: 10[14/23]    Loss: 0.53252    lr: 2.88e-03


Exception in thread Thread-27:
Traceback (most recent call last):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py", line 583, in _get_data
    data = self._data_queue.get(timeout=self._timeout)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/multiprocessing/queues.py", line 105, in get
    raise Empty
_queue.Empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py", line 505, in _thread_loop
    batch = self._get_data()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py", line 599, in _get_data
    "pids: {}".format(len(failed_workers), pids))
RuntimeError: DataLoader 2 workers exit unexpectedly, pids: 4275, 4276




---------------------------------------------------------------------------

SystemError                               Traceback (most recent call last)

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py in __next__(self)
    696             if in_dygraph_mode():
--> 697                 data = self._reader.read_next_var_list()
    698                 data = _restore_batch(data, self._structure_infos.pop(0))


SystemError: (Fatal) Blocking queue is killed because the data reader raises an exception.
  [Hint: Expected killed_ != true, but received killed_:1 == true:1.] (at /paddle/paddle/fluid/operators/reader/blocking_queue.h:166)



During handling of the above exception, another exception occurred:


KeyboardInterrupt                         Traceback (most recent call last)

/tmp/ipykernel_194/1859510741.py in <module>
     76 
     77 if __name__ == '__main__':
---> 78     train('work/data/')


/tmp/ipykernel_194/1859510741.py in train(data_dir, net_path, save_dir)
     48     for epoch in range(start_epoch,cfg.epoch_num+1):
     49         model.train()
---> 50         for it, batch in enumerate(tqdm(loader_dataset)):
     51             # 获取输入
     52             z = batch[0]  # z.shape=([8,3,127,127])


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/tqdm/_tqdm.py in __iter__(self)
    977 """, fp_write=getattr(self.fp, 'write', sys.stderr.write))
    978 
--> 979             for obj in iterable:
    980                 yield obj
    981                 # Update and possibly print the progressbar.


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py in __next__(self)
    695 
    696             if in_dygraph_mode():
--> 697                 data = self._reader.read_next_var_list()
    698                 data = _restore_batch(data, self._structure_infos.pop(0))
    699             else:


KeyboardInterrupt: 

六、推理

推理的过程主要可以分为三步,第一步是初始帧的处理,第二步是通过正向传播更新响应,第三步是根据响应反推出目标的位置。

初始帧的处理的目的是得到后续跟踪所用的模板图,并且保持不变,输入第一帧的图片和GT,输出为[6,6,128]的卷积核。该卷积核作为模板,不在线更新。这种做法虽然可以保证速度,但是一旦物体出现快速移动或者形变等就容易跟丢。

得到初始帧的模板后,即可进行跟踪计算。需要把后续帧作为目标x,第一帧作为模板z输入网络得到响应图尺寸为17x17,需要找出最大响应所在的位置。

考虑到17x17的map较小,因此采用了双三次插值的方法,对响应图进行上采样到272x272。实际中,由于跟踪目标可能会出现大小变化,因此需要引入变化尺度来调整框的大小,同时使用汉宁窗惩罚来突出目标的位置。不停地更新目标的中心及框的大小,即可完成跟踪。

#推理

import time
import paddle
import numpy as np
import cv2
from got10k.trackers import Tracker



def read_image(img_file, cvt_code=cv2.COLOR_BGR2RGB):  # 将BGR格式转换成RGB格式,cv.imread都进来直接就是BGR,[w,h,c]
    img = cv2.imread(img_file, cv2.IMREAD_COLOR)  # cv2.imread函数读取图片,后面参数代表加载彩色图片,还有灰度图片等 返回的img为[weight,height,channel]
    if cvt_code is not None:  # 这个判断可以省略,上面给出了cvt_code的具体值
        img = cv2.cvtColor(img, cvt_code)
    return img


def show_image(img, boxes=None, box_fmt='ltwh', colors=None,
               thickness=3, fig_n=1, delay=1, visualize=True,
               cvt_code=cv2.COLOR_RGB2BGR):
    if cvt_code is not None:
        img = cv2.cvtColor(img, cvt_code)  # 要用cv2显示,要把RGB转化为BGR!!!

    # resize img if necessary 有必要的话resize 图片
    max_size = 960  # 最大为960
    if max(img.shape[:2]) > max_size:
        scale = max_size / max(img.shape[:2])  # 960/max(w,h)
        out_size = (
            int(img.shape[1] * scale),  # 960/max(w,h)*h
            int(img.shape[0] * scale))  # 960/max(w,h)*w
        img = cv2.resize(img, out_size)
        if boxes is not None:
            boxes = np.array(boxes, dtype=np.float32) * scale

    if boxes is not None:
        assert box_fmt in ['ltwh', 'ltrb']
        boxes = np.array(boxes, dtype=np.int32)  # boxes.shape(4,)
        if boxes.ndim == 1:  # boxes的维度是否为1
            boxes = np.expand_dims(boxes,
                                   axis=0)  # boxes.shape(1,4) #增加维度 axis=0,比如[2 2 3]变成[1 2 2 3] axis=1,[2 2 3]变成[2 1 2 3] 还有axis=2/3
        if box_fmt == 'ltrb':
            boxes[:, 2:] -= boxes[:, :2]

        # clip bounding boxes
        bound = np.array(img.shape[1::-1])[None, :]  # img.shape[1::-1]表示[w,h,3]->[h,w,3] ,[None,:]表示[h,w]->[1,h,w,3]
        boxes[:, :2] = np.clip(boxes[:, :2], 0, bound)  # boxes前两列
        boxes[:, 2:] = np.clip(boxes[:, 2:], 0, bound - boxes[:, :2])  # boxes后两列

        if colors is None:
            colors = [
                (0, 0, 255),
                (0, 255, 0),
                (255, 0, 0),
                (0, 255, 255),
                (255, 0, 255),
                (255, 255, 0),
                (0, 0, 128),
                (0, 128, 0),
                (128, 0, 0),
                (0, 128, 128),
                (128, 0, 128),
                (128, 128, 0)]  # len(colors)=12
        colors = np.array(colors, dtype=np.int32)  # colors.shape=[12 3]
        if colors.ndim == 1:
            colors = np.expand_dims(colors, axis=0)

        for i, box in enumerate(boxes): 
            color = colors[i % len(colors)]  # len(colors)=3
            pt1 = (box[0], box[1])
            pt2 = (box[0] + box[2], box[1] + box[3])
            img = cv2.rectangle(img, pt1, pt2, color.tolist(), thickness)

    if visualize:
        winname = 'window_{}'.format(fig_n)  # window_1   {}被格式化为1
        cv2.imshow(winname, img)
        cv2.waitKey(delay)  # 1秒更新一次

    return img


def ltwh_to_yxhw(ltwh):
    yxhw = np.array([
        ltwh[1] - 1 + (ltwh[3] - 1) / 2,
        ltwh[0] - 1 + (ltwh[2] - 1) / 2,
        ltwh[3], ltwh[2]], dtype=np.float32)
    return yxhw


def yxhw_to_ltwh(yxhw):
    ltwh = np.array([
        yxhw[1] + 1 - (yxhw[1] - 1) / 2,
        yxhw[0] + 1 - (yxhw[0] - 1) / 2,
        yxhw[1], yxhw[0]])
    return ltwh


def map_process(response, hanning_window):
    # 数据处理,汉宁窗惩罚
    response -= response.min()
    response /= response.sum() + 1e-16
    response = (1 - cfg.window_influence) * response + \
               cfg.window_influence * hanning_window  # window_influence=0.176
    return response


def map_to272(responses, out_size):
    responses = np.stack([cv2.resize(
        u, (out_size, out_size),
        interpolation=cv2.INTER_CUBIC)
        for u in responses])
    return responses


def x_to3s255(img, center, patch_size, three_scales, out_size, border_value):
    x = [crop_and_resize(
        img, center, patch_size * scale,
        out_size=out_size,
        border_value=border_value) for scale in three_scales]
    x = np.stack(x, axis=0)  # [3,255,255,3]第一个三代表三种尺度
    return x


def create_hanning_window(size):
    hann_window = np.outer(
        np.hanning(size),
        np.hanning(size))  # a,b都是行向量,则np.outer(a,b)=a^(T)*b 组成一个矩阵
    hann_window /= hann_window.sum()
    return hann_window


def scales():
    scale_factors = cfg.scale_step ** np.linspace(  # 1.0375^(-1,0,1)
        -(cfg.scale_num // 2),
        cfg.scale_num // 2, cfg.scale_num)
    return scale_factors


def z_to127(img, center, patch_size, out_size, border_value):
    z = crop_and_resize(
        img, center, patch_size,
        out_size=out_size,
        border_value=border_value)
    return z


class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None):
        super(TrackerSiamFC, self).__init__('SiamFC', True)
        self.model = Siamfc(out_scale=cfg.out_scale)
        if net_path is not None:
            checkpoint = paddle.load(net_path)
            self.model.load_dict(checkpoint['model'])

    # 传入第一帧图片和gt及初始化
    def init(self, img, box):
        # 推理模式,关闭自动求导
        self.model.eval()
        # 将原始的目标位置表示[l,t,w,h]->[center_y,center_x,h,w]
        yxhw = ltwh_to_yxhw(box)
        self.center, self.target_sz = yxhw[:2], yxhw[2:]
        # hanning窗
        self.response_upsz = cfg.response_up * cfg.response_sz
        self.hanning_window = create_hanning_window(size=self.response_upsz)
        # 三种尺度1.0375**(-1,0,1) 三种尺度
        self.scale_factors = scales()
        # patch的边长
        context = cfg.context * np.sum(self.target_sz)  # 上下文信息(h+w)/2
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))  # (h+(h+w)/2)*(w+(h+2)/2))^0.5
        self.x_sz = self.z_sz * cfg.instance_sz / cfg.exemplar_sz  # (h+(h+w)/2)*(w+(h+2)/2))^0.5*255/127
        # 图像的RGB均值
        self.avg_color = np.mean(img, axis=(0, 1))
        # 裁剪一块以目标为中心,边长为z_sz大小的patch,然后将其resize成exemplar_sz的大小
        z = z_to127(img, self.center, self.z_sz, cfg.exemplar_sz, self.avg_color)
        z = paddle.transpose(paddle.to_tensor(z, 'float32'), perm=[2, 0, 1]).unsqueeze(0)
        self.sample = z

    def update(self, img):
        self.model.eval()
        x = x_to3s255(img, self.center, self.x_sz, self.scale_factors, cfg.instance_sz, self.avg_color)
        x = paddle.to_tensor(x, 'float32')
        x = paddle.transpose(x, perm=[0, 3, 1, 2])
        # x : [3,255,22,22]
        responses = self.model(x, self.sample)
        responses = responses.squeeze(1).cpu().numpy()
        # 相应的size: 17x17 -> 272x272
        responses = map_to272(responses, out_size=self.response_upsz)
        # 尺度变化的惩罚
        responses[:cfg.scale_num // 2] *= cfg.scale_penalty
        responses[:cfg.scale_num // 2 + 1] *= cfg.scale_penalty
        # 找到最大的响应
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))
        response = responses[scale_id]  # [272,272]
        # 数据处理
        response = map_process(response, self.hanning_window)
        loc = np.unravel_index(response.argmax(), response.shape)  # 返回索引response.argmax()的元素的坐标
        # 反推原图的位置
        disp_in_respone =np.array(loc) - (self.response_upsz-1)/2
        disp = disp_in_respone/16
        disp = disp*8
        disp = disp*self.x_sz * self.scale_factors[scale_id]/cfg.instance_sz
        self.center +=disp
        # 参数更新
        scale  = (1-cfg.scale_lr) * 1 + cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale   # 得到目标的长宽
        self.z_sz *=scale  # h+(h+w)/2)*(w+(h+2)/2))^0.5*scale
        self.x_sz *=scale  # h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127*scale
        # [y,x,h,w]->[l,t,w,h]
        box=yxhw_to_ltwh([self.center,self.target_sz])
        return box


    def track(self,img_files,box,visualize=False):
        fram_num = len(img_files)
        boxes = np.zeros((fram_num,4))
        boxes[0] = box
        times = np.zeros(fram_num)

        for f, img_file in enumerate(img_files):
            img = read_image(img_file)
            begin = time.time()
            if f == 0:
                self.init(img,box)
            else:
                boxes[f,:]=self.update(img)
            times[f] = time.time() -begin

            if visualize:
                show_image(img,boxes[f,:])
        return boxes,times

七、测试

由于notebook不支持cv2.imshow,因此要查看模型的效果需要在自己的电脑上运行,数据集上的效果测试待补充。

后续计划是复现SiamRPN及SiamRPN++,学习中。。

#测试

import os
import paddle
import glob
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

if __name__ == '__main__':
    seq_dir = os.path.expanduser('work/Crossing/')
    img_files = sorted(glob.glob(seq_dir + 'img/*.jpg'))
    # print(img_files[0])
    anno = np.loadtxt(seq_dir + 'groundtruth_rect.txt', delimiter='\t')  # 读取groundtruth
    net_path = 'pre_trained/siamfc_alexnet_e49.pdparams'
    tracker = TrackerSiamFC(net_path=net_path)
    tracker.track(img_files, anno[0], visualize=False)

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐