0 引言

  近年基于深度学习的语义/实例分割技术不断发展,分割性能持续提高,但对物体形状的精确分割仍然不轻松,并且很多任务中(即使不是遥感领域)可能对精确分割并无需求而只需要目标的位置信息,本项目对遥感影像中的目标定位进行了探索实验。

  该任务其实可以看做无边框估计的目标检测,但考虑到魔改模型的代价不确定性,笔者又想到了类语义分割方法:其中单纯分割物体质心样本挖掘难度大、梯度不平滑,此处将该任务视作回归任务——模型为单通道输出,训练标签为以物体质心为中心的高斯热力图集合。

  类似的工作有Centroid-UNet[1],实验数据来自文献[2]——“Road and Building Detection Datasets”。

  值得一提的是,热力图对相关领域的研究和应用提供了重要价值,例如目标检测模型CenterNet[3]、热力图人群计数MSCNN[4]。

[1] Lakmal Deshapriya N, Tran D, Reddy S, et al. Centroid-UNet: Detecting Centroids in Aerial Images[J]. arXiv e-prints, 2021: arXiv: 2112.06530.

[2] Mnih V. Machine learning for aerial image labeling[M]. University of Toronto (Canada), 2013.

[3] Zhou X, Wang D, Krähenbühl P. Objects as points[J]. arXiv preprint arXiv:1904.07850, 2019.

[4] Zeng L, Xu X, Cai B, et al. Multi-scale convolutional neural networks for crowd counting[C]//2017 IEEE International Conference on Image Processing (ICIP). IEEE, 2017: 465-469.

1 数据介绍

  “马萨诸塞州建筑数据集”由 151(137:4:10)个 1500 x 1500 像素的航拍图像组成,空间分辨率约为每像素 1 平方米(Mnih,2013 年)。 它为两个类别(建筑物和非建筑物)并提供像素级地面实况数据。需要注意的是该尺寸图像可能与部分网络架构不兼容。

  实验提取了建筑标注以构建的围绕物体质心的热力图。

  数据地址:https://www.cs.toronto.edu/~vmnih/data

imgimg

!pip install scikit-image
!pip install paddleseg
!unzip -oq data/data57019/mass_build.zip -d data/mass_build

2 热力图生成

  本步骤实现将图像分割标注转化为热力图标注。

import glob
import json
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
import paddle.nn.functional as F
import paddleseg
from skimage import measure
%matplotlib inline

  获取标注图中的各个质心和半径。const_radius:自适应高斯图像大小或固定大小。

def label2centroid(label_path, size_wh, const_radius=10):
    _label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
    _label = cv2.resize(_label, size_wh, interpolation=cv2.INTER_NEAREST)

    _label = measure.label(_label)
    regions = measure.regionprops(_label)
    centroid_list = [region.centroid for region in regions]
    if not const_radius:
        radius_list = [region.equivalent_diameter / 2. for region in regions]
    else:
        radius_list = [const_radius] * len(centroid_list)

    return {'centroid_list': centroid_list, 'radius_list': radius_list}

  将质心和半径数值转化为热力图。

def centroids2heatmap(size_wh, centroid_list, radius_list):
    def get_gaussian(radius):
        x = np.linspace(-radius, +radius, radius * 2)
        y = np.linspace(-radius, +radius, radius * 2)
        xx, yy = np.meshgrid(x, y)
        d = np.sqrt(xx ** 2 + yy ** 2)
        sigma, mu = radius / 2, 0.0
        gauss = np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2)))
        gauss = (gauss - np.min(gauss) + 1e-8) / (np.max(gauss) - np.min(gauss) + 1e-8)
        return gauss

    pad = 100
    _heatmap = np.zeros((size_wh[1] + pad * 2, size_wh[0] + pad * 2), dtype='float32')
    for i in range(len(centroid_list)):
        y, x = int(centroid_list[i][0]), int(centroid_list[i][1])
        r = int(np.ceil(radius_list[i]))
        gaussian = get_gaussian(r)
        _heatmap[pad + y - r:pad + y + r, pad + x - r:pad + x + r] = \
            np.maximum(_heatmap[pad + y - r:pad + y + r, pad + x - r:pad + x + r],
                       gaussian)
    return _heatmap[pad:-pad, pad:-pad]

3 数据变换与读取

class Transforms:
    def __init__(self, mode, size_wh=(1500, 1500)):
        self.mode = mode.lower()
        assert self.mode in ['train', 'val']
        assert isinstance(size_wh, (list, tuple)) and len(size_wh) == 2
        self.size_wh = size_wh

    def __call__(self, image, label):
        image = cv2.resize(image, self.size_wh, interpolation=cv2.INTER_AREA)
        if self.mode == 'train':
            image, label = self.random_flip(image, label)
        image = self.normalize(image)
        image = np.transpose(image, (2, 0, 1))
        return image, label

    @staticmethod
    def normalize(image):
        image = image.astype('float32')
        image /= 255.
        image -= [0.5, 0.5, 0.5]
        image /= [0.5, 0.5, 0.5]
        return image

    @staticmethod
    def random_flip(image, label):
        if np.random.rand() < 0.5:
            image = cv2.flip(image, 1)
            label = cv2.flip(label, 1)
        if np.random.rand() < 0.5:
            image = cv2.flip(image, 0)
            label = cv2.flip(label, 0)
        return image, label
