转载自AI Studio 项目链接https://aistudio.baidu.com/aistudio/projectdetail/2432755

项目背景

我最近在复现Faster-RCNN 的resnet50+FPN版本的时候在数据预处理和加载的时候,在阅读paddlepaddle官方给的数据加载的类的时候发现并没有一个很好的例子,

所以我今天就做了一个这个,如何做一个读取标准的Voc2012年的数据加载器,之后我也会用我自己的一个数据集,来做一个自定义数据读取

Paddle数据集定义与加载文档

VOC数据集介绍

本次使用的是来自ai studio 笨笨提供的pascal-voc数据集

pascal-voc数据集,包含voc2007和voc2012数据,主要用于目标检测,语义分隔等视觉任务

下面是Pascal-voc数据集目录的结构

.
└── VOCdevkit     #根目录
    └── VOC2012   #不同年份的数据集,这里只下载了2012的,还有2007等其它年份的
        ├── Annotations        #存放xml文件,与JPEGImages中的图片一一对应,解释图片的内容等等
        ├── ImageSets          #该目录下存放的都是txt文件,txt文件中每一行包含一个图片的名称,末尾会加上±1表示正负样本
        │   ├── Action
        │   ├── Layout
        │   ├── Main
        │   └── Segmentation
        ├── JPEGImages         #存放源图片
        ├── SegmentationClass  #存放的是图片,语义分割相关
        └── SegmentationObject #存放的是图片,实例分割相关

这里我们将会用到VOC2012数据集

其中因为Faster-RCNN是用于目标检测任务所以我们将会用到Annotations ,JPEGImages ,和ImageSets中Main之中的train.txt和val.txt

Annotations 存放xml文件的目录

JPEGImages 存放图片文件的目录

train.txt 存放训练文件名称的txt

val.txt 存放验证文件名称的txt

首先我们将pascal-voc数据集进行解压

!unzip -oq data/data4379/pascalvoc.zip

因为我们这里只用到了VOC2012所以将VOC2012文件夹移动到根目录下

!mv pascalvoc/VOCdevkit/VOC2012 ./

自定义数据集解析

paddlepaddle官方提供了一个十分简单的自定义数据集的案例

import paddle
from paddle.io import Dataset

BATCH_SIZE = 64
BATCH_NUM = 20

IMAGE_SIZE = (28, 28)
CLASS_NUM = 10


class MyDataset(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super(MyDataset, self).__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = paddle.uniform(IMAGE_SIZE, dtype='float32')
        label = paddle.randint(0, CLASS_NUM-1, dtype='int64')

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

# 测试定义的数据集
custom_dataset = MyDataset(BATCH_SIZE * BATCH_NUM)

print('=============custom dataset=============')
for data, label in custom_dataset:
    print(data.shape, label.shape)
    break

我们可以按照他的样子来一步一步进行实现

创建一个类并定义它

# 定义数据读取类,继承Paddle.io.Dataset
class VOCDataset(paddle.io.Dataset):

实现构造函数,定义数据集读取路径

在__init__方法中,我们要定义一下读取VOC2012各个文件夹的路径
同时还需要将VOC2012数据集的类别文件进行读取

存放VOC2012数据集的类别文件我给放在根目录下了

路径:pascal_voc_classes.json


def __init__(self,voc_root, year='2012',transforms=None, txt_name:str = 'train.txt'):
        assert year in ['2007','2012'], "year must be in ['2007','2012']"
        self.root = os.path.join(voc_root,f"VOC{year}")
        self.img_root = os.path.join(self.root,'JPEGImages')
        self.annotations_root = os.path.join(self.root,'Annotations')

        txt_path = os.path.join(self.root,"ImageSets",'Main',txt_name)
        assert os.path.exists(txt_path),'not found {} file'.format(txt_name)

        with open(txt_path) as read:
            self.xml_list = [os.path.join(self.annotations_root,line.strip()+'.xml')
                            for line in read.readlines() if len(line.strip()) >0 ]
        

        #check file
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
        for xml_path in self.xml_list:
            assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
        
        # read class_indict
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        json_file = open(json_file, 'r')
        self.class_dict = json.load(json_file)
        json_file.close()

        self.transforms = transforms
    

实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)

def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []



        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # convert everything into a paddle.Tensor
        boxes = paddle.to_tensor(boxes).astype('float32')
        labels = paddle.to_tensor(labels).astype('int32')
        iscrowd = paddle.to_tensor(iscrowd, dtype=paddle.int64)
        image_id = paddle.to_tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])


        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)
    
        return image, target
    
    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

parse_xml_to_dict方法返回的数据

