基于PaddlePaddle的PredNet模型
PredNet是一个利用神经生物学中的预测性编码(Predictive Coding)原理所构建的视频预测模型,预测过程自顶向下,感知过程自底向上。
基于PaddlePaddle实现的PredNet模型
1. 简介
PredNet是一个利用神经生物学中的预测性编码(Predictive Coding)原理所构建的视频预测模型,预测过程自顶向下,感知过程自底向上,对于其本身原理更感兴趣的同学可以参考Wikipedia。PreNet为该原理在视频数据上的一个实现,结构如下所示。网络为层级结构,每一层都会通过R来计算A_hat以预测输入A,两者的差异E再作为输入传递给下一层。其中R为每一层的表征,由一个LSTM来完成对时序信息的整合。特别的是R的输入除了由当前层的误差E外,还包含高层的表征,从而使得顶层的语义信息可以向底层传递。
以下给出两张官方参数的预测结果,第一行为真实图像,第二行为预测图像,第一个时间步网络处于初始化状态,没有外部输入,故输出为0。
参考repo: prednet
2. 数据集
本项目使用KITTI数据集进行训练。KITTI是一个自动驾驶领域常用的数据集,其由配置了传感器的汽车于德国卡尔斯鲁厄采集,数据中包含了双目彩色图像、深度图像、雷达点云等传感数据,并提供目标检测、实例分割等常见CV任务的标签。
按照论文中的设置,将RGB图像降采样至128x160。由于原始KITTI数据过大(~165G),作者在DropBox上提供了处理过的版本。又由于处理过后的数据为.hkl
格式,只能在python2中使用hickle 2.1.0进行加载,所以我将数据集转换为hdf5
格式,上传至AI Studio。
数据分为train,val,test三个数据集,分别包含41396,154,832张图像。每个文件中包含两个变量,images
为所有图像帧,sources
是每帧图像
的来源,用于判断帧之间的连续性。本项目已挂载数据集,位于/home/aistudio/data/data119650
。
3. 复现精度
采用与原文代码一致的训练参数:batch size 4,epoch 150,samples per epoch 500,优化器Adam,初始学习率1e-3,75个epoch后减小为1e-4。保存验证集loss最低的模型为最优模型。
数据集 | 复现精度要求 | 原始代码库精度 | 本项目精度 |
---|---|---|---|
KITTI | 0.007000 | 0.006995 | 0.006900 |
4. 环境依赖
-
硬件:
- x86 cpu
- NVIDIA GPU
-
框架:
- PaddlePaddle==2.2.0
-
其他依赖项:
- matplotlib
- h5py
- tqdm
5. 快速开始
Step 1 准备环境
%cd ~/work/prednet-paddle
!pip install -r requirements.txt
Step 2 使用官方预训练模型,验证模型(可选)
!python kitti_evaluate.py --weight_file model_data_keras2/tensorflow_weights/prednet_kitti_weights.hdf5 --data_dir ~/data/data119650/
!cat kitti_results/prediction_scores.txt
Step 3 训练模型
# 注:使用V100训练需2h左右
!python kitti_train.py --data_dir ~/data/data119650/
Step 4 验证模型
### 验证
!python kitti_evaluate.py --data_dir ~/data/data119650/
!cat kitti_results/prediction_scores.txt
### 可视化结果
import matplotlib.pyplot as plt
plot_dir = 'kitti_results/prediction_plots/'
plot_names = os.listdir(plot_dir)
img = plt.imread(os.path.join(plot_dir, plot_names[0])) # 可以修改index查看不同结果
plt.imshow(img)
mshow(img)
plt.show()
6. 代码结构
├── LICENSE
├── README.md
├── data.py # 数据集定义
├── kitti_data # 数据文件夹
│ ├── test.h5 # 测试数据
│ ├── train.h5 # 训练数据
│ └── val.h5 # 验证数据
├── kitti_evaluate.py # 评估脚本
├── kitti_results
│ ├── prediction_plots/ # 测试集预测可视化
│ └── prediction_scores.txt # 测试集指标
├── kitti_settings.py # 路径定义
├── kitti_train.py # 训练脚本
├── prednet.py # 网络定义
├── requirements.txt # 依赖包
└── utils.py # 功能函数
7. 复现心得
PredNet作为2016年的模型,与现在的模型相比更加纯粹一些,没有各种花式trick,复现起来相对容易。
原始代码是使用keras实现的,所以在复现时有几个小细节需要注意:
- 默认初始化方式:keras的默认初始化方式为
XavierUniform
,而paddle的默认初始化方式为XavierNormal
,本项目通过传入weight_attr
进行修改。 - HardSigmoid表达式:keras为
clip(x/5+0.5, 0, 1)
,而paddle为clip(x/6+0.5, 0, 1)
,本项目通过重写HardSigmoid进行对齐。
可能受限于当时的计算资源,作者并没有对模型进行深度的调参,在实验的过程中发现,调整学习率衰减率至0.5,训练400个epoch,test MSE可以达到0.006546,并且仍在减少,说明模型的潜力还没有被完全发掘。感兴趣的同学可以继续调参~
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
更多推荐
所有评论(0)