class Building(paddle.io.Dataset):
    def __init__(self, image_dir, label_dir, mode='val', size_wh=(1500, 1500)):
        super(Building, self).__init__()

        self.mode = mode.lower()
        assert self.mode in ['train', 'val']
        assert os.path.exists(image_dir) and os.path.exists(label_dir)

        self.image_path_list = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        self.label_path_list = sorted(glob.glob(os.path.join(label_dir, '*.png')))
        assert len(self.image_path_list) == len(self.label_path_list)

        self.transforms = Transforms(mode=self.mode, size_wh=size_wh)
        self.heatmap_path_list = self._init_heatmap()

    def __getitem__(self, index):
        image = cv2.imread(self.image_path_list[index])
        heatmap = cv2.imread(self.heatmap_path_list[index], cv2.IMREAD_UNCHANGED)
        image, heatmap = self.transforms(image, heatmap)
        return image, heatmap

    def __len__(self):
        return len(self.heatmap_path_list)

    def _init_heatmap(self):
        heatmap_list = []
        for label_path in self.label_path_list:
            info = label2centroid(label_path,
                                  self.transforms.size_wh,
                                  const_radius=10)
            heatmap = centroids2heatmap(self.transforms.size_wh,
                                        centroid_list=info['centroid_list'],
                                        radius_list=info['radius_list'])
            heatmap_list.append(label_path.replace('.png', '.tiff'))
            cv2.imwrite(heatmap_list[-1], heatmap)
        return heatmap_list
train_dataset = Building(
    image_dir='data/mass_build/png/train',
    label_dir='data/mass_build/png/train_labels',
    mode='train')

val_dataset = Building(
    image_dir='data/mass_build/png/val',
    label_dir='data/mass_build/png/val_labels',
    mode='val')

test_dataset = Building(
    image_dir='data/mass_build/png/test',
    label_dir='data/mass_build/png/test_labels',
    mode='val')

4 样本可视化

show_index = 0
show_dataset = val_dataset
image, label = show_dataset[show_index]

plt.figure(figsize=(9, 9))
plt.subplot(221)
plt.title(show_dataset.image_path_list[show_index])
plt.imshow(np.transpose(image, (1, 2, 0)) * 0.5 + 0.5)

plt.subplot(223)
plt.title(show_dataset.label_path_list[show_index])
mask = cv2.imread(show_dataset.label_path_list[show_index],
                  cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask,
                  show_dataset.transforms.size_wh,
                  interpolation=cv2.INTER_NEAREST)
plt.imshow(mask)

plt.subplot(224)
plt.title(show_dataset.heatmap_path_list[show_index])
plt.imshow(label)

plt.tight_layout()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FuCEsqdY-1643543650369)(output_16_0.png)]

5 模型构建

class CentroidNet(paddle.nn.Layer):
    def __init__(self, use_pretrained=False):
        super(CentroidNet, self).__init__()

        model_pretrained = None
        backbone_pretrained = None
        if use_pretrained:
            backbone_pretrained = 'https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz'
            model_pretrained = 'https://bj.bcebos.com/paddleseg/dygraph/cityscapes/fcn_hrnetw18_cityscapes_1024x512_80k/model.pdparams'

        self.backbone = paddleseg.models.HRNet_W18(
            pretrained=backbone_pretrained)
        self.model = paddleseg.models.FCN(
            num_classes=1,
            backbone=self.backbone,
            backbone_indices=(-1,),
            pretrained=model_pretrained)
        self.sigmoid = paddle.nn.Sigmoid()

    def forward(self, x):
        logits = self.model(x)
        x = self.sigmoid(logits[0])
        return x
model = CentroidNet(use_pretrained=True)
paddle.summary(model, input_size=(1, 3, 1500, 1500))

6 超参与优化器

  固定学习率的训练方式预计会在训练中期收敛。

num_epochs = 60
learning_rate = 0.001
train_batch_size = 2
step_each_epoch = len(train_dataset) // train_batch_size

# lr_decay_epochs = [30, 90]
# boundaries = [b * step_each_epoch for b in lr_decay_epochs]
# values = [learning_rate * (0.1**i) for i in range(len(lr_decay_epochs) + 1)]
# lr = paddle.optimizer.lr.PiecewiseDecay(
#     boundaries=boundaries,
#     values=values)

# optimizer = paddle.optimizer.Momentum(
#     learning_rate=lr,
#     momentum=0.9,
#     weight_decay=paddle.regularizer.L2Decay(0.0005),
#     parameters=model.parameters())

optimizer = paddle.optimizer.AdamW(
    learning_rate=learning_rate,
    beta1=0.9,
    beta2=0.999,
    parameters=model.parameters(),
    weight_decay=0.01)
train_dataloader = paddle.io.DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    num_workers=2,
    shuffle=True)

val_dataloader = paddle.io.DataLoader(
    val_dataset,
    batch_size=1)

