【目标跟踪】基于孪生网络的目标跟踪网络SiamFC
目标跟踪的经典网络SiamFC的paddle实现
一、简介
基于孪生网络的单目标追踪网络SiamFC是单目标追踪领域相当重要的一篇经典工作,本文将从网络结构,数据处理,损失函数,跟踪方法四个方面简单介绍SiamFC的具体内容。本文仅为个人的学习笔记,如有错误或者理解不到位的地方欢迎讨论。
论文地址:https://arxiv.org/pdf/1606.09549.pdf
本项目使用的数据集为GOT10K:
数据集网站:http://got-10k.aitestunion.com/
paddle版本的代码实现参考了GitHub上的各种pytorch版本:
pytorch版本:https://github.com/NieHa0ha0/siamfc-pytorch-got10k
# 解压数据集和安装数据集的工具包
# 这里仅使用spilt01-03测试,list.txt文件是数据的索引,
# !mkdir work/data
# !mkdir work/data/train
# !mv work/list.txt work/data/train/
# !unzip data/data126011/GOT-10k_Train_split_01.zip -d work/data/train/
# !unzip data/data126011/GOT-10k_Train_split_02.zip -d work/data/train/
# !unzip data/data126011/GOT-10k_Train_split_03.zip -d work/data/train/
# !pip install got10k
#各种配置参数
class Config:
# basic parameters
out_scale = 0.001 # 互相关之后需要乘一个系数,否则值太大会导致梯度爆炸
exemplar_sz = 127 # 搜索图像的尺寸
instance_sz = 255 # 模板图像的尺寸
context = 0.5 # 上下文的比例
# inference parameters
scale_num = 3 # 推理中的不同尺度数
scale_step = 1.0375 # 检测框的变化尺度
scale_lr = 0.59 # 尺度变化的学习率
scale_penalty = 0.9745 # 尺度变化的惩罚系数
window_influence = 0.176 #汉宁窗惩罚的系数
response_sz = 17 # 响应图的尺寸
response_up = 16 # 插值
total_stride = 8 # 步幅
# train parameters
epoch_num = 50
batch_size = 64
num_workers = 2
initial_lr = 1e-2
ultimate_lr = 1e-5
weight_decay = 5e-4
r_pos = 16
r_neg = 0
cfg = Config()
二、网络结构
SiamFC的网络结构如图所示,该网络有两个输入,分别是作为搜索模板的Z和作为搜索目标的X,两个输入分别通过权值共享的特征提取网络,分别得到了两个特征图,模板的特征图尺寸为6x6x128,目标的特征图尺寸为22x22x128。
作者使用的特征提取网络是AlexNet,网络的具体参数和每层的输入及输出如下表所示
之后,通过互相关操作,得到一个17x17的输出,该输出的每个值代表模板与目标在当前位置的相似程度。其中互相关操作的实现方式是将模板的特征图作为卷积核,与目标的特征图进行卷积操作,代码中是通过调用paddle.nn.functional.conv2d实现的。 除了直接使用卷积计算,后续的学者还提出了深度相关操作,深度相关与深度可分离卷积有类似的思想,分离通道进行卷积,可以降低计算量。
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import time
paddle.set_device='cpu'
# backbone
class AlexNet(nn.Layer):
def __init__(self, out_channels, init=False, conv_weight_init=None, bias_init=None):
super(AlexNet, self).__init__()
self.conv_weight_init = conv_weight_init
self.bias_init = bias_init
# 如果需要初始化,则为初始化的两个变量赋值
if init:
self._init_weights()
self.conv1 = nn.Sequential(
nn.Conv2D(in_channels=3, out_channels=96, kernel_size=11, stride=2, padding=0
, weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
nn.BatchNorm2D(96),
nn.ReLU(),
nn.MaxPool2D(3, 2))
self.conv2 = nn.Sequential(
nn.Conv2D(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=0
, weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
nn.BatchNorm2D(256),
nn.ReLU(),
nn.MaxPool2D(3, 2))
self.conv3 = nn.Sequential(
nn.Conv2D(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=0
, weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
nn.BatchNorm2D(384),
nn.ReLU())
self.conv4 = nn.Sequential(
nn.Conv2D(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=0
, weight_attr=self.conv_weight_init, bias_attr=self.bias_init),
nn.BatchNorm2D(384),
nn.ReLU())
self.conv5 = nn.Sequential(
nn.Conv2D(in_channels=384, out_channels=out_channels, kernel_size=3, stride=1, padding=0
, weight_attr=self.conv_weight_init, bias_attr=self.bias_init))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
# 参数初始化,权重和偏置分别使用XavierUniform和Constant初始化
def _init_weights(self):
self.conv_weight_init = nn.initializer.XavierUniform()
self.bias_init = nn.initializer.Constant(value=0)
# RPN模块 (暂时用不到)
# class Rpn(nn.Layer):
# def __init__(self, anchor_num):
# super(Rpn, self).__init__()
# self.anchor_num = anchor_num # 锚框数
# self.conv_x_cls = nn.Conv2D(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
# self.conv_x_reg = nn.Conv2D(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
# self.conv_z_cls = nn.Conv2D(in_channels=256, out_channels=2 * anchor_num * 256, kernel_size=3, stride=1,
# padding=0)
# self.conv_z_reg = nn.Conv2D(in_channels=256, out_channels=4 * anchor_num * 256, kernel_size=3, stride=1,
# padding=0)
# def Xcorr_train(self, x, z, type):
# out = []
# b = x.shape[0]
# k = 2 if type == 'cls' else 4
# for i in range(b):
# out.append(
# F.conv2d(x[i, :, :, :].unsqueeze(0),
# paddle.reshape(z[i, :, :, :].unsqueeze(0), [k * self.anchor_num, -1, 0, 0]),
# # 调整模板tensor的shape[1,2k,4,4]
# stride=1, padding=0))
# out = paddle.concat(out, axis=0)
# return out
# def Xcorr_test(self, x, z, type):
# k = 2 if type == 'cls' else 4
# out = F.conv2d(x,
# paddle.reshape(z, [k * self.anchor_num, -1, 0, 0]),
# stride=1, padding=0)
# return out
# def forward(self, x, z):
# x_reg = self.conv_x_reg(x)
# z_reg = self.conv_z_reg(z)
# x_cls = self.conv_x_cls(x)
# z_cls = self.conv_z_cls(z)
# cls_out = self.Xcorr_train(x_cls, z_cls, 'cls')
# reg_out = self.Xcorr_train(x_reg, z_reg, 'reg')
# return cls_out, reg_out
# def track_init(self, z):
# z_reg = self.conv_z_reg(z)
# z_cls = self.conv_z_cls(z)
# return z_reg, z_cls
# def track_update(self, x, z_cls, z_reg):
# x_reg = self.conv_x_reg(x)
# x_cls = self.conv_x_cls(x)
# cls_out = self.Xcorr_test(x_cls, z_cls, 'cls')
# reg_out = self.Xcorr_test(x_reg, z_reg, 'reg')
# return cls_out, reg_out
#——————————————————————————————————————————————————————————
# 使用for循环实现相关计算,速度较慢
def Xcorr(x, z):
out = []
b = x.shape[0] # 每个batch独立计算,之后concat
for i in range(b):
out.append(
nn.functional.conv2d(x[i, :, :, :].unsqueeze(0), z[i, :, :, :].unsqueeze(0), stride=1, padding=0))
out = paddle.concat(out, axis=0)
return out
# 参考pysot,使用分组卷积加速相关计算
def Xcorr_fast(x,z):
b = z.shape[0]
z = paddle.reshape(z,[-1,x.shape[1],0,0])
x = paddle.reshape(x, [1,-1,0,0])
out = F.conv2d(x,z,groups=b)
out = paddle.reshape(out,[b,-1,0,0])
return out
# 推理时,模板z不再更新,batch始终为1
def Xcorr_test(x, z):
out = []
b = x.shape[0] # 每个batch独立计算,之后concat
for i in range(b):
out.append(
nn.functional.conv2d(x[i, :, :, :].unsqueeze(0), z, stride=1, padding=0))
out = paddle.concat(out, axis=0)
return out
#————————————————————————————————————————————————————————
# 构建Siamfc类
class Siamfc(nn.Layer):
def __init__(self, out_scale, init=False):
super(Siamfc, self).__init__()
self.out_scale = out_scale
self.backbone = AlexNet(out_channels=128, init=init)
def forward(self, x, z, mode='train'): # x:detect z:template
x = self.backbone(x)
z = self.backbone(z)
if mode == 'train':
# out = Xcorr(x, z) * self.out_scale
out = Xcorr_fast(x,z) * self.out_scale
return out
else:
out = Xcorr_test(x, z) * self.out_scale
return out
# class SiamRpn(nn.Layer):
# def __init__(self, anchor_num):
# super(SiamRpn, self).__init__()
# self.anchor_num = anchor_num
# self.backbone = AlexNet(out_channels=256)
# self.head = Rpn(self.anchor_num)
# def forward(self, x, z):
# x = self.backbone(x)
# z = self.backbone(z)
# cls_out, reg_out = self.head(x, z)
# return cls_out, reg_out
# def track_init(self, z):
# z = self.backbone(z)
# reg_z, cls_z = Rpn(self.anchor_num).track_init(z)
# return reg_z, cls_z
# def track_update(self, x, cls_z, reg_z):
# x = self.backbone(x)
# cls_out, reg_out = Rpn(self.anchor_num).track_update(x, cls_z, reg_z)
# return cls_out, reg_out
def main():
model_SiamFc = Siamfc(out_scale=0.001, init=True)
z_train = paddle.randn([32, 3, 128, 128])
x_train = paddle.randn([32, 3, 255, 255])
pred_score = model_SiamFc(x_train, z_train, mode='train')
print(pred_score.shape)
if __name__ == '__main__':
main()
[32, 1, 17, 17]
三、数据处理
训练数据来源为GOT10K,训练中的图片是从视频序列中随机抽取两帧图片,通过图像处理将其作为模板和目标输入网络中。
以一个数据集中的图片为例解析图片处理的过程:
读取原始图像和bbox后,需要从原始图像中切出包含上下文信息的patch,patch 的中心即bbox的中心,patch的size由下面的公式计算:
其中x和z分别是目标和模板的size,h和w为bbox的高和宽,得到patch的大小后,在原图中,以bbox的中心为patch的中心切出来,会有下面两种情况,如果patch没有超出边界,则直接切即可,如果超出了边界则需要对原图使用像素的均值进行padding后再切出patch。
之后对patch进行resize得到255x255的patch,再通过一些随机缩放,随机切块等数据增强操作,即可得到255x255的目标和127x127的模板
流程如下图所示:
解压完数据集后,还需要一个list.txt文件,该文件是数据集的目录,路径在/work/data/train下
import numbers
from got10k.datasets import *
import paddle
import numpy as np
import cv2
# [l,t,w,h] -> [y,x,h,w]
def convert_coordinate(org_box: list):
box = np.array([
org_box[1] - 1 + (org_box[3] - 1) / 2,
org_box[0] - 1 + (org_box[2] - 1) / 2,
org_box[3], org_box[2]], dtype=np.float32)
return box
# 就是把一系列的transforms串起来
class Compose(object): # 继承了object类,就拥有了object类里面好多可以操作的对象
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img): # 为了将类的实例对象变为可调用对象(相当于重载()运算符) a=Compose() a.__call__() 和a()的使用是一样的
for t in self.transforms:
img = t(img)
return img
# 主要是随机的resize图片的大小,变化再[1 1.05之内]其中要注意cv2.resize()的一点用法
class RandomStretch(object):
def __init__(self, max_stretch=0.05):
self.max_stretch = max_stretch
def __call__(self, img):
interp = np.random.choice([ # 调用interp时候随机选择一个
cv2.INTER_LINEAR, # 双线性插值(默认设置)
cv2.INTER_CUBIC, # 4x4像素领域的双三次插值
cv2.INTER_AREA, # 像素区域关系重采样,类似与NEAREST
cv2.INTER_NEAREST, # 最近领插值
cv2.INTER_LANCZOS4]) # 8x8像素的Lanczosc插值
scale = 1.0 + np.random.uniform(
-self.max_stretch, self.max_stretch)
out_size = (
round(img.shape[1] * scale), # 这里是width
round(img.shape[0] * scale)) # 这里是heigth cv2的用法导致
return cv2.resize(img, out_size, interpolation=interp) # 将img的大小resize成out_size
# 从img中心抠一块(size, size)大小的patch,如果不够大,以图片均值进行pad之后再crop
class CenterCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number): # isinstance(object, classinfo) 判断实例是否是这个类或者object是变量
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
h, w = img.shape[:2] # img.shape为[height,width,channel]
tw, th = self.size
i = round((h - th) / 2.) # round(x,n) 对x四舍五入,保留n位小数 省略n 0位小数
j = round((w - tw) / 2.)
npad = max(0, -i, -j)
if npad > 0:
avg_color = np.mean(img, axis=(0, 1)) # 取整个图片的像素均值
img = cv2.copyMakeBorder( # 添加边框函数,上下左右要扩展的像素数都是npad,BORDER_CONSTANT固定值填充,值为avg_color)
img, npad, npad, npad, npad,
cv2.BORDER_CONSTANT, value=avg_color)
i += npad
j += npad
return img[i:i + th, j:j + tw]
# 用法类似CenterCrop,只不过从随机的位置抠,没有pad的考虑
class RandomCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
h, w = img.shape[:2]
tw, th = self.size
i = np.random.randint(0, h - th + 1)
j = np.random.randint(0, w - tw + 1)
return img[i:i + th, j:j + tw]
# 调整图片的尺寸
class Resize(object):
def __init__(self,size):
self.size = size
def __call__(self,img):
interp = np.random.choice([ # 调用interp时候随机选择一个
cv2.INTER_LINEAR, # 双线性插值(默认设置)
cv2.INTER_CUBIC, # 4x4像素领域的双三次插值
cv2.INTER_AREA, # 像素区域关系重采样,类似与NEAREST
cv2.INTER_NEAREST, # 最近领插值
cv2.INTER_LANCZOS4]) # 8x8像素的Lanczosc插值
out = cv2.resize(img,dsize=(self.size,self.size),interpolation=interp)
return out
class ToTensor(object):
def __call__(self, img):
out = paddle.to_tensor(img,'float32')
return paddle.transpose(out,perm=[2,0,1])
class SiamFCTransforms(object):
def __init__(self, exemplar_sz=127, instance_sz=255, context=0.5):
self.exemplar_sz = exemplar_sz
self.instance_sz = instance_sz
self.context = context
self.transforms_z = Compose(
[RandomStretch(),
CenterCrop(instance_sz - 8),
RandomCrop(instance_sz - 16),
CenterCrop(exemplar_sz),
ToTensor()])
self.transforms_x = Compose([
RandomStretch(),
CenterCrop(instance_sz - 8),
RandomCrop(instance_sz - 16),
Resize(instance_sz),
ToTensor()])
def __call__(self, z, x, box_z, box_x):
z = self._crop(z, box_z, self.instance_sz) # 对z(x类似)图像 1、box转换(l,t,w,h)->(y,x,h,w),并且数据格式转为float32,得到center[y,x],和target_sz[h,w]
x = self._crop(x, box_x, self.instance_sz) # 2、得到size=((h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127
z = self.transforms_z(z) # 3、进入crop_and_resize:传入z作为图片img,center,size,outsize=255(instance_sz),随机选方式填充,均值填充
x = self.transforms_x(x) # 以center为中心裁剪一块边长为size大小的正方形框(注意裁剪时的padd边框填充问题),再resize成out_size=255(instance_sz)
return z, x
def _crop(self, img, box, out_size):
box = convert_coordinate(box) # 将[xmin,ymin,w,h]转换成[y,x,h,w]
center,target_sz = box[:2],box[2:]
context = self.context * np.sum(target_sz)
size = np.sqrt(np.prod(target_sz + context))
size *= out_size / self.exemplar_sz
avg_color = np.mean(img, axis=(0, 1), dtype=float)
interp = np.random.choice([
cv2.INTER_LINEAR,
cv2.INTER_CUBIC,
cv2.INTER_AREA,
cv2.INTER_NEAREST,
cv2.INTER_LANCZOS4])
patch = crop_and_resize(
img, center, size, out_size, border_value=avg_color, interp=interp)
return patch
def crop_and_resize(img, center, size, out_size,
border_type=cv2.BORDER_CONSTANT,
border_value=(0, 0, 0), # border_value使用的是图像均值(averageR,aveG,aveB)
interp=cv2.INTER_LINEAR):
size = round(size) # 对size取整
corners = np.concatenate(( # np.concatenate:数组的凭借 np.concatenate((a,b),axis) axis=0是列拼接,axis=1是行拼接 省略axis为0
np.round(center - (size - 1) / 2),
np.round(center - (size - 1) / 2) + size)) # 得到corners=[ymin,xmin,ymax,xmax]
corners = np.round(corners).astype(int) # 转化为int型
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# print(img.shape)
# cv2.imshow('original', img)
# cv2.imwrite('original.png', img)
# 填充
pads = np.concatenate((
-corners[:2], corners[2:] - img.shape[:2]))
npad = max(0, int(max(pads))) # 得到4个值中最大的与0对比
if npad > 0:
img = cv2.copyMakeBorder(
img, npad, npad, npad, npad,
border_type, value=border_value)
# crop image patch
corners = (corners + npad).astype(int) # 如果经行了填充,那么中心坐标也要变
# print(corners)
patch = img[corners[0]:corners[2], corners[1]:corners[3]] # 得到patch的大小
# cv2.imshow('padding_img',img)
# cv2.imwrite('padding_img.png', img)
# cv2.imshow('contest_img',patch)
# cv2.imwrite('contest_img.png', patch)
# resize to out_size
patch = cv2.resize(patch, (out_size, out_size),
interpolation=interp)
# cv2.imshow('resize255_img',patch)
# cv2.imwrite('resize255_img.png', patch)
cv2.waitKey(0)
return patch
class GOT10kDataset(paddle.io.Dataset):
def __init__(self, seqs, transforms=None, pairs_per_seq=1):
super(GOT10kDataset, self).__init__()
self.seqs = seqs
self.transforms = transforms
self.pairs_per_seq = pairs_per_seq
self.indices = np.random.permutation(len(seqs))
self.return_meta = getattr(seqs, 'return_meta') # 判断return_meta是否在segs中,如果不在,返回False,在的话返回1
# 通过index索引返回item=(z,x,box_z,box_x),然后经过transforms返回一对pair(z,x)
def __getitem__(self, index):
# print(self.indices)
index = self.indices[index % len(self.indices)]
# print(index)
# index = self.indices[index] 与上相同
# get filename lists and annotations
if self.return_meta:
img_files, anno, meta = self.seqs[index]
via_ratios = meta.get('cover', None)
else:
img_files, anno = self.seqs[index][:2]
via_ratios = None
val_indices = self._filter(
cv2.imread(img_files[0], cv2.IMREAD_COLOR),
anno, via_ratios)
if len(val_indices) < 2:
index = np.random.choice(len(self))
return self.__getitem__(index)
rand_z, rand_x = self._sample_pair_(val_indices)
z = cv2.imread(img_files[rand_z], cv2.IMREAD_COLOR)
x = cv2.imread(img_files[rand_x], cv2.IMREAD_COLOR)
z = cv2.cvtColor(z, cv2.COLOR_BGR2RGB)
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
box_z = anno[rand_z]
box_x = anno[rand_x]
item = (z, x, box_z, box_x) # box就是ground_truth
if self.transforms is not None:
item = self.transforms(*item)
return item
# 这里定义的长度就是被索引到的视频序列数x每个序列提供的对数(1对)
def __len__(self):
return len(self.indices) * self.pairs_per_seq # len(self.indices)=9335 返回9335*1对
# 随机挑选两个索引,这里取的间隔不超过T=100
def _sample_pair_(self, indices):
n = len(indices)
assert n > 0
if n == 1:
return indices[0], indices[0]
elif n == 2:
return indices[0], indices[1]
else:
for i in range(100):
rand_z, rand_x = np.sort(
np.random.choice(indices, 2, replace=False))
if rand_x - rand_z < 100:
break
else:
rand_z = np.random.choice(indices)
rand_x = rand_z
return rand_z, rand_x
# 通过该函数筛选符合条件的有效索引val_indices
def _filter(self, img0, anno, vis_ratios=None):
size = np.array(img0.shape[1::-1])[np.newaxis, :]
areas = anno[:, 2] * anno[:, 3]
# acceptance conditions
c1 = areas >= 20
c2 = np.all(anno[:, 2:] >= 20, axis=1)
c3 = np.all(anno[:, 2:] <= 500, axis=1)
c4 = np.all((anno[:, 2:] / size) >= 0.01, axis=1)
c5 = np.all((anno[:, 2:] / size) <= 0.5, axis=1)
c6 = (anno[:, 2] / np.maximum(1, anno[:, 3])) >= 0.25
c7 = (anno[:, 2] / np.maximum(1, anno[:, 3])) <= 4
if vis_ratios is not None:
c8 = (vis_ratios > max(1, vis_ratios.max() * 0.3))
else:
c8 = np.ones_like(c1)
mask = np.logical_and.reduce(
(c1, c2, c3, c4, c5, c6, c7, c8))
val_indices = np.where(mask)[0]
return val_indices
if __name__ == "__main__":
root_dir = 'work/data/'
seq_dataset = GOT10k(root_dir, subset='train')
transforms = SiamFCTransforms(
exemplar_sz=cfg.exemplar_sz, # 127
instance_sz=cfg.instance_sz, # 255
context=cfg.context) # 0.5
train_dataset = GOT10kDataset(seq_dataset, transforms)
item = train_dataset.__getitem__(1) # 返回随机的某个视频序列两帧处理后的图片
print(item[0].shape)
print(train_dataset.__len__())
[3, 127, 127]
1500
四、损失函数
论文中采用的loss函数如图
其中,y和v分别是真实标签(取值+1和-1)和上文17x17的输出中x和z相关程度的预测值,l(y,v)即为输出上的某点的loss。
而整个输出的loss为
D为输出图。即Loss为所有点的loss求和再除以整体输出
代码中,l 的实现是通过binary_cross_entropy_with_logits实现的,标签的构造也非论文中的{-1,+1},而是{0,1},推导过程如下
#loss和label
import numpy as np
import paddle
import paddle.nn as nn
def logistic_labels(x, y, r_pos):
# x^2+y^2<4 的位置设为为1,其他为0
dist = np.sqrt(x ** 2 + y ** 2)
labels = np.where(dist <= r_pos, # r_pos=2
np.ones_like(x),
np.zeros_like(x))
return labels
def get_label(size):
n, c, h, w = size # [8,1,17,17]
x = np.arange(w) - (w - 1) / 2
y = np.arange(h) - (h - 1) / 2
x, y = np.meshgrid(x, y)
r_pos = cfg.r_pos / cfg.total_stride
labels = logistic_labels(x, y, r_pos)
labels = labels.reshape((1, 1, h, w))
labels = np.tile(labels, (n, c, 1, 1))
return labels
class GetLoss(nn.Layer):
def __init__(self, neg_weight=1.0):
super(GetLoss, self).__init__()
self.neg_weight = neg_weight
def forward(self, input, target):
pos_mask = (target == 1)
neg_mask = (target == 0)
pos_num = float(pos_mask.sum())
neg_num = float(neg_mask.sum())
weight = paddle.to_tensor(np.zeros(target.shape),'float32')
weight[pos_mask] = 1 / pos_num
weight[neg_mask] = 1 / neg_num * self.neg_weight
weight /= weight.sum()
return paddle.nn.functional.binary_cross_entropy_with_logits(
input, target, weight, reduction='sum')
if __name__ == '__main__':
labels = get_label([8, 1, 17, 17])
五、网络训练
建立了损失函数和标签之后,即可进行训练,训练采用了引入一阶动量的随机梯度下降优化参数,同时使用指数衰减调整学习率,公式如下:
其中
为初始学习率,
为超参数,定义如下:
epoch为当前的epoch,Epoch为总的epoch
参数设置:学习率从1e-2衰减到1e-5,epoch总数50,batchsize为64
在网络的输出引入一个缩小输出值的scale,因为原本的网络相关输出的值很大,直接计算loss会导致在参数更新过程中出现梯度爆炸。
#训练
import paddle
from got10k.datasets import *
from paddle.io import DataLoader
from paddle import optimizer
from tqdm import tqdm
from paddle.optimizer.lr import ExponentialDecay
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def train(data_dir, net_path=None, save_dir='pre_trained'):
# 读取数据集
seq_dataset = GOT10k(data_dir, subset='train', return_meta=False)
# 定义数据增强
transforms = SiamFCTransforms(exemplar_sz=cfg.exemplar_sz,
instance_sz=cfg.instance_sz,
context=cfg.context)
train_dataset = GOT10kDataset(seq_dataset, transforms)
loader_dataset = DataLoader(dataset=train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
drop_last=True)
# 初始化网络
paddle.device.set_device('gpu')
model = Siamfc(out_scale=cfg.out_scale,init=True)
# 建立损失函数和标签
getloss = GetLoss()
labels = get_label(size=[cfg.batch_size, 1, cfg.response_sz, cfg.response_sz])
labels = paddle.to_tensor(labels,'float32')
# 建立优化器
gamma = np.power(cfg.ultimate_lr/cfg.initial_lr,1.0/cfg.epoch_num)
lr_scheduler = ExponentialDecay(cfg.initial_lr,gamma) # 指数衰减
opt = optimizer.SGD(learning_rate=lr_scheduler, # 学习率
parameters=model.parameters(), # 参数
weight_decay=cfg.weight_decay) # 衰减系数
# 训练
if not os.path.exists(save_dir):
os.makedirs(save_dir)
start_epoch = 1
for epoch in range(start_epoch,cfg.epoch_num+1):
model.train()
for it, batch in enumerate(tqdm(loader_dataset)):
# 获取输入
z = batch[0] # z.shape=([8,3,127,127])
x = batch[1] # x.shape=([8,3,239,239])
# 通过网络
output = model(x,z)
loss = getloss(output,labels)
# 反向传播
opt.clear_grad()
paddle.autograd.backward(loss)
opt.step()
print('Epoch: {}[{}/{}] Loss: {:.5f} lr: {:.2e}'.format(
epoch, it + 1, len(loader_dataset), loss.item(), opt.get_lr()))
# 更新学习率
lr_scheduler.step()
# print(lr_scheduler)
# save checkpoint
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(
save_dir, 'siamfc_alexnet_e%d.pdparams' % (epoch))
paddle.save({'epoch': epoch,
'model': model.state_dict(),
'optimizer': opt.state_dict()}, save_path)
if __name__ == '__main__':
train('work/data/')
0%| | 0/23 [00:00<?, ?it/s]/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/tensor.py:624: UserWarning: paddle.assign doesn't support float64 input now due to current platform protobuf data limitation, we convert it to float32
"paddle.assign doesn't support float64 input now due "
4%|▍ | 1/23 [00:07<02:41, 7.36s/it]
Epoch: 1[1/23] Loss: 0.93592 lr: 1.00e-02
9%|▊ | 2/23 [00:07<01:49, 5.23s/it]
Epoch: 1[2/23] Loss: 0.80803 lr: 1.00e-02
13%|█▎ | 3/23 [00:14<01:56, 5.81s/it]
Epoch: 1[3/23] Loss: 0.74549 lr: 1.00e-02
17%|█▋ | 4/23 [00:15<01:20, 4.23s/it]
Epoch: 1[4/23] Loss: 0.70060 lr: 1.00e-02
22%|██▏ | 5/23 [00:21<01:27, 4.85s/it]
Epoch: 1[5/23] Loss: 0.67397 lr: 1.00e-02
26%|██▌ | 6/23 [00:22<01:02, 3.68s/it]
Epoch: 1[6/23] Loss: 0.66659 lr: 1.00e-02
30%|███ | 7/23 [00:28<01:09, 4.36s/it]
Epoch: 1[7/23] Loss: 0.63990 lr: 1.00e-02
35%|███▍ | 8/23 [00:29<00:50, 3.38s/it]
Epoch: 1[8/23] Loss: 0.64395 lr: 1.00e-02
39%|███▉ | 9/23 [00:36<01:00, 4.33s/it]
Epoch: 1[9/23] Loss: 0.62834 lr: 1.00e-02
43%|████▎ | 10/23 [00:37<00:43, 3.33s/it]
Epoch: 1[10/23] Loss: 0.63017 lr: 1.00e-02
48%|████▊ | 11/23 [00:43<00:52, 4.35s/it]
Epoch: 1[11/23] Loss: 0.62883 lr: 1.00e-02
52%|█████▏ | 12/23 [00:44<00:34, 3.17s/it]
Epoch: 1[12/23] Loss: 0.61871 lr: 1.00e-02
57%|█████▋ | 13/23 [00:51<00:43, 4.39s/it]
Epoch: 1[13/23] Loss: 0.61395 lr: 1.00e-02
61%|██████ | 14/23 [00:52<00:29, 3.31s/it]
Epoch: 1[14/23] Loss: 0.62463 lr: 1.00e-02
70%|██████▉ | 16/23 [00:59<00:21, 3.12s/it]
Epoch: 1[15/23] Loss: 0.60954 lr: 1.00e-02
Epoch: 1[16/23] Loss: 0.60873 lr: 1.00e-02
78%|███████▊ | 18/23 [01:06<00:15, 3.04s/it]
Epoch: 1[17/23] Loss: 0.61568 lr: 1.00e-02
Epoch: 1[18/23] Loss: 0.61443 lr: 1.00e-02
87%|████████▋ | 20/23 [01:14<00:09, 3.15s/it]
Epoch: 1[19/23] Loss: 0.61060 lr: 1.00e-02
Epoch: 1[20/23] Loss: 0.61134 lr: 1.00e-02
96%|█████████▌| 22/23 [01:21<00:03, 3.06s/it]
Epoch: 1[21/23] Loss: 0.59785 lr: 1.00e-02
Epoch: 1[22/23] Loss: 0.59997 lr: 1.00e-02
100%|██████████| 23/23 [01:28<00:00, 4.34s/it]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/framework/io.py:415: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterable) and not isinstance(obj, (
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 1[23/23] Loss: 0.59723 lr: 1.00e-02
4%|▍ | 1/23 [00:07<02:36, 7.13s/it]
Epoch: 2[1/23] Loss: 0.60008 lr: 8.71e-03
9%|▊ | 2/23 [00:08<01:52, 5.37s/it]
Epoch: 2[2/23] Loss: 0.58535 lr: 8.71e-03
13%|█▎ | 3/23 [00:14<01:53, 5.69s/it]
Epoch: 2[3/23] Loss: 0.59269 lr: 8.71e-03
17%|█▋ | 4/23 [00:15<01:18, 4.11s/it]
Epoch: 2[4/23] Loss: 0.59693 lr: 8.71e-03
26%|██▌ | 6/23 [00:22<01:01, 3.64s/it]
Epoch: 2[5/23] Loss: 0.57386 lr: 8.71e-03
Epoch: 2[6/23] Loss: 0.58227 lr: 8.71e-03
35%|███▍ | 8/23 [00:30<00:49, 3.32s/it]
Epoch: 2[7/23] Loss: 0.60373 lr: 8.71e-03
Epoch: 2[8/23] Loss: 0.58816 lr: 8.71e-03
43%|████▎ | 10/23 [00:37<00:40, 3.13s/it]
Epoch: 2[9/23] Loss: 0.57428 lr: 8.71e-03
Epoch: 2[10/23] Loss: 0.58953 lr: 8.71e-03
52%|█████▏ | 12/23 [00:44<00:34, 3.16s/it]
Epoch: 2[11/23] Loss: 0.58164 lr: 8.71e-03
Epoch: 2[12/23] Loss: 0.58790 lr: 8.71e-03
57%|█████▋ | 13/23 [00:51<00:43, 4.30s/it]
Epoch: 2[13/23] Loss: 0.59717 lr: 8.71e-03
61%|██████ | 14/23 [00:52<00:28, 3.17s/it]
Epoch: 2[14/23] Loss: 0.57433 lr: 8.71e-03
65%|██████▌ | 15/23 [00:58<00:32, 4.09s/it]
Epoch: 2[15/23] Loss: 0.57142 lr: 8.71e-03
70%|██████▉ | 16/23 [00:59<00:21, 3.09s/it]
Epoch: 2[16/23] Loss: 0.58748 lr: 8.71e-03
74%|███████▍ | 17/23 [01:06<00:25, 4.21s/it]
Epoch: 2[17/23] Loss: 0.57060 lr: 8.71e-03
78%|███████▊ | 18/23 [01:06<00:15, 3.12s/it]
Epoch: 2[18/23] Loss: 0.57931 lr: 8.71e-03
83%|████████▎ | 19/23 [01:13<00:17, 4.29s/it]
Epoch: 2[19/23] Loss: 0.55739 lr: 8.71e-03
87%|████████▋ | 20/23 [01:14<00:09, 3.13s/it]
Epoch: 2[20/23] Loss: 0.56377 lr: 8.71e-03
91%|█████████▏| 21/23 [01:21<00:08, 4.29s/it]
Epoch: 2[21/23] Loss: 0.58154 lr: 8.71e-03
96%|█████████▌| 22/23 [01:21<00:03, 3.17s/it]
Epoch: 2[22/23] Loss: 0.58087 lr: 8.71e-03
100%|██████████| 23/23 [01:28<00:00, 4.16s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 2[23/23] Loss: 0.58262 lr: 8.71e-03
9%|▊ | 2/23 [00:08<02:05, 5.98s/it]
Epoch: 3[1/23] Loss: 0.57540 lr: 7.59e-03
Epoch: 3[2/23] Loss: 0.55442 lr: 7.59e-03
17%|█▋ | 4/23 [00:15<01:23, 4.41s/it]
Epoch: 3[3/23] Loss: 0.58554 lr: 7.59e-03
Epoch: 3[4/23] Loss: 0.57220 lr: 7.59e-03
26%|██▌ | 6/23 [00:22<01:03, 3.72s/it]
Epoch: 3[5/23] Loss: 0.56074 lr: 7.59e-03
Epoch: 3[6/23] Loss: 0.55358 lr: 7.59e-03
35%|███▍ | 8/23 [00:30<00:50, 3.34s/it]
Epoch: 3[7/23] Loss: 0.56507 lr: 7.59e-03
Epoch: 3[8/23] Loss: 0.56367 lr: 7.59e-03
43%|████▎ | 10/23 [00:37<00:42, 3.24s/it]
Epoch: 3[9/23] Loss: 0.56969 lr: 7.59e-03
Epoch: 3[10/23] Loss: 0.56292 lr: 7.59e-03
52%|█████▏ | 12/23 [00:44<00:34, 3.10s/it]
Epoch: 3[11/23] Loss: 0.57934 lr: 7.59e-03
Epoch: 3[12/23] Loss: 0.56415 lr: 7.59e-03
61%|██████ | 14/23 [00:52<00:27, 3.07s/it]
Epoch: 3[13/23] Loss: 0.56923 lr: 7.59e-03
Epoch: 3[14/23] Loss: 0.55533 lr: 7.59e-03
70%|██████▉ | 16/23 [01:00<00:22, 3.18s/it]
Epoch: 3[15/23] Loss: 0.56040 lr: 7.59e-03
Epoch: 3[16/23] Loss: 0.55846 lr: 7.59e-03
78%|███████▊ | 18/23 [01:07<00:15, 3.14s/it]
Epoch: 3[17/23] Loss: 0.54790 lr: 7.59e-03
Epoch: 3[18/23] Loss: 0.58165 lr: 7.59e-03
87%|████████▋ | 20/23 [01:15<00:09, 3.14s/it]
Epoch: 3[19/23] Loss: 0.54137 lr: 7.59e-03
Epoch: 3[20/23] Loss: 0.56882 lr: 7.59e-03
96%|█████████▌| 22/23 [01:22<00:03, 3.08s/it]
Epoch: 3[21/23] Loss: 0.56345 lr: 7.59e-03
Epoch: 3[22/23] Loss: 0.56439 lr: 7.59e-03
100%|██████████| 23/23 [01:29<00:00, 4.31s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 3[23/23] Loss: 0.55234 lr: 7.59e-03
9%|▊ | 2/23 [00:08<02:01, 5.78s/it]
Epoch: 4[1/23] Loss: 0.55495 lr: 6.61e-03
Epoch: 4[2/23] Loss: 0.55586 lr: 6.61e-03
17%|█▋ | 4/23 [00:15<01:24, 4.44s/it]
Epoch: 4[3/23] Loss: 0.56306 lr: 6.61e-03
Epoch: 4[4/23] Loss: 0.56745 lr: 6.61e-03
26%|██▌ | 6/23 [00:23<01:02, 3.69s/it]
Epoch: 4[5/23] Loss: 0.54936 lr: 6.61e-03
Epoch: 4[6/23] Loss: 0.55679 lr: 6.61e-03
30%|███ | 7/23 [00:30<01:15, 4.69s/it]
Epoch: 4[7/23] Loss: 0.57084 lr: 6.61e-03
35%|███▍ | 8/23 [00:30<00:50, 3.40s/it]
Epoch: 4[8/23] Loss: 0.56528 lr: 6.61e-03
39%|███▉ | 9/23 [00:37<01:01, 4.42s/it]
Epoch: 4[9/23] Loss: 0.55478 lr: 6.61e-03
43%|████▎ | 10/23 [00:37<00:42, 3.27s/it]
Epoch: 4[10/23] Loss: 0.55739 lr: 6.61e-03
48%|████▊ | 11/23 [00:44<00:49, 4.16s/it]
Epoch: 4[11/23] Loss: 0.54679 lr: 6.61e-03
52%|█████▏ | 12/23 [00:44<00:34, 3.15s/it]
Epoch: 4[12/23] Loss: 0.54321 lr: 6.61e-03
61%|██████ | 14/23 [00:52<00:27, 3.07s/it]
Epoch: 4[13/23] Loss: 0.56846 lr: 6.61e-03
Epoch: 4[14/23] Loss: 0.55159 lr: 6.61e-03
65%|██████▌ | 15/23 [00:59<00:34, 4.31s/it]
Epoch: 4[15/23] Loss: 0.55452 lr: 6.61e-03
70%|██████▉ | 16/23 [00:59<00:21, 3.12s/it]
Epoch: 4[16/23] Loss: 0.56130 lr: 6.61e-03
74%|███████▍ | 17/23 [01:06<00:25, 4.25s/it]
Epoch: 4[17/23] Loss: 0.56254 lr: 6.61e-03
78%|███████▊ | 18/23 [01:07<00:16, 3.35s/it]
Epoch: 4[18/23] Loss: 0.54062 lr: 6.61e-03
83%|████████▎ | 19/23 [01:13<00:16, 4.00s/it]
Epoch: 4[19/23] Loss: 0.54441 lr: 6.61e-03
87%|████████▋ | 20/23 [01:15<00:10, 3.52s/it]
Epoch: 4[20/23] Loss: 0.54258 lr: 6.61e-03
91%|█████████▏| 21/23 [01:19<00:07, 3.71s/it]
Epoch: 4[21/23] Loss: 0.55964 lr: 6.61e-03
96%|█████████▌| 22/23 [01:22<00:03, 3.51s/it]
Epoch: 4[22/23] Loss: 0.55492 lr: 6.61e-03
100%|██████████| 23/23 [01:27<00:00, 3.72s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 4[23/23] Loss: 0.56480 lr: 6.61e-03
9%|▊ | 2/23 [00:08<01:58, 5.65s/it]
Epoch: 5[1/23] Loss: 0.56192 lr: 5.75e-03
Epoch: 5[2/23] Loss: 0.56169 lr: 5.75e-03
17%|█▋ | 4/23 [00:15<01:22, 4.35s/it]
Epoch: 5[3/23] Loss: 0.53957 lr: 5.75e-03
Epoch: 5[4/23] Loss: 0.55524 lr: 5.75e-03
26%|██▌ | 6/23 [00:23<01:02, 3.70s/it]
Epoch: 5[5/23] Loss: 0.55393 lr: 5.75e-03
Epoch: 5[6/23] Loss: 0.55363 lr: 5.75e-03
35%|███▍ | 8/23 [00:31<00:52, 3.51s/it]
Epoch: 5[7/23] Loss: 0.56923 lr: 5.75e-03
Epoch: 5[8/23] Loss: 0.54914 lr: 5.75e-03
43%|████▎ | 10/23 [00:38<00:42, 3.24s/it]
Epoch: 5[9/23] Loss: 0.54129 lr: 5.75e-03
Epoch: 5[10/23] Loss: 0.54755 lr: 5.75e-03
52%|█████▏ | 12/23 [00:45<00:35, 3.20s/it]
Epoch: 5[11/23] Loss: 0.55717 lr: 5.75e-03
Epoch: 5[12/23] Loss: 0.55964 lr: 5.75e-03
61%|██████ | 14/23 [00:53<00:29, 3.29s/it]
Epoch: 5[13/23] Loss: 0.53102 lr: 5.75e-03
Epoch: 5[14/23] Loss: 0.55593 lr: 5.75e-03
70%|██████▉ | 16/23 [01:01<00:22, 3.23s/it]
Epoch: 5[15/23] Loss: 0.53759 lr: 5.75e-03
Epoch: 5[16/23] Loss: 0.54130 lr: 5.75e-03
78%|███████▊ | 18/23 [01:09<00:16, 3.22s/it]
Epoch: 5[17/23] Loss: 0.54188 lr: 5.75e-03
Epoch: 5[18/23] Loss: 0.54319 lr: 5.75e-03
87%|████████▋ | 20/23 [01:16<00:09, 3.14s/it]
Epoch: 5[19/23] Loss: 0.54164 lr: 5.75e-03
Epoch: 5[20/23] Loss: 0.57035 lr: 5.75e-03
96%|█████████▌| 22/23 [01:24<00:03, 3.13s/it]
Epoch: 5[21/23] Loss: 0.55962 lr: 5.75e-03
Epoch: 5[22/23] Loss: 0.54557 lr: 5.75e-03
100%|██████████| 23/23 [01:31<00:00, 4.28s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 5[23/23] Loss: 0.53630 lr: 5.75e-03
9%|▊ | 2/23 [00:07<01:52, 5.37s/it]
Epoch: 6[1/23] Loss: 0.53901 lr: 5.01e-03
Epoch: 6[2/23] Loss: 0.55966 lr: 5.01e-03
17%|█▋ | 4/23 [00:15<01:21, 4.30s/it]
Epoch: 6[3/23] Loss: 0.54261 lr: 5.01e-03
Epoch: 6[4/23] Loss: 0.58012 lr: 5.01e-03
26%|██▌ | 6/23 [00:23<01:02, 3.69s/it]
Epoch: 6[5/23] Loss: 0.52803 lr: 5.01e-03
Epoch: 6[6/23] Loss: 0.54670 lr: 5.01e-03
35%|███▍ | 8/23 [00:30<00:51, 3.43s/it]
Epoch: 6[7/23] Loss: 0.53775 lr: 5.01e-03
Epoch: 6[8/23] Loss: 0.56192 lr: 5.01e-03
43%|████▎ | 10/23 [00:37<00:40, 3.13s/it]
Epoch: 6[9/23] Loss: 0.54907 lr: 5.01e-03
Epoch: 6[10/23] Loss: 0.54323 lr: 5.01e-03
52%|█████▏ | 12/23 [00:46<00:37, 3.43s/it]
Epoch: 6[11/23] Loss: 0.53700 lr: 5.01e-03
Epoch: 6[12/23] Loss: 0.52656 lr: 5.01e-03
61%|██████ | 14/23 [00:53<00:28, 3.18s/it]
Epoch: 6[13/23] Loss: 0.55234 lr: 5.01e-03
Epoch: 6[14/23] Loss: 0.54426 lr: 5.01e-03
70%|██████▉ | 16/23 [01:01<00:21, 3.12s/it]
Epoch: 6[15/23] Loss: 0.54654 lr: 5.01e-03
Epoch: 6[16/23] Loss: 0.52107 lr: 5.01e-03
78%|███████▊ | 18/23 [01:08<00:15, 3.13s/it]
Epoch: 6[17/23] Loss: 0.54821 lr: 5.01e-03
Epoch: 6[18/23] Loss: 0.53889 lr: 5.01e-03
87%|████████▋ | 20/23 [01:16<00:09, 3.15s/it]
Epoch: 6[19/23] Loss: 0.53072 lr: 5.01e-03
Epoch: 6[20/23] Loss: 0.54654 lr: 5.01e-03
96%|█████████▌| 22/23 [01:23<00:03, 3.03s/it]
Epoch: 6[21/23] Loss: 0.54716 lr: 5.01e-03
Epoch: 6[22/23] Loss: 0.51913 lr: 5.01e-03
100%|██████████| 23/23 [01:30<00:00, 4.26s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 6[23/23] Loss: 0.53160 lr: 5.01e-03
4%|▍ | 1/23 [00:07<02:38, 7.21s/it]
Epoch: 7[1/23] Loss: 0.52979 lr: 4.37e-03
9%|▊ | 2/23 [00:08<01:52, 5.34s/it]
Epoch: 7[2/23] Loss: 0.52126 lr: 4.37e-03
13%|█▎ | 3/23 [00:14<01:54, 5.71s/it]
Epoch: 7[3/23] Loss: 0.53736 lr: 4.37e-03
17%|█▋ | 4/23 [00:16<01:23, 4.39s/it]
Epoch: 7[4/23] Loss: 0.53954 lr: 4.37e-03
22%|██▏ | 5/23 [00:21<01:26, 4.81s/it]
Epoch: 7[5/23] Loss: 0.54176 lr: 4.37e-03
26%|██▌ | 6/23 [00:23<01:03, 3.72s/it]
Epoch: 7[6/23] Loss: 0.57112 lr: 4.37e-03
35%|███▍ | 8/23 [00:30<00:50, 3.36s/it]
Epoch: 7[7/23] Loss: 0.52973 lr: 4.37e-03
Epoch: 7[8/23] Loss: 0.53718 lr: 4.37e-03
43%|████▎ | 10/23 [00:38<00:43, 3.32s/it]
Epoch: 7[9/23] Loss: 0.52441 lr: 4.37e-03
Epoch: 7[10/23] Loss: 0.54617 lr: 4.37e-03
52%|█████▏ | 12/23 [00:45<00:35, 3.20s/it]
Epoch: 7[11/23] Loss: 0.52008 lr: 4.37e-03
Epoch: 7[12/23] Loss: 0.54956 lr: 4.37e-03
61%|██████ | 14/23 [00:53<00:29, 3.25s/it]
Epoch: 7[13/23] Loss: 0.56509 lr: 4.37e-03
Epoch: 7[14/23] Loss: 0.54426 lr: 4.37e-03
70%|██████▉ | 16/23 [01:00<00:21, 3.05s/it]
Epoch: 7[15/23] Loss: 0.54590 lr: 4.37e-03
Epoch: 7[16/23] Loss: 0.54508 lr: 4.37e-03
78%|███████▊ | 18/23 [01:08<00:15, 3.18s/it]
Epoch: 7[17/23] Loss: 0.53662 lr: 4.37e-03
Epoch: 7[18/23] Loss: 0.54384 lr: 4.37e-03
87%|████████▋ | 20/23 [01:15<00:09, 3.15s/it]
Epoch: 7[19/23] Loss: 0.53775 lr: 4.37e-03
Epoch: 7[20/23] Loss: 0.52708 lr: 4.37e-03
96%|█████████▌| 22/23 [01:22<00:02, 3.00s/it]
Epoch: 7[21/23] Loss: 0.54664 lr: 4.37e-03
Epoch: 7[22/23] Loss: 0.55685 lr: 4.37e-03
100%|██████████| 23/23 [01:29<00:00, 4.15s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 7[23/23] Loss: 0.55369 lr: 4.37e-03
9%|▊ | 2/23 [00:08<01:57, 5.59s/it]
Epoch: 8[1/23] Loss: 0.52307 lr: 3.80e-03
Epoch: 8[2/23] Loss: 0.53564 lr: 3.80e-03
13%|█▎ | 3/23 [00:14<01:56, 5.83s/it]
Epoch: 8[3/23] Loss: 0.54166 lr: 3.80e-03
17%|█▋ | 4/23 [00:15<01:20, 4.25s/it]
Epoch: 8[4/23] Loss: 0.53586 lr: 3.80e-03
22%|██▏ | 5/23 [00:21<01:28, 4.92s/it]
Epoch: 8[5/23] Loss: 0.53068 lr: 3.80e-03
26%|██▌ | 6/23 [00:23<01:06, 3.91s/it]
Epoch: 8[6/23] Loss: 0.55961 lr: 3.80e-03
30%|███ | 7/23 [00:27<01:05, 4.08s/it]
Epoch: 8[7/23] Loss: 0.53280 lr: 3.80e-03
35%|███▍ | 8/23 [00:30<00:54, 3.60s/it]
Epoch: 8[8/23] Loss: 0.53320 lr: 3.80e-03
39%|███▉ | 9/23 [00:34<00:52, 3.77s/it]
Epoch: 8[9/23] Loss: 0.53954 lr: 3.80e-03
43%|████▎ | 10/23 [00:37<00:47, 3.65s/it]
Epoch: 8[10/23] Loss: 0.54348 lr: 3.80e-03
48%|████▊ | 11/23 [00:42<00:46, 3.91s/it]
Epoch: 8[11/23] Loss: 0.52232 lr: 3.80e-03
52%|█████▏ | 12/23 [00:45<00:41, 3.74s/it]
Epoch: 8[12/23] Loss: 0.53384 lr: 3.80e-03
57%|█████▋ | 13/23 [00:49<00:38, 3.86s/it]
Epoch: 8[13/23] Loss: 0.54646 lr: 3.80e-03
61%|██████ | 14/23 [00:53<00:34, 3.82s/it]
Epoch: 8[14/23] Loss: 0.54290 lr: 3.80e-03
65%|██████▌ | 15/23 [00:57<00:31, 3.91s/it]
Epoch: 8[15/23] Loss: 0.54018 lr: 3.80e-03
70%|██████▉ | 16/23 [01:00<00:24, 3.52s/it]
Epoch: 8[16/23] Loss: 0.52218 lr: 3.80e-03
74%|███████▍ | 17/23 [01:04<00:23, 3.95s/it]
Epoch: 8[17/23] Loss: 0.53511 lr: 3.80e-03
78%|███████▊ | 18/23 [01:07<00:17, 3.48s/it]
Epoch: 8[18/23] Loss: 0.54220 lr: 3.80e-03
83%|████████▎ | 19/23 [01:12<00:15, 3.94s/it]
Epoch: 8[19/23] Loss: 0.53641 lr: 3.80e-03
87%|████████▋ | 20/23 [01:14<00:10, 3.48s/it]
Epoch: 8[20/23] Loss: 0.54800 lr: 3.80e-03
91%|█████████▏| 21/23 [01:19<00:07, 3.79s/it]
Epoch: 8[21/23] Loss: 0.52794 lr: 3.80e-03
96%|█████████▌| 22/23 [01:21<00:03, 3.42s/it]
Epoch: 8[22/23] Loss: 0.55691 lr: 3.80e-03
100%|██████████| 23/23 [01:27<00:00, 4.05s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 8[23/23] Loss: 0.53947 lr: 3.80e-03
4%|▍ | 1/23 [00:07<02:39, 7.26s/it]
Epoch: 9[1/23] Loss: 0.53901 lr: 3.31e-03
9%|▊ | 2/23 [00:07<01:50, 5.27s/it]
Epoch: 9[2/23] Loss: 0.53211 lr: 3.31e-03
13%|█▎ | 3/23 [00:14<01:55, 5.78s/it]
Epoch: 9[3/23] Loss: 0.53220 lr: 3.31e-03
17%|█▋ | 4/23 [00:15<01:19, 4.21s/it]
Epoch: 9[4/23] Loss: 0.56607 lr: 3.31e-03
26%|██▌ | 6/23 [00:22<01:01, 3.62s/it]
Epoch: 9[5/23] Loss: 0.53106 lr: 3.31e-03
Epoch: 9[6/23] Loss: 0.51545 lr: 3.31e-03
35%|███▍ | 8/23 [00:29<00:47, 3.18s/it]
Epoch: 9[7/23] Loss: 0.55350 lr: 3.31e-03
Epoch: 9[8/23] Loss: 0.53537 lr: 3.31e-03
43%|████▎ | 10/23 [00:36<00:40, 3.12s/it]
Epoch: 9[9/23] Loss: 0.53736 lr: 3.31e-03
Epoch: 9[10/23] Loss: 0.52271 lr: 3.31e-03
48%|████▊ | 11/23 [00:44<00:53, 4.50s/it]
Epoch: 9[11/23] Loss: 0.52839 lr: 3.31e-03
52%|█████▏ | 12/23 [00:45<00:36, 3.33s/it]
Epoch: 9[12/23] Loss: 0.52050 lr: 3.31e-03
61%|██████ | 14/23 [00:52<00:28, 3.21s/it]
Epoch: 9[13/23] Loss: 0.53473 lr: 3.31e-03
Epoch: 9[14/23] Loss: 0.53693 lr: 3.31e-03
70%|██████▉ | 16/23 [00:59<00:21, 3.10s/it]
Epoch: 9[15/23] Loss: 0.52158 lr: 3.31e-03
Epoch: 9[16/23] Loss: 0.53606 lr: 3.31e-03
78%|███████▊ | 18/23 [01:06<00:15, 3.00s/it]
Epoch: 9[17/23] Loss: 0.53131 lr: 3.31e-03
Epoch: 9[18/23] Loss: 0.53540 lr: 3.31e-03
87%|████████▋ | 20/23 [01:13<00:08, 2.98s/it]
Epoch: 9[19/23] Loss: 0.54101 lr: 3.31e-03
Epoch: 9[20/23] Loss: 0.52359 lr: 3.31e-03
96%|█████████▌| 22/23 [01:21<00:03, 3.14s/it]
Epoch: 9[21/23] Loss: 0.53666 lr: 3.31e-03
Epoch: 9[22/23] Loss: 0.53864 lr: 3.31e-03
100%|██████████| 23/23 [01:29<00:00, 4.58s/it]
0%| | 0/23 [00:00<?, ?it/s]
Epoch: 9[23/23] Loss: 0.53795 lr: 3.31e-03
9%|▊ | 2/23 [00:08<02:02, 5.84s/it]
Epoch: 10[1/23] Loss: 0.53813 lr: 2.88e-03
Epoch: 10[2/23] Loss: 0.55388 lr: 2.88e-03
17%|█▋ | 4/23 [00:15<01:24, 4.43s/it]
Epoch: 10[3/23] Loss: 0.54600 lr: 2.88e-03
Epoch: 10[4/23] Loss: 0.52611 lr: 2.88e-03
26%|██▌ | 6/23 [00:22<01:02, 3.68s/it]
Epoch: 10[5/23] Loss: 0.53944 lr: 2.88e-03
Epoch: 10[6/23] Loss: 0.54629 lr: 2.88e-03
35%|███▍ | 8/23 [00:30<00:50, 3.36s/it]
Epoch: 10[7/23] Loss: 0.56240 lr: 2.88e-03
Epoch: 10[8/23] Loss: 0.54403 lr: 2.88e-03
43%|████▎ | 10/23 [00:37<00:41, 3.17s/it]
Epoch: 10[9/23] Loss: 0.52237 lr: 2.88e-03
Epoch: 10[10/23] Loss: 0.54204 lr: 2.88e-03
52%|█████▏ | 12/23 [00:45<00:35, 3.20s/it]
Epoch: 10[11/23] Loss: 0.53683 lr: 2.88e-03
Epoch: 10[12/23] Loss: 0.52826 lr: 2.88e-03
61%|██████ | 14/23 [00:52<00:28, 3.14s/it]
Epoch: 10[13/23] Loss: 0.53804 lr: 2.88e-03
Epoch: 10[14/23] Loss: 0.53252 lr: 2.88e-03
Exception in thread Thread-27:
Traceback (most recent call last):
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py", line 583, in _get_data
data = self._data_queue.get(timeout=self._timeout)
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/multiprocessing/queues.py", line 105, in get
raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/threading.py", line 926, in _bootstrap_inner
self.run()
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py", line 505, in _thread_loop
batch = self._get_data()
File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py", line 599, in _get_data
"pids: {}".format(len(failed_workers), pids))
RuntimeError: DataLoader 2 workers exit unexpectedly, pids: 4275, 4276
---------------------------------------------------------------------------
SystemError Traceback (most recent call last)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py in __next__(self)
696 if in_dygraph_mode():
--> 697 data = self._reader.read_next_var_list()
698 data = _restore_batch(data, self._structure_infos.pop(0))
SystemError: (Fatal) Blocking queue is killed because the data reader raises an exception.
[Hint: Expected killed_ != true, but received killed_:1 == true:1.] (at /paddle/paddle/fluid/operators/reader/blocking_queue.h:166)
During handling of the above exception, another exception occurred:
KeyboardInterrupt Traceback (most recent call last)
/tmp/ipykernel_194/1859510741.py in <module>
76
77 if __name__ == '__main__':
---> 78 train('work/data/')
/tmp/ipykernel_194/1859510741.py in train(data_dir, net_path, save_dir)
48 for epoch in range(start_epoch,cfg.epoch_num+1):
49 model.train()
---> 50 for it, batch in enumerate(tqdm(loader_dataset)):
51 # 获取输入
52 z = batch[0] # z.shape=([8,3,127,127])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/tqdm/_tqdm.py in __iter__(self)
977 """, fp_write=getattr(self.fp, 'write', sys.stderr.write))
978
--> 979 for obj in iterable:
980 yield obj
981 # Update and possibly print the progressbar.
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py in __next__(self)
695
696 if in_dygraph_mode():
--> 697 data = self._reader.read_next_var_list()
698 data = _restore_batch(data, self._structure_infos.pop(0))
699 else:
KeyboardInterrupt:
六、推理
推理的过程主要可以分为三步,第一步是初始帧的处理,第二步是通过正向传播更新响应,第三步是根据响应反推出目标的位置。
初始帧的处理的目的是得到后续跟踪所用的模板图,并且保持不变,输入第一帧的图片和GT,输出为[6,6,128]的卷积核。该卷积核作为模板,不在线更新。这种做法虽然可以保证速度,但是一旦物体出现快速移动或者形变等就容易跟丢。
得到初始帧的模板后,即可进行跟踪计算。需要把后续帧作为目标x,第一帧作为模板z输入网络得到响应图尺寸为17x17,需要找出最大响应所在的位置。
考虑到17x17的map较小,因此采用了双三次插值的方法,对响应图进行上采样到272x272。实际中,由于跟踪目标可能会出现大小变化,因此需要引入变化尺度来调整框的大小,同时使用汉宁窗惩罚来突出目标的位置。不停地更新目标的中心及框的大小,即可完成跟踪。
#推理
import time
import paddle
import numpy as np
import cv2
from got10k.trackers import Tracker
def read_image(img_file, cvt_code=cv2.COLOR_BGR2RGB): # 将BGR格式转换成RGB格式,cv.imread都进来直接就是BGR,[w,h,c]
img = cv2.imread(img_file, cv2.IMREAD_COLOR) # cv2.imread函数读取图片,后面参数代表加载彩色图片,还有灰度图片等 返回的img为[weight,height,channel]
if cvt_code is not None: # 这个判断可以省略,上面给出了cvt_code的具体值
img = cv2.cvtColor(img, cvt_code)
return img
def show_image(img, boxes=None, box_fmt='ltwh', colors=None,
thickness=3, fig_n=1, delay=1, visualize=True,
cvt_code=cv2.COLOR_RGB2BGR):
if cvt_code is not None:
img = cv2.cvtColor(img, cvt_code) # 要用cv2显示,要把RGB转化为BGR!!!
# resize img if necessary 有必要的话resize 图片
max_size = 960 # 最大为960
if max(img.shape[:2]) > max_size:
scale = max_size / max(img.shape[:2]) # 960/max(w,h)
out_size = (
int(img.shape[1] * scale), # 960/max(w,h)*h
int(img.shape[0] * scale)) # 960/max(w,h)*w
img = cv2.resize(img, out_size)
if boxes is not None:
boxes = np.array(boxes, dtype=np.float32) * scale
if boxes is not None:
assert box_fmt in ['ltwh', 'ltrb']
boxes = np.array(boxes, dtype=np.int32) # boxes.shape(4,)
if boxes.ndim == 1: # boxes的维度是否为1
boxes = np.expand_dims(boxes,
axis=0) # boxes.shape(1,4) #增加维度 axis=0,比如[2 2 3]变成[1 2 2 3] axis=1,[2 2 3]变成[2 1 2 3] 还有axis=2/3
if box_fmt == 'ltrb':
boxes[:, 2:] -= boxes[:, :2]
# clip bounding boxes
bound = np.array(img.shape[1::-1])[None, :] # img.shape[1::-1]表示[w,h,3]->[h,w,3] ,[None,:]表示[h,w]->[1,h,w,3]
boxes[:, :2] = np.clip(boxes[:, :2], 0, bound) # boxes前两列
boxes[:, 2:] = np.clip(boxes[:, 2:], 0, bound - boxes[:, :2]) # boxes后两列
if colors is None:
colors = [
(0, 0, 255),
(0, 255, 0),
(255, 0, 0),
(0, 255, 255),
(255, 0, 255),
(255, 255, 0),
(0, 0, 128),
(0, 128, 0),
(128, 0, 0),
(0, 128, 128),
(128, 0, 128),
(128, 128, 0)] # len(colors)=12
colors = np.array(colors, dtype=np.int32) # colors.shape=[12 3]
if colors.ndim == 1:
colors = np.expand_dims(colors, axis=0)
for i, box in enumerate(boxes):
color = colors[i % len(colors)] # len(colors)=3
pt1 = (box[0], box[1])
pt2 = (box[0] + box[2], box[1] + box[3])
img = cv2.rectangle(img, pt1, pt2, color.tolist(), thickness)
if visualize:
winname = 'window_{}'.format(fig_n) # window_1 {}被格式化为1
cv2.imshow(winname, img)
cv2.waitKey(delay) # 1秒更新一次
return img
def ltwh_to_yxhw(ltwh):
yxhw = np.array([
ltwh[1] - 1 + (ltwh[3] - 1) / 2,
ltwh[0] - 1 + (ltwh[2] - 1) / 2,
ltwh[3], ltwh[2]], dtype=np.float32)
return yxhw
def yxhw_to_ltwh(yxhw):
ltwh = np.array([
yxhw[1] + 1 - (yxhw[1] - 1) / 2,
yxhw[0] + 1 - (yxhw[0] - 1) / 2,
yxhw[1], yxhw[0]])
return ltwh
def map_process(response, hanning_window):
# 数据处理,汉宁窗惩罚
response -= response.min()
response /= response.sum() + 1e-16
response = (1 - cfg.window_influence) * response + \
cfg.window_influence * hanning_window # window_influence=0.176
return response
def map_to272(responses, out_size):
responses = np.stack([cv2.resize(
u, (out_size, out_size),
interpolation=cv2.INTER_CUBIC)
for u in responses])
return responses
def x_to3s255(img, center, patch_size, three_scales, out_size, border_value):
x = [crop_and_resize(
img, center, patch_size * scale,
out_size=out_size,
border_value=border_value) for scale in three_scales]
x = np.stack(x, axis=0) # [3,255,255,3]第一个三代表三种尺度
return x
def create_hanning_window(size):
hann_window = np.outer(
np.hanning(size),
np.hanning(size)) # a,b都是行向量,则np.outer(a,b)=a^(T)*b 组成一个矩阵
hann_window /= hann_window.sum()
return hann_window
def scales():
scale_factors = cfg.scale_step ** np.linspace( # 1.0375^(-1,0,1)
-(cfg.scale_num // 2),
cfg.scale_num // 2, cfg.scale_num)
return scale_factors
def z_to127(img, center, patch_size, out_size, border_value):
z = crop_and_resize(
img, center, patch_size,
out_size=out_size,
border_value=border_value)
return z
class TrackerSiamFC(Tracker):
def __init__(self, net_path=None):
super(TrackerSiamFC, self).__init__('SiamFC', True)
self.model = Siamfc(out_scale=cfg.out_scale)
if net_path is not None:
checkpoint = paddle.load(net_path)
self.model.load_dict(checkpoint['model'])
# 传入第一帧图片和gt及初始化
def init(self, img, box):
# 推理模式,关闭自动求导
self.model.eval()
# 将原始的目标位置表示[l,t,w,h]->[center_y,center_x,h,w]
yxhw = ltwh_to_yxhw(box)
self.center, self.target_sz = yxhw[:2], yxhw[2:]
# hanning窗
self.response_upsz = cfg.response_up * cfg.response_sz
self.hanning_window = create_hanning_window(size=self.response_upsz)
# 三种尺度1.0375**(-1,0,1) 三种尺度
self.scale_factors = scales()
# patch的边长
context = cfg.context * np.sum(self.target_sz) # 上下文信息(h+w)/2
self.z_sz = np.sqrt(np.prod(self.target_sz + context)) # (h+(h+w)/2)*(w+(h+2)/2))^0.5
self.x_sz = self.z_sz * cfg.instance_sz / cfg.exemplar_sz # (h+(h+w)/2)*(w+(h+2)/2))^0.5*255/127
# 图像的RGB均值
self.avg_color = np.mean(img, axis=(0, 1))
# 裁剪一块以目标为中心,边长为z_sz大小的patch,然后将其resize成exemplar_sz的大小
z = z_to127(img, self.center, self.z_sz, cfg.exemplar_sz, self.avg_color)
z = paddle.transpose(paddle.to_tensor(z, 'float32'), perm=[2, 0, 1]).unsqueeze(0)
self.sample = z
def update(self, img):
self.model.eval()
x = x_to3s255(img, self.center, self.x_sz, self.scale_factors, cfg.instance_sz, self.avg_color)
x = paddle.to_tensor(x, 'float32')
x = paddle.transpose(x, perm=[0, 3, 1, 2])
# x : [3,255,22,22]
responses = self.model(x, self.sample)
responses = responses.squeeze(1).cpu().numpy()
# 相应的size: 17x17 -> 272x272
responses = map_to272(responses, out_size=self.response_upsz)
# 尺度变化的惩罚
responses[:cfg.scale_num // 2] *= cfg.scale_penalty
responses[:cfg.scale_num // 2 + 1] *= cfg.scale_penalty
# 找到最大的响应
scale_id = np.argmax(np.amax(responses, axis=(1, 2)))
response = responses[scale_id] # [272,272]
# 数据处理
response = map_process(response, self.hanning_window)
loc = np.unravel_index(response.argmax(), response.shape) # 返回索引response.argmax()的元素的坐标
# 反推原图的位置
disp_in_respone =np.array(loc) - (self.response_upsz-1)/2
disp = disp_in_respone/16
disp = disp*8
disp = disp*self.x_sz * self.scale_factors[scale_id]/cfg.instance_sz
self.center +=disp
# 参数更新
scale = (1-cfg.scale_lr) * 1 + cfg.scale_lr * self.scale_factors[scale_id]
self.target_sz *= scale # 得到目标的长宽
self.z_sz *=scale # h+(h+w)/2)*(w+(h+2)/2))^0.5*scale
self.x_sz *=scale # h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127*scale
# [y,x,h,w]->[l,t,w,h]
box=yxhw_to_ltwh([self.center,self.target_sz])
return box
def track(self,img_files,box,visualize=False):
fram_num = len(img_files)
boxes = np.zeros((fram_num,4))
boxes[0] = box
times = np.zeros(fram_num)
for f, img_file in enumerate(img_files):
img = read_image(img_file)
begin = time.time()
if f == 0:
self.init(img,box)
else:
boxes[f,:]=self.update(img)
times[f] = time.time() -begin
if visualize:
show_image(img,boxes[f,:])
return boxes,times
七、测试
由于notebook不支持cv2.imshow,因此要查看模型的效果需要在自己的电脑上运行,数据集上的效果测试待补充。
后续计划是复现SiamRPN及SiamRPN++,学习中。。
#测试
import os
import paddle
import glob
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
seq_dir = os.path.expanduser('work/Crossing/')
img_files = sorted(glob.glob(seq_dir + 'img/*.jpg'))
# print(img_files[0])
anno = np.loadtxt(seq_dir + 'groundtruth_rect.txt', delimiter='\t') # 读取groundtruth
net_path = 'pre_trained/siamfc_alexnet_e49.pdparams'
tracker = TrackerSiamFC(net_path=net_path)
tracker.track(img_files, anno[0], visualize=False)
更多推荐
所有评论(0)