
【特训营第三期】生成学习-3D物体生成
AI特训营第三期,目前可以拟合单张3D点云,用网络生成飞机3D点云
·
★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
介绍
生成对抗网络(GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN由生成模型和判别模型组成,生成模型的任务是去生成某个特定分布的数据,例如合成和真实图片接近的图片,它生成的数据(图片),我们称为Fake Image。对于判别模型,它的任务是区分真实的图片和生成器生成的图片,理想状态下,判别模型对生成器生成图片,判别为假,输出为0,对真实图片,判别为真,输出为1。
其中,3D点云在未来的元宇宙时代也会发挥很大的作用,GAN与3D点云的结合,又会产生哪些有趣的新技术呢?
创新点
1、完成了单图拟合生成网络的功能
2、完成了用three.js显示3D点云的功能
训练
引入依赖
import os
import numpy as np
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from model.PointNet import PointNet, Generator
from tool.point_tools import pointDataLoader
配置列表
MAX_POINT = 200
category = {
'bathtub': 0,
'bed': 1,
'chair': 2,
'desk': 3,
'dresser': 4,
'monitor': 5,
'night_stand': 6,
'sofa': 7,
'table': 8,
'toilet': 9
}
数据处理
整理训练、测试文件夹中的图片
def getDatalist(file_path='./dataset/modelnet10_shape_names.txt'):
f = open(file_path, 'r')
f_train = open('./dataset/train.txt', 'w')
f_test = open('./dataset/test.txt', 'w')
for category in f:
dict_path = os.path.join('./dataset/', category.split('\n')[0])
data_dict = os.listdir(dict_path)
count = 0
for data_path in data_dict:
if count % 60 != 0:
f_train.write(os.path.join(dict_path, data_path) + ' ' + category)
else:
f_test.write(os.path.join(dict_path, data_path) + ' ' + category)
count += 1
f_train.close()
f_test.close()
f.close()
def getNames():
if not os.path.exists(f'./dataset/train'):
os.mkdir(f'./dataset/result')
if not os.path.exists(f'./dataset/test'):
os.mkdir(f'./dataset/result')
if not os.path.exists(f'./dataset/result'):
os.mkdir(f'./dataset/result')
dirs = ['dataset/train', 'dataset/test']
for dir in dirs:
fns = os.listdir(dir)
fh = open(dir + '.txt', 'w')
for fn in fns:
fh.write(os.path.join(dir, fn) + ' real\n')
if __name__ == '__main__':
# getDatalist()
getNames()
定义网络/训练/测试
这里损失函数处使用了单张图进行拟合,所以最后生成网络会生成和训练图里类似的图片,并没有使用到GAN。如果需要改成GAN可以调整生成网络loss的GAN部分,不过GAN网络很难收敛
pointnet = PointNet()
paddle.summary(pointnet, (64, 3, MAX_POINT, 1))
def train():
dataloader = pointDataLoader(file_path='./dataset/train.txt', mode='train')
netD = PointNet()
netG = Generator()
paddle.summary(netG, (1, 250))
# paddle.summary(netG, (1, 3, 100))
# paddle.summary(netG, (1, 100, 1, 1, 1))
# 加载历史模型
if os.path.exists('./model/netD.pdparams'):
netD.load_dict(paddle.load('./model/netD.pdparams'))
if os.path.exists('./model/netG.pdparams'):
netG.load_dict(paddle.load('./model/netG.pdparams'))
# 加载损失函数
criterion = F.nll_loss
# criterion = F.binary_cross_entropy_with_logits
# 加载优化器
# optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002)
# optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002)
optimizerD = paddle.optimizer.SGD(parameters=netD.parameters(), learning_rate=0.0002)
optimizerG = paddle.optimizer.SGD(parameters=netG.parameters(), learning_rate=0.002)
epoch_num = 2000
for epoch in range(epoch_num+1):
for i, data in enumerate(dataloader()):
# 1. 训练判别器
# 1.1 训练判别器的真实图片
optimizerD.clear_grad()
real, _ = data
label = paddle.full((real.shape[0], 1), 1).astype('int64')
real = paddle.to_tensor(real)
output = netD(real)
errD_real = ((output - label) ** 2).sum() # criterion(output, label)
errD_real.backward()
D_x = output.mean()
# 1.2 训练判别器的假图片
noise = paddle.rand([label.shape[0], 250])
# noise = paddle.randn([label.shape[0], 3, 100])
# noise = paddle.rand([label.shape[0], 100, 1, 1, 1])
fake = netG(noise)
label.fill_(0).astype('int64')
output = netD(fake.detach())
errD_fake = ((output - label) ** 2).sum()
# errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean()
errD = errD_real + errD_fake
# errD.backward()
optimizerD.step()
optimizerG.clear_grad()
# 2. 训练生成器
label.fill_(1).astype('int64')
output = netD(fake)
# from mpl_toolkits import mplot3d
# import matplotlib.pyplot as plt
# import numpy as np
# ax = plt.axes(projection='3d')
# i = 0
# ax.scatter3D(real[i][0], real[i][1], real[i][2], c='r', s=1)
# plt.show()
# 拟合单图
errG = float(0)
for i, _ in enumerate(fake):
errG += ((fake[i] - real[i])**2).sum()
# GAN拟合
# errG = ((output - label) ** 2).sum()
# errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean()
# if errD.item() > errG.item():
# optimizerD.step()
# if errG.item() > 0.1:
optimizerG.step()
# 3. 打印训练信息
# print('[%d/%d][%d] Loss_D: %.4f' % (epoch, epoch_num, i, errD_real.item()))
print('[%d/%d][%d] Loss_G: %.4f' % (epoch, epoch_num, i, errG))
# print('[%d/%d][%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, epoch_num, i, errD.item(), errG, D_x, D_G_z1, D_G_z2))
if epoch % 100 == 0:
paddle.save(netD.state_dict(), './model/netD.pdparams')
paddle.save(optimizerD.state_dict(), './model/netD.pdopt')
paddle.save(netG.state_dict(), './model/netG.pdparams')
paddle.save(optimizerG.state_dict(), './model/netG.pdopt')
test(epoch)
def test(epoch=-1):
netG = Generator()
netG.load_dict(paddle.load('./model/netG.pdparams'))
netG.eval()
# 测试生成并显示点图
noise = paddle.ones([1, 250])
# noise = paddle.randn([1, 3, 100])
# noise = paddle.randn([1, 100, 1, 1, 1])
fake = netG(noise)
out = ''
for i in fake[0].transpose([2,1,0])[0]:
out += str(float(i[0])) + ' ' + str(float(i[1])) + ' ' + str(float(i[2])) + '\n'
fh = open(f'./dataset/result/test_{epoch}.off', 'w')
fh.write(f'OFF\n{MAX_POINT} 0 0\n')
fh.write(out)
fh.close()
if __name__ == '__main__':
train()
# test()
最终结果
用MeshLab查看
1、我生成的图如上(仔细看左下角是飞机头),如果有 off 查看工具的话,可以去 dataset/result/ 里面找最新的文件
2、切换 dataset/train 里面的 off 文件可以换其他 3D点云 拟合
用three.js查看
可以用鼠标左键旋转,右键移动,滚轮缩放。把生成的.off下载到本地,然后拖入到下面这个网页中展示
from IPython.display import IFrame
IFrame(src='./3d_show.html', width=1000, height=500)
使用演示
此文章为搬运
原项目链接
更多推荐
所有评论(0)