test_dataloader = paddle.io.DataLoader(
    test_dataset,
    batch_size=1)
mse_loss = paddle.nn.MSELoss(reduction='mean')

7 模型训练

def evaluate(model, dataloader, loss):
    model.eval()
    loss_sum = 0
    with paddle.no_grad():
        for step, data in enumerate(dataloader()):
            image, label = data
            predict = model(image)
            val_loss = loss(predict, label)
            loss_sum += val_loss.numpy()[0]
    return loss_sum / len(dataloader)
min_val_loss = 5000000
for epoch in range(1, num_epochs):
    model.train()
    train_loss_sum = 0
    for step, data in enumerate(train_dataloader()):
        image, label = data
        predict = model(image)

        train_loss = mse_loss(predict, label)
        train_loss_sum += train_loss.numpy()[0]
        train_loss.backward()
        optimizer.step()

        lr_sche = optimizer._learning_rate
        if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
            lr_sche.step()
        model.clear_gradients()
        optimizer.clear_gradients()

    train_loss = train_loss_sum / len(train_dataset)
    val_loss = evaluate(model, val_dataloader, mse_loss)

    print('[Epoch %d] loss: %.5f/%.5f, lr: %.5f.' % (epoch, train_loss, val_loss, optimizer.get_lr()))
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        if not os.path.exists('output/fcn_hrnetw18_e60_adamw1e-3'):
            os.makedirs('output/fcn_hrnetw18_e60_adamw1e-3')

        paddle.save(model.state_dict(), f'output/fcn_hrnetw18_e60_adamw1e-3/model_{epoch}.pdparams')
        paddle.save(optimizer.state_dict(), f'output/fcn_hrnetw18_e60_adamw1e-3/model_{epoch}.pdopt')

        print(f'[Info] current best model is epoch %d (mean mse: %.5f).' % (epoch, min_val_loss))

8 评估与预测

  (1)加载训练好的模型参数,计算测试集上的平均损失。

!cp output/fcn_hrnetw18_e60_adamw1e-3/model_32.pdparams work/fcn_hrnetw18.pdparams
model = CentroidNet()
state_dict = paddle.load('work/fcn_hrnetw18.pdparams')
model.set_state_dict(state_dict)
print('val loss: %.5f, test loss: %.5f' % (min_val_loss, evaluate(model, test_dataloader, mse_loss)))
val loss: 0.01373, test loss: 0.01780

  (2)测试集中图像预测结果可视化。

def preprocess(image):
    image = cv2.resize(image, val_dataset.transforms.size_wh)
    image = Transforms.normalize(image)
    image = np.transpose(image, (2, 0, 1))
    image = np.expand_dims(image, axis=0)
    image = paddle.to_tensor(image, dtype='float32')
    return image


def postprocess(predict):
    predict = predict.numpy()
    predict = np.squeeze(predict)
    return predict


def predict(model, image):
    image = preprocess(image)
    logit = model(image)[0]
    predict = postprocess(logit)
    return predict


def detect_peaks(image):
    from scipy.ndimage import generate_binary_structure, maximum_filter, binary_erosion
    neighborhood = generate_binary_structure(2, 2)
    local_max = maximum_filter(image, footprint=neighborhood) == image
    background = (image == 0)
    eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
    detected_peaks = local_max ^ eroded_background
    return detected_peaks
image = cv2.imread('data/mass_build/png/test/23879080_15.png')
label = cv2.imread('data/mass_build/png/test_labels/23879080_15.tiff', cv2.IMREAD_UNCHANGED)
label = cv2.resize(label, val_dataset.transforms.size_wh, interpolation=cv2.INTER_NEAREST)

heatmap = predict(model, image)
heatmap /= 0.5  # Particularities of PaddleSeg(FCN): (0-0.5) => (0-1)
heatmap[heatmap < 0.5] = 0
plt.figure(figsize=(8, 8))
plt.subplot(221)
plt.imshow(image)
plt.subplot(223)
plt.imshow(label)
plt.subplot(224)
plt.imshow(heatmap)
plt.tight_layout()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-C5uBn6LV-1643543650370)(output_36_0.png)]

info = label2centroid(
    label_path='data/mass_build/png/test_labels/23879080_15.png',
    size_wh=val_dataset.transforms.size_wh,
    const_radius=10)

.tight_layout()
plt.show()

[外链图片转存中…(img-C5uBn6LV-1643543650370)]

info = label2centroid(
    label_path='data/mass_build/png/test_labels/23879080_15.png',
    size_wh=val_dataset.transforms.size_wh,
    const_radius=10)

print('GT Count: %d, Prediction Count: %d' % (len(info['centroid_list']), np.sum(detect_peaks(heatmap))))
GT Count: 1392, Prediction Count: 1500

9 项目总结

  本项目将遥感影像语义分割标签转换为高斯热力图,采用单通道的语义分割网络,基于均方误差损失函数进行回归任务的训练,探索了基于热力图的遥感影像目标定位与计数。

  本项目还有很多优化空间,感兴趣可加强交流。


  我的AI Studio主页。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