{'filename': '2010_001142.jpg', 'folder': 'VOC2012', 'object': [{'name': 'bottle', 'bndbox': {'xmax': '282', 'xmin': '264', 'ymax': '244', 'ymin': '210'}, 'difficult': '0', 'occluded': '0', 'pose': 'Unspecified', 'truncated': '0'}, {'name': 'bottle', 'bndbox': {'xmax': '308', 'xmin': '295', 'ymax': '184', 'ymin': '162'}, 'difficult': '1', 'occluded': '0', 'pose': 'Unspecified', 'truncated': '0'}, {'name': 'bottle', 'bndbox': {'xmax': '270', 'xmin': '254', 'ymax': '224', 'ymin': '196'}, 'difficult': '1', 'occluded': '0', 'pose': 'Unspecified', 'truncated': '1'}, {'name': 'bottle', 'bndbox': {'xmax': '292', 'xmin': '281', 'ymax': '225', 'ymin': '204'}, 'difficult': '1', 'occluded': '0', 'pose': 'Unspecified', 'truncated': '1'}, {'name': 'bottle', 'bndbox': {'xmax': '221', 'xmin': '212', 'ymax': '227', 'ymin': '208'}, 'difficult': '1', 'occluded': '0', 'pose': 'Unspecified', 'truncated': '0'}, {'name': 'person', 'bndbox': {'xmax': '371', 'xmin': '315', 'ymax': '220', 'ymin': '103'}, 'difficult': '0', 'occluded': '1', 'pose': 'Frontal', 'truncated': '1'}, {'name': 'person', 'bndbox': {'xmax': '379', 'xmin': '283', 'ymax': '342', 'ymin': '171'}, 'difficult': '0', 'occluded': '0', 'pose': 'Left', 'truncated': '0'}, {'name': 'person', 'bndbox': {'xmax': '216', 'xmin': '156', 'ymax': '260', 'ymin': '180'}, 'difficult': '0', 'occluded': '1', 'pose': 'Right', 'truncated': '1'}, {'name': 'person', 'bndbox': {'xmax': '223', 'xmin': '205', 'ymax': '198', 'ymin': '172'}, 'difficult': '1', 'occluded': '1', 'pose': 'Frontal', 'truncated': '1'}, {'name': 'person', 'bndbox': {'xmax': '280', 'xmin': '218', 'ymax': '234', 'ymin': '155'}, 'difficult': '0', 'occluded': '1', 'pose': 'Right', 'truncated': '1'}, {'name': 'person', 'bndbox': {'xmax': '343', 'xmin': '292', 'ymax': '241', 'ymin': '185'}, 'difficult': '1', 'occluded': '1', 'pose': 'Left', 'truncated': '1'}], 'segmented': '0', 'size': {'depth': '3', 'height': '375', 'width': '500'}, 'source': {'annotation': 'PASCAL VOC2010', 'database': 'The VOC2010 Database', 'image': 'flickr'}}
!pip install lxml
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (4.8.0)
[33mWARNING: You are using pip version 21.3.1; however, version 22.0.3 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m

import paddle
import os
import json
from PIL import Image
from lxml import etree


# 定义数据读取类,继承Paddle.io.Dataset
class VOCDataset(paddle.io.Dataset):

    def __init__(self,voc_root, year='2012',transforms=None, txt_name:str = 'train.txt'):
        assert year in ['2007','2012'], "year must be in ['2007','2012']"
        self.root = os.path.join(voc_root,f"VOC{year}")
        self.img_root = os.path.join(self.root,'JPEGImages')
        self.annotations_root = os.path.join(self.root,'Annotations')

        txt_path = os.path.join(self.root,"ImageSets",'Main',txt_name)
        assert os.path.exists(txt_path),'not found {} file'.format(txt_name)

        with open(txt_path) as read:
            self.xml_list = [os.path.join(self.annotations_root,line.strip()+'.xml')
                            for line in read.readlines() if len(line.strip()) >0 ]
        

        #check file
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
        for xml_path in self.xml_list:
            assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
        
        # read class_indict
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        json_file = open(json_file, 'r')
        self.class_dict = json.load(json_file)
        json_file.close()

        self.transforms = transforms

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []



        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # convert everything into a paddle.Tensor
        boxes = paddle.to_tensor(boxes).astype('float32')
        labels = paddle.to_tensor(labels).astype('int32')
        iscrowd = paddle.to_tensor(iscrowd, dtype=paddle.int64)
        image_id = paddle.to_tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])


        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)
    
        return image, target
    
    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}
    
    def collate_fn(batch):
        return tuple(zip(*batch))

with open('VOC2012/ImageSets/Main/train.txt') as t:
    pass
train_dataset = VOCDataset('./', "2012")
print(train_dataset.class_dict)
{'aeroplane': 1, 'bicycle': 2, 'bird': 3, 'boat': 4, 'bottle': 5, 'bus': 6, 'car': 7, 'cat': 8, 'chair': 9, 'cow': 10, 'diningtable': 11, 'dog': 12, 'horse': 13, 'motorbike': 14, 'person': 15, 'pottedplant': 16, 'sheep': 17, 'sofa': 18, 'train': 19, 'tvmonitor': 20}

