基于PP-ShiTu的珍稀动物识别
利用PP-ShiTu套件二次开发珍稀动物识别系统
基于PP-ShiTu的珍稀动物识别
1、背景介绍
近来,短视频平台上出现了贩卖或者杀害国家保护动物的视频,为了能够利用AI自动检测短视频中是否包含国家保护动物,本项目利用飞桨开源的PP-ShiTu套件开发了珍稀动物识别系统,本项目参考项目基于PP-ShiTu的商品识别系统。
本项目提供一个demo,其演示效果如下:
2、环境准备
# 克隆 PaddleClas
# 此代码只需要执行一次,本项目已完成克隆
# # github仓库
# !git clone https://github.com/PaddlePaddle/PaddleClas.git
# gitee仓库(推荐)
!git clone https://gitee.com/paddlepaddle/PaddleClas.git
Cloning into 'PaddleClas'...
remote: Enumerating objects: 33953, done.[K
remote: Counting objects: 100% (13836/13836), done.[K
remote: Compressing objects: 100% (4190/4190), done.[K
remote: Total 33953 (delta 10382), reused 12788 (delta 9521), pack-reused 20117[K
Receiving objects: 100% (33953/33953), 189.60 MiB | 23.37 MiB/s, done.
Resolving deltas: 100% (24167/24167), done.
Checking connectivity... done.
!pip install -r /home/aistudio/PaddleClas/requirements.txt
3、数据集准备及预处理
(1)数据集介绍
数据集名称:一级保护动物图像分类数据集汇总
数据集链接:https://aistudio.baidu.com/aistudio/datasetdetail/165243
数据集介绍:
数据集中包含八千余张各类国家一级保护动物的图像数据,而对于标签,不仅提供了每张图像的所属类别,另外还根据百度百科对所有类别进行归类,将一级保护动物分为九个等级。所有标签数据用txt存储,满足PP-ShiTu要求的标准格式,另外,也方便转换成其他格式。
参考链接:国家一级保护动物_百度百科
# 解压数据集
!unzip -oq /home/aistudio/data/data165243/animal_archive.zip -d /home/aistudio/PaddleClas/dataset
(2)数据增强
已选增强策略:
- Random Flip:以一定的概率翻转图片,这对商品比较重要,因为商品摆放的角度可能各异,甚至倒放,有利于增强模型对商品本身的学习能力。
- Random Crop:随机裁剪
备选增强策略:
- Cutout数据增强策略,在随机位置Crop正方形Patch。
- AutoAugmentation策略,使用了针对ImageNet搜索得到的策略。
- Random Erasing策略,随机擦除原图中的一个矩形区域,将区域内部像素值替换为随机值。
备选fine-tuning策略:
- FixRes的linear probing策略
# 数据预处理
import os
os.chdir('/home/aistudio/PaddleClas/dataset')
data_dir = 'animal_archive/dataset/train'
pathlist = os.listdir(data_dir)
ls = [int(k) for k in pathlist]
print(max(ls), min(ls), len(ls))
# 图像的可视化
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
import numpy as np
img_dir = 'animal_archive/dataset/train/17'
img_name = 'img_train_607.jpg'
img_path = os.path.join(img_dir, img_name)
img = Image.open(img_path)
img = np.array(img)
print(img.shape)
plt.figure(figsize=(8,8))
plt.imshow(img)
67 1 67
(1326, 2000, 3)
<matplotlib.image.AxesImage at 0x7f3cd850ae50>
4、 模型选择
PP-ShiTu是一个实用的轻量级通用图像识别系统,主要由主体检测、特征学习和向量检索三个模块组成。该系统从骨干网络选择和调整、损失函数的选择、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型裁剪量化8个方面,采用多种策略,对各个模块的模型进行优化,最终得到在CPU上仅0.2s即可完成10w+库的图像识别的系统。
(1)主体检测
主体检测技术是目前应用非常广泛的一种检测技术,它指的是检测出图片中一个或者多个主体的坐标位置,然后将图像中的对应区域裁剪下来,进行识别,从而完成整个识别过程。主体检测是识别任务的前序步骤,可以有效提升识别精度。
考虑到商品识别实际应用场景中,需要快速准确地获得识别结果,故本项目选取适用于 CPU 或者移动端场景的轻量级主体检测模型PicoDet作为本项目主体检测部分的模型。此模型融合了ATSS、Generalized Focal Loss、余弦学习率策略、Cycle-EMA、轻量级检测 head等一系列优化算法,基于COCO train2017数据集进行大规模预训练,最终inference模型大小(MB)仅30.1MB,mAP可达40.1%,在cpu下单张图片预测耗时仅29.8ms,完美符合本项目实际落地需求,故在本项目中不对主体检测部分做适应性训练。
(2)特征提取
特征提取是图像识别中的关键一环,它的作用是将输入的图片转化为固定维度的特征向量,用于后续的向量检索。好的特征需要具备相似度保持性,即在特征空间中,相似度高的图片对其特征相似度要比较高(距离比较近),相似度低的图片对,其特征相似度要比较小(距离比较远)。Deep Metric Learning用以研究如何通过深度学习的方法获得具有强表征能力的特征。
考虑到本项目的真实落地的场景中,推理速度及预测准确率是考量模型好坏的重要指标,所以本项目采用 CPU 级轻量化骨干网络 PP_LCNet_x2_5 作为骨干网络, Neck 部分选用 Linear Layer, Head 部分选用 ArcMargin,Loss 部分选用 CELoss,并结合度量学习arcmargin算法,对高相似物体的区分效果远超单一模型。在 Intel 至强 6148 处理器,PP-LCNet 的单张图像 5.39ms 的预测速度下,在 ImageNet 上 Top1 识别准确率可以达到 80.82%,准确率超越大模型 ResNet50 的模型效果,而预测速度却可以达到后者的 3 倍!PP-ShiTu 充分挖掘该网络的潜力,学习一个具有超强泛化能力的特征提取模型,同一模型可在多个数据集上同时达到较高精度。
(3)向量检索
PP-ShiTu 的第三个模块是向量检索。当获得了图像特征后,我们通过计算向量距离来获得两张图像的相似度,进一步通过向量检索获取最终识别结果。这种方式最大的优点是,当增加新的品类时,不需要重新训练提取特征模型,仅需要更新检索库即可识别新的目标。为了更好地兼容(Linux, Windows, MacOS)多平台,在图像识别系统中,本项目使用 Faiss 。在此过程中,本项目选取 HNSW32 为检索算法,使得检索精度、检索速度能够取得较好的平衡,更为贴切本项目实际应用场景的使用需求。
5、模型训练(分类模型)
-
需要修改配置文件:./ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml, 主要是将数据集的路径正确填写,其他配置按需修改
-
一般各种设置对于不同数据集都会有不同程度的影响,该配置文件中的各种设置仅供参考
# 分类模型训练(使用PPLCNet)
%cd /home/aistudio/PaddleClas
!python tools/train.py \
-c ./ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml \
-o Arch.Backbone.pretrained=True \
-o Arch.Head.class_num=68 \
-o Global.epochs=200 \
-o Global.device=gpu \
-o Global.output_dir=./output_new
6、特征提取模型评估
%cd /home/aistudio/PaddleClas
!python tools/eval.py \
-c ./ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml \
-o Global.pretrained_model="/home/aistudio/PaddleClas/output_new/RecModel/best_model"
评估结果展示,总共跑了三种配置,结果展示如下:
模型 | Epoch | Top-1 | Top-5 |
---|---|---|---|
PPLCNet | 200 | 84.244 | 90.960 |
PPLCNet+wide Neck | 200 | 83.553 | 91.086 |
7、模型推理
导出推理模型
PaddlePaddle框架保存的权重文件分为两种:支持前向推理和反向梯度的训练模型 和 只支持前向推理的推理模型。二者的区别是推理模型针对推理速度和显存做了优化,裁剪了一些只在训练过程中才需要的tensor,降低显存占用,并进行了一些类似层融合,kernel选择的速度优化。因此可执行如下命令导出推理模型。
# 导出推理模型
%cd /home/aistudio/PaddleClas
!python tools/export_model.py \
-c ./ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml \
-o Global.pretrained_model="output_new/RecModel/best_model"
生成的推理模型文件如下
PaddleClas/inference
|--inference.pdmodel :存储推理模型的结构
|--inference.pdiparams: 存储权重
|--inference.pdiparams.info: 存储推理模型相关的参数信息
# 推理模型测试代码:获取待测试图片的特征向量
%cd /home/aistudio/PaddleClas/deploy
!python python/predict_rec.py \
-c configs/inference_rec.yaml \
-o Global.rec_inference_model_dir="../inference" \
-o Global.infer_imgs="../dataset/test_images/test_001.jpg"
可以正常通过导出的推理模型完成图片特征向量的计算,说明推理模型导出成功
8、搭建动物识别系统
(1)识别原理
- 输入图片经过保存的推理模型后,被编码为长度为128的向量。
- 在索引文件中使用向量检索算法查找与当前图片编码向量最接近的一个已有向量。
- 如果检索到的向量的分数低于阈值,那么认为此时匹配失败,索引库中没有对应的类别,输出为空。
- 如果检索到的向量的分数高于阈值,那么匹配成功,输出分数最高的向量对应的label作为识别结果。
- 利用预训练完成的主体检测模型PP-Picodet对图片的动物主体进行检测,并将label和对应分数标注在预测锚框上方。
索引文件生成
# 首先根据classlabels.txt文件得到序号和中文标签的映射
%cd ~/PaddleClas
dic = {}
with open('./dataset/animal_archive/label_list.txt', "r") as r:
for line in r.readlines():
i, k = line.split()[0], line.split()[1]
dic[i] = k
# print(dic)
/home/aistudio/PaddleClas
# 数据集没有提供gallery_label.txt
# 读取test_list.txt,然后每一类选择三个(不足则全部选择),加入gallery_label.txt
import shutil
import os
os.chdir('/home/aistudio/PaddleClas/dataset')
contents = []
with open("./animal_archive/dataset/test_list.txt", 'r') as fa:
contents = fa.readlines()
def warp(s):
ls = s.split()[0:2]
new_s = " ".join(ls)
return new_s
contents = list(map(warp, contents))
with open("./animal_archive/dataset/gallery_label.txt", 'w') as fb:
cnt = 0
prev = ''
for line in contents:
if prev == '' or line.split(' ')[-1] != prev:
cnt = 0
if cnt < 3:
path = line.split(' ')[0]
name = path.split('/')[-1]
shutil.copy(path, './animal_archive/dataset/gallery/' + name)
prev = line.split(' ')[-1]
fb.write('animal_archive/dataset/gallery/' + name + '\t' + dic[prev] + '\n')
cnt += 1
建立索引库
修改configs/inference_general.yaml
文件内容:
Global:
rec_inference_model_dir: "/home/aistudio/PaddleClas/inference"
IndexProcess:
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
image_root: "/home/aistudio/PaddleClas/dataset/"
index_dir: "/home/aistudio/PaddleClas/dataset/animal_archive/dataset/index"
data_file: "/home/aistudio/PaddleClas/dataset/animal_archive/dataset/gallery_label.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
embedding_size: 512
执行如下代码:
# 建立索引库
%cd /home/aistudio/PaddleClas/deploy
!python3 python/build_gallery.py \
-c configs/inference_general.yaml
-
索引库建立完成后便可以使用主体检测模型进行检测,并使用推理模型进行识别。
-
运行如下命令,下载通用检测
inference
模型并解压: -
下载完成后/home/aistudio/PaddleClas/deploy/文件夹下建立了一个models的子文件夹,里面是主体检测模型解压完成的文件夹,文件夹里包含四个文件,其中三个是推理模型的相关权重参数,另一个是推理用的配置文件。
通用主体检测模型下载
%cd /home/aistudio/PaddleClas/deploy/
%mkdir models
%cd models
# 下载通用检测 inference 模型并解压
# 上面那里保存的是特征提取的backbone的推理模型,这里是主体检测的推理模型Picodet
# !wget https://paddledet.bj.bcebos.com/deploy/Inference/picodet_l_640_coco_lcnet_non_postprocess.tar
# !tar -xf picodet_l_640_coco_lcnet_non_postprocess.tar
!wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.zip
!unzip picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.zip
(2)文件配置
修改推理文件configs/inference_general.yaml
内容:
Global:
det_inference_model_dir: "./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer"
rec_inference_model_dir: "/home/aistudio/PaddleClas/inference"
rec_nms_thresold: 0.05
注意,配置文件中的infer_imgs字段不需要修改,我们在推理的时候使用命令行传入需要推理的图片路径即可
9、 系统测试
(1)待测试图片为索引库中已经存在的图片
# 基于索引库的图像识别
%cd /home/aistudio/PaddleClas/deploy
%pwd
!python python/predict_system.py \
-c configs/inference_general.yaml \
-o Global.infer_imgs="../dataset/test_images/2.jpg" \
-o IndexProcess.index_dir="../dataset/animal_archive/dataset/index"
其中 bbox 表示检测出的主体所在位置,rec_docs 表示索引库中与检测框最为相似的类别,rec_scores 表示对应的置信度。
检测的可视化结果也保存在 output 文件夹下,对于本张图像,识别结果可视化如下所示:
(2)待测试图片为索引库里没有的商品:
这里以国家二级保护动物红腹锦鸡为例
对图像 /test_images/红腹锦鸡.jpg
进行识别,待检索图像如下所示。
# 识别
%cd /home/aistudio/PaddleClas/deploy
!python python/predict_system.py \
-c configs/inference_general.yaml \
-o Global.infer_imgs="../dataset/test_images/红腹锦鸡.jpg" \
-o IndexProcess.index_dir="../dataset/animal_archive/dataset/index"
-
由于默认的索引库中不包含对应的索引信息,所以这里的识别结果有误,输出了相似度较高的动物,也就是“黑头角雉”,这种误判可以通过改进分类模型加以改进。
-
当索引库中的图像无法覆盖我们实际识别的场景时,我们可以通过构建新的索引库的方式,完成未知类别的图像识别。即在预测未知类别的图像时,只需要将对应类别的相似图像添加到索引库中,从而完成对未知类别的图像识别,这一过程是不需要重新训练的。
准备新的数据与标签
首先需要将与待检索图像相似的图像列表拷贝到索引库原始图像的文件夹。这里将所有的底库图像数据都放在文件夹 PaddleClas/dataset/animal_archive/dataset/gallery
中。
然后需要编辑记录了图像路径和标签信息的文本文件,这里 PaddleClas 将更正后的标签信息文件放在了 PaddleClas/dataset/animal_archive/dataset/gallery_update.txt
文件中。可以与原来的 PaddleClas/dataset/animal_archive/dataset/gallery_label.txt
标签文件进行对比,添加了红腹锦鸡的索引图像。
每一行的文本中,第一个字段表示图像的相对路径,第二个字段表示图像对应的标签信息,中间用 \t 键分隔开
更新索引库
# 使用下面的命令建立新的 index 索引,加速识别后的检索过程
%cd /home/aistudio/PaddleClas/deploy/
%pwd
!python python/build_gallery.py \
-c configs/inference_general.yaml \
-o IndexProcess.data_file="../dataset/animal_archive/dataset/gallery_update.txt" \
-o IndexProcess.index_dir="../dataset/animal_archive/dataset/index_update"
- 基于新的索引库的图像识别
最终新的索引信息保存在文件夹 /home/aistudio/dataset/all/index_update
中。
使用新的索引库,对上述图像进行识别,运行命令如下:
%cd /home/aistudio/PaddleClas/deploy/
!python python/predict_system.py \
-c configs/inference_general.yaml \
-o Global.infer_imgs="../dataset/test_images/红腹锦鸡.jpg" \
-o IndexProcess.index_dir="../dataset/animal_archive/dataset/index_update"
由测试效果图可知,模型对于未参与训练的商品及多个商品均有较好的识别效果
10、 模型优化思路
(1)检测模型调优
PP-ShiTu
中检测模型采用的 PP-PicoDet
算法,在使用官方模型后,如果不满足精度需求,则可以参考此部分文档,进行模型调优
(2)识别模型调优
因为要对模型进行训练,所以参照数据准备部分描述收集自己的数据集。值得注意的是,此部分需要准备大量的数据,以保证识别模型效果。
- 数据增强:根据实际情况选择不同数据增强方法。如:实际应用中数据遮挡比较严重,建议添加
RandomErasing
增强方法。 - 换不同的
backbone
,一般来说,越大的模型,特征提取能力更强。 - 增加模型的宽度,一般来说,模型宽度越大,学习能力越强。
- 选择不同的
Metric Learning
方法。不同的Metric Learning
方法,对不同的数据集效果可能不太一样,建议尝试其他Loss
- 采用蒸馏方法,对小模型进行模型能力提升,但是进行知识蒸馏可能比较困难。
- 增补数据集。针对错误样本,添加badcase数据。
模型训练完成后,参照系统测试进行检索库更新。同时,对整个pipeline进行测试,如果精度不达预期,则重复此步骤。
11、视频流检测
近来,短视频平台上出现了贩卖或者杀害国家保护动物的视频,为了能够利用AI自动检测短视频中是否包含国家保护动物,可以利用视频流检测方式。
import cv2
import numpy as np
from PIL import Image, ImageDraw
import os
os.chdir('/home/aistudio/PaddleClas')
def CutVideo2Image(video_path, img_path):
#将视频输出为图像
#video_path为输入视频文件路径
#img_path为输出图像文件夹路径
cap = cv2.VideoCapture(video_path)
index = 0
while(True):
ret,frame = cap.read()
if ret:
cv2.imwrite(img_path+'/%d.jpg'%index, frame)
index += 1
else:
break
cap.release()
CutVideo2Image('./dataset/test_video.mp4', './dataset/video_imgs_new')
# 对生成的图像进行推理
%cd /home/aistudio/PaddleClas/deploy/
!python python/predict_system.py \
-c configs/inference_general.yaml \
-o Global.infer_imgs="../dataset/video_imgs_new" \
-o IndexProcess.index_dir="../dataset/animal_archive/dataset/index_update"
# 将推理结果图像合成视频
import os
os.chdir('/home/aistudio/PaddleClas')
import cv2
size = (1280, 720)
# 分割结果
# 完成写入对象的创建,第一个参数是合成之后的视频的名称,第二个参数是可以使用的编码器,第三个参数是帧率即每秒钟展示多少张图片,第四个参数是图片大小信息
fourcc = cv2.VideoWriter.fourcc('m', 'p', '4', 'v')
videowrite = cv2.VideoWriter('dataset/final.mp4', fourcc, 15, size)# 15是帧数,size是图片尺寸
filelist = os.listdir('./deploy/output_video_new')
img_array = []
for filename in ['./deploy/output_video_new/{}.jpg'.format(i) for i in range(len(filelist))]:
img = cv2.imread(filename)
img_array.append(img)
for i in range(len(img_array)):
videowrite.write(img_array[i])
img = cv2.imread(filename)
img_array.append(img)
for i in range(len(img_array)):
videowrite.write(img_array[i])
videowrite.release()
12、项目总结
-
本项目基于PP-ShiTu图像识别系统进行动物识别,主要难点在于数据集不足,噪声数据较多,没有进行严格的数据清洗,因此识别模型的精度受到一定影响。另外,主体检测模型的识别有些问题,我怀疑是推理代码有问题,主体检测模型的域适应性有待进一步加强。
-
使用PaddleClas在模型调优上也比较方便,根据自己的数据集选择合适的模型即可,可以对其宽度进行适当加宽,但是学习率策略、数据增强、EMA和weight decay等都保持默认是可以达到最好的效果的,说明PaddleClas提供了高质量的配置文件,大大减轻了我们的调优负担。
作者介绍
赵祎安 大连理工大学飞桨领航团团长 计算机科学与技术专业 2019级 本科生
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接
更多推荐
所有评论(0)