★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

介绍

生成对抗网络(GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN由生成模型和判别模型组成,生成模型的任务是去生成某个特定分布的数据,例如合成和真实图片接近的图片,它生成的数据(图片),我们称为Fake Image。对于判别模型,它的任务是区分真实的图片和生成器生成的图片,理想状态下,判别模型对生成器生成图片,判别为假,输出为0,对真实图片,判别为真,输出为1。

其中,3D点云在未来的元宇宙时代也会发挥很大的作用,GAN与3D点云的结合,又会产生哪些有趣的新技术呢?

创新点

1、完成了单图拟合生成网络的功能

2、完成了用three.js显示3D点云的功能

训练

引入依赖

import os
import numpy as np
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from model.PointNet import PointNet, Generator
from tool.point_tools import pointDataLoader

配置列表

MAX_POINT = 200
category = {
    'bathtub': 0,
    'bed': 1,
    'chair': 2,
    'desk': 3,
    'dresser': 4,
    'monitor': 5,
    'night_stand': 6,
    'sofa': 7,
    'table': 8,
    'toilet': 9
}

数据处理

整理训练、测试文件夹中的图片

def getDatalist(file_path='./dataset/modelnet10_shape_names.txt'):
    f = open(file_path, 'r')
    f_train = open('./dataset/train.txt', 'w')
    f_test = open('./dataset/test.txt', 'w')
    for category in f:
        dict_path = os.path.join('./dataset/', category.split('\n')[0])
        data_dict = os.listdir(dict_path)
        count = 0
        for data_path in data_dict:
            if count % 60 != 0:
                f_train.write(os.path.join(dict_path, data_path) + ' ' + category)
            else:
                f_test.write(os.path.join(dict_path, data_path) + ' ' + category)
            count += 1
    f_train.close()
    f_test.close()
    f.close()

def getNames():
    if not os.path.exists(f'./dataset/train'):
        os.mkdir(f'./dataset/result')
    if not os.path.exists(f'./dataset/test'):
        os.mkdir(f'./dataset/result')
    if not os.path.exists(f'./dataset/result'):
        os.mkdir(f'./dataset/result')
    dirs = ['dataset/train', 'dataset/test']
    for dir in dirs:
        fns = os.listdir(dir)
        fh = open(dir + '.txt', 'w')
        for fn in fns:
            fh.write(os.path.join(dir, fn) + ' real\n')

if __name__ == '__main__':
    # getDatalist()
    getNames()

定义网络/训练/测试

这里损失函数处使用了单张图进行拟合,所以最后生成网络会生成和训练图里类似的图片,并没有使用到GAN。如果需要改成GAN可以调整生成网络loss的GAN部分,不过GAN网络很难收敛

pointnet = PointNet()
paddle.summary(pointnet, (64, 3, MAX_POINT, 1))

def train():
    dataloader = pointDataLoader(file_path='./dataset/train.txt', mode='train')

    netD = PointNet()
    netG = Generator()
    paddle.summary(netG, (1, 250))
    # paddle.summary(netG, (1, 3, 100))
    # paddle.summary(netG, (1, 100, 1, 1, 1))
    # 加载历史模型
    if os.path.exists('./model/netD.pdparams'):
        netD.load_dict(paddle.load('./model/netD.pdparams'))
    if os.path.exists('./model/netG.pdparams'):
        netG.load_dict(paddle.load('./model/netG.pdparams'))
    
    # 加载损失函数
    criterion = F.nll_loss
    # criterion = F.binary_cross_entropy_with_logits
    # 加载优化器
    # optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002)
    # optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002)
    optimizerD = paddle.optimizer.SGD(parameters=netD.parameters(), learning_rate=0.0002)
    optimizerG = paddle.optimizer.SGD(parameters=netG.parameters(), learning_rate=0.002)

    epoch_num = 2000
    for epoch in range(epoch_num+1):
        for i, data in enumerate(dataloader()):
            # 1. 训练判别器
            # 1.1 训练判别器的真实图片
            optimizerD.clear_grad()
            real, _ = data
            label = paddle.full((real.shape[0], 1), 1).astype('int64')
            real = paddle.to_tensor(real)
            output = netD(real)
            errD_real = ((output - label) ** 2).sum() # criterion(output, label)
            errD_real.backward()
            D_x = output.mean()
            # 1.2 训练判别器的假图片
            noise = paddle.rand([label.shape[0], 250])
            # noise = paddle.randn([label.shape[0], 3, 100])
            # noise = paddle.rand([label.shape[0], 100, 1, 1, 1])

            fake = netG(noise)
            label.fill_(0).astype('int64')
            output = netD(fake.detach())
            errD_fake = ((output - label) ** 2).sum()
            # errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean()
            errD = errD_real + errD_fake
            # errD.backward()
            optimizerD.step()
            optimizerG.clear_grad()
            # 2. 训练生成器
            label.fill_(1).astype('int64')
            output = netD(fake)
            # from mpl_toolkits import mplot3d
            # import matplotlib.pyplot as plt
            # import numpy as np

            # ax = plt.axes(projection='3d')
            # i = 0
            # ax.scatter3D(real[i][0], real[i][1], real[i][2], c='r', s=1)
            # plt.show()
            # 拟合单图
            errG = float(0)
            for i, _ in enumerate(fake):
                errG += ((fake[i] - real[i])**2).sum()
            # GAN拟合
            # errG = ((output - label) ** 2).sum()
            # errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean()
            # if errD.item() > errG.item():
            #     optimizerD.step()
            # if errG.item() > 0.1:
            optimizerG.step()
            # 3. 打印训练信息
            # print('[%d/%d][%d] Loss_D: %.4f' % (epoch, epoch_num, i, errD_real.item()))
            print('[%d/%d][%d] Loss_G: %.4f' % (epoch, epoch_num, i, errG))
            # print('[%d/%d][%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, epoch_num, i, errD.item(), errG, D_x, D_G_z1, D_G_z2))

        if epoch % 100 == 0:
            paddle.save(netD.state_dict(), './model/netD.pdparams')
            paddle.save(optimizerD.state_dict(), './model/netD.pdopt')
            paddle.save(netG.state_dict(), './model/netG.pdparams')
            paddle.save(optimizerG.state_dict(), './model/netG.pdopt')
            test(epoch)

def test(epoch=-1):
    netG = Generator()

    netG.load_dict(paddle.load('./model/netG.pdparams'))
    netG.eval()

    # 测试生成并显示点图
    noise = paddle.ones([1, 250])
    # noise = paddle.randn([1, 3, 100])
    # noise = paddle.randn([1, 100, 1, 1, 1])
    fake = netG(noise)
    out = ''
    for i in fake[0].transpose([2,1,0])[0]:
        out += str(float(i[0])) + ' ' + str(float(i[1])) + ' ' + str(float(i[2])) + '\n'
    fh = open(f'./dataset/result/test_{epoch}.off', 'w')
    fh.write(f'OFF\n{MAX_POINT} 0 0\n')
    fh.write(out)
    fh.close()

if __name__ == '__main__':
    train()
    # test()

最终结果

用MeshLab查看

单图拟合飞机结果

1、我生成的图如上(仔细看左下角是飞机头),如果有 off 查看工具的话,可以去 dataset/result/ 里面找最新的文件

2、切换 dataset/train 里面的 off 文件可以换其他 3D点云 拟合

用three.js查看

可以用鼠标左键旋转,右键移动,滚轮缩放。把生成的.off下载到本地,然后拖入到下面这个网页中展示

from IPython.display import IFrame
IFrame(src='./3d_show.html', width=1000, height=500)

使用演示


此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