VOC读取测试

import paddle.vision.transforms as transforms
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
import random

# read class_indict
category_index = {}
try:
    json_file = open('./pascal_voc_classes.json', 'r')
    class_dict = json.load(json_file)
    category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
    print(e)
    exit(-1)

data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}

# load train data set
train_data_set = VOCDataset('./', "2012")
print(len(train_data_set))
for index in random.sample(range(0, len(train_data_set)), k=5):
    img, target = train_data_set[index]
    draw_box(img,
             target["boxes"].numpy(),
             target["labels"].numpy(),
             [1 for i in range(len(target["labels"].numpy()))],
             category_index,
             thresh=0.5,
             line_thickness=5)
    plt.imshow(img)
    plt.show()
5717

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1VANjVvA-1646469073208)(output_18_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cKLoNyJm-1646469073210)(output_18_2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-criYhOCH-1646469073210)(output_18_3.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DXKYXBvA-1646469073211)(output_18_4.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0T6ryWQ9-1646469073211)(output_18_5.png)]

类VOC 数据

!unzip -oq data/data106197/voc.zip
import paddle
import os
import json
from PIL import Image
from lxml import etree

# 定义数据读取类,继承Paddle.io.Dataset
class Selfataset(paddle.io.Dataset):

    def __init__(self,voc_root,transforms=None,txt_name:str = 'train.txt'):
        self.root =voc_root
        self.img_root = os.path.join(self.root,'JPEGImages')
        self.annotations_root = os.path.join(self.root,'Annotations')

        txt_path = os.path.join(self.root,txt_name)
        print(txt_path)
        assert os.path.exists(txt_path),'not found {} file'.format(txt_name)


#self.xml_list = [os.path.join(self.annotations_root,line.strip()+'.xml')
                            #for line in read.readlines() if len(line.strip()) >0 ]
        self.image_list = []
        self.xml_list = []
        with open(txt_path) as read:
            self.path_list = [line.strip() for line in read.readlines() if len(line.strip()) >0 ] 
            for path in self.path_list:
                self.image_list.append(os.path.join(self.root,path.split(' ')[0]))
                self.xml_list.append(os.path.join(self.root,path.split(' ')[1]))
        
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
        for xml_path in self.xml_list:
            assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
        
        #read class
        self.class_dict = {}
        self.class_path = os.path.join(self.root,'labels.txt')
        print(self.class_path)
        with open(self.class_path) as read:
            self.classes = [class_name.strip() for class_name in read.readlines() ]
            print(self.classes)
            for number,class_name in enumerate(self.classes,1):
                self.class_dict[class_name] = number



        self.transforms = transforms

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        #print(data)
        img_path = os.path.join(self.img_root, data["frame"]+'.jpg')
        image = Image.open(img_path)
        #if image.format != "JPEG":
            #raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # convert everything into a paddle.Tensor
        boxes = paddle.to_tensor(boxes).astype('float32')
        labels = paddle.to_tensor(labels).astype('int32')
        iscrowd = paddle.to_tensor(iscrowd, dtype=paddle.int64)
        image_id = paddle.to_tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])


        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)
    
        return image, target
    
    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

    def collate_fn(batch):
        return tuple(zip(*batch))
a = Selfataset('voc',None,'train_list.txt')
voc/train_list.txt
voc/labels.txt
['flv', 'gx', 'mbw']
a.class_dict
{'flv': 1, 'gx': 2, 'mbw': 3}
import paddle.vision.transforms as transforms
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
import random

# read class_indict
category_index = {}
try:
    json_file = open('./pascal_voc_classes.json', 'r')
    class_dict = json.load(json_file)
    category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
    print(e)
    exit(-1)

data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}

# load train data set
train_data_set = Selfataset('voc',None,'train_list.txt')
print(len(train_data_set))
for index in random.sample(range(0, len(train_data_set)), k=5):
    img, target = train_data_set[index]
    draw_box(img,
             target["boxes"].numpy(),
             target["labels"].numpy(),
             [1 for i in range(len(target["labels"].numpy()))],
             category_index,
             thresh=0.6,
             line_thickness=5)
    plt.imshow(img)
    plt.show()
# targetn = []
# for index in range(0, len(train_data_set)):
#     try:
#         img, target = train_data_set[index]
#         targetn.append(target["labels"].numpy())
#     except:
#         pass

voc/train_list.txt
voc/labels.txt
['flv', 'gx', 'mbw']
1216

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ekBhuhqT-1646469073212)(output_24_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Yd5Jy6vF-1646469073212)(output_24_2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jUiCpwbm-1646469073212)(output_24_3.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ouo9s6g9-1646469073213)(output_24_4.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tmogXrXb-1646469073213)(output_24_5.png)]

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