百度网盘AI大赛——表格检测 Baseline
基于Resnet18回归表格的四个拐角坐标,本项目附带一个可以直接提交的样例,分数20+。
·
百度网盘AI大赛——表格检测
基于Resnet18回归表格的四个拐角坐标,本项目附带一个可以直接提交的样例,分数20+。
1. 比赛介绍
生活中,扫描技术越来越常见,通过手机就能将图片转化为可编辑的文档等;但是现在的技术在处理带有表格类型的文字的时候往往没有那么灵敏,把完整表格拆分成难以使用的零散个体似乎很常见又令人苦恼。本次比赛旨在解决这个问题,通过万能的算法,准确地识别表格在图片中的位置并标注。
1.1. 数据介绍
本次比赛最新发布的数据集共包含训练集、A榜测试集、B榜测试集三个部分,其中训练集共10000张图片,A榜测试集共500张图片,B榜测试集500张图片;imgs目录下为所有训练图片;annos.txt 为标注文件,json格式,格式示例如下:
{
"a.jpg": [ # 图片文件名称
{
'box': [xmin, ymin, xmax,ymax], # 表格box位置,(xmin, ymin)为box左上点,(xmax, ymax)为box右下点
'lb': [x, y], # 表格left bottom点,即左下顶点
'lt': [x, y], # 表格left top点,即左上顶点
'rt': [x, y], # 表格right top点,即右上顶点
'rb': [x, y] # 表格right bottom点,即右下顶点
}],
"b.jpg": [{
'box': [xmin, ymin, xmax, ymax],
'lb': [x, y],
'lt': [x, y],
'rt': [x, y],
'rb': [x, y]
}]
}
1.2. 赛题分析
查看数据可知,本次代标注的图片中仅有一个表格,因此可以看作是一个回归问题。
2. 代码
2.1. 准备数据
! unzip -oq data/data182385/train.zip
! unzip -oq data/data182385/testA.zip
2.2. 创建数据读取器
import json
import paddle
import os
import cv2
class MyDateset(paddle.io.Dataset):
def __init__(self, txt_dir = 'train/annos.txt', root_dir = 'train/imgs', mode = 'train'):
super(MyDateset, self).__init__()
self.mode = mode
self.root_dir = root_dir
with open('train/annos.txt', 'r') as f:
self.data = json.load(f)
self.name_list = list(self.data.keys())
def __getitem__(self, index):
name = self.name_list[index]
img_dir = os.path.join(self.root_dir, name)
img = cv2.imread(img_dir)
# get img infor
h, w, c = img.shape
# pre-process of img
img = paddle.vision.transforms.resize(img, (512,512), interpolation='bilinear')
img = img.transpose((2,0,1))
img = img/255
img = paddle.to_tensor(img).astype('float32')
# pre-process of labels
data = self.data[name][0] # get box infor
labels = []
for pos in ['lb','lt','rt','rb']:
labels += [data[pos][0]/w, data[pos][1]/h]
labels = paddle.to_tensor(labels).astype('float32')
return img,labels
def __len__(self):
return len(self.name_list)
if 1:
train_dataset=MyDateset('train.txt')
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
drop_last=False)
for step, data in enumerate(train_dataloader):
data, label = data
print(step, data.shape, label.shape)
break
W1222 12:38:52.568889 851 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1222 12:38:52.572556 851 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
0 [16, 3, 512, 512] [16, 8]
2.3. 构建网络
对于回归问题,可以简单使用Resnet进行处理。
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet,self).__init__()
self.resnet = paddle.vision.models.resnet18(pretrained = True, num_classes = 0)
self.fc = paddle.nn.Linear(in_features=512, out_features=8)
def forward(self,x):
x = self.resnet(x)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
if 1:
paddle.summary(MyNet(),(16, 3, 512, 512))
-------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===============================================================================
Conv2D-21 [[16, 3, 512, 512]] [16, 64, 256, 256] 9,408
BatchNorm2D-21 [[16, 64, 256, 256]] [16, 64, 256, 256] 256
ReLU-10 [[16, 64, 256, 256]] [16, 64, 256, 256] 0
MaxPool2D-2 [[16, 64, 256, 256]] [16, 64, 128, 128] 0
Conv2D-22 [[16, 64, 128, 128]] [16, 64, 128, 128] 36,864
BatchNorm2D-22 [[16, 64, 128, 128]] [16, 64, 128, 128] 256
ReLU-11 [[16, 64, 128, 128]] [16, 64, 128, 128] 0
Conv2D-23 [[16, 64, 128, 128]] [16, 64, 128, 128] 36,864
BatchNorm2D-23 [[16, 64, 128, 128]] [16, 64, 128, 128] 256
BasicBlock-9 [[16, 64, 128, 128]] [16, 64, 128, 128] 0
Conv2D-24 [[16, 64, 128, 128]] [16, 64, 128, 128] 36,864
BatchNorm2D-24 [[16, 64, 128, 128]] [16, 64, 128, 128] 256
ReLU-12 [[16, 64, 128, 128]] [16, 64, 128, 128] 0
Conv2D-25 [[16, 64, 128, 128]] [16, 64, 128, 128] 36,864
BatchNorm2D-25 [[16, 64, 128, 128]] [16, 64, 128, 128] 256
BasicBlock-10 [[16, 64, 128, 128]] [16, 64, 128, 128] 0
Conv2D-27 [[16, 64, 128, 128]] [16, 128, 64, 64] 73,728
BatchNorm2D-27 [[16, 128, 64, 64]] [16, 128, 64, 64] 512
ReLU-13 [[16, 128, 64, 64]] [16, 128, 64, 64] 0
Conv2D-28 [[16, 128, 64, 64]] [16, 128, 64, 64] 147,456
BatchNorm2D-28 [[16, 128, 64, 64]] [16, 128, 64, 64] 512
Conv2D-26 [[16, 64, 128, 128]] [16, 128, 64, 64] 8,192
BatchNorm2D-26 [[16, 128, 64, 64]] [16, 128, 64, 64] 512
BasicBlock-11 [[16, 64, 128, 128]] [16, 128, 64, 64] 0
Conv2D-29 [[16, 128, 64, 64]] [16, 128, 64, 64] 147,456
BatchNorm2D-29 [[16, 128, 64, 64]] [16, 128, 64, 64] 512
ReLU-14 [[16, 128, 64, 64]] [16, 128, 64, 64] 0
Conv2D-30 [[16, 128, 64, 64]] [16, 128, 64, 64] 147,456
BatchNorm2D-30 [[16, 128, 64, 64]] [16, 128, 64, 64] 512
BasicBlock-12 [[16, 128, 64, 64]] [16, 128, 64, 64] 0
Conv2D-32 [[16, 128, 64, 64]] [16, 256, 32, 32] 294,912
BatchNorm2D-32 [[16, 256, 32, 32]] [16, 256, 32, 32] 1,024
ReLU-15 [[16, 256, 32, 32]] [16, 256, 32, 32] 0
Conv2D-33 [[16, 256, 32, 32]] [16, 256, 32, 32] 589,824
BatchNorm2D-33 [[16, 256, 32, 32]] [16, 256, 32, 32] 1,024
Conv2D-31 [[16, 128, 64, 64]] [16, 256, 32, 32] 32,768
BatchNorm2D-31 [[16, 256, 32, 32]] [16, 256, 32, 32] 1,024
BasicBlock-13 [[16, 128, 64, 64]] [16, 256, 32, 32] 0
Conv2D-34 [[16, 256, 32, 32]] [16, 256, 32, 32] 589,824
BatchNorm2D-34 [[16, 256, 32, 32]] [16, 256, 32, 32] 1,024
ReLU-16 [[16, 256, 32, 32]] [16, 256, 32, 32] 0
Conv2D-35 [[16, 256, 32, 32]] [16, 256, 32, 32] 589,824
BatchNorm2D-35 [[16, 256, 32, 32]] [16, 256, 32, 32] 1,024
BasicBlock-14 [[16, 256, 32, 32]] [16, 256, 32, 32] 0
Conv2D-37 [[16, 256, 32, 32]] [16, 512, 16, 16] 1,179,648
BatchNorm2D-37 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,048
ReLU-17 [[16, 512, 16, 16]] [16, 512, 16, 16] 0
Conv2D-38 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,359,296
BatchNorm2D-38 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,048
Conv2D-36 [[16, 256, 32, 32]] [16, 512, 16, 16] 131,072
BatchNorm2D-36 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,048
BasicBlock-15 [[16, 256, 32, 32]] [16, 512, 16, 16] 0
Conv2D-39 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,359,296
BatchNorm2D-39 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,048
ReLU-18 [[16, 512, 16, 16]] [16, 512, 16, 16] 0
Conv2D-40 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,359,296
BatchNorm2D-40 [[16, 512, 16, 16]] [16, 512, 16, 16] 2,048
BasicBlock-16 [[16, 512, 16, 16]] [16, 512, 16, 16] 0
AdaptiveAvgPool2D-2 [[16, 512, 16, 16]] [16, 512, 1, 1] 0
ResNet-2 [[16, 3, 512, 512]] [16, 512, 1, 1] 0
Linear-2 [[16, 512]] [16, 8] 4,104
===============================================================================
Total params: 11,190,216
Trainable params: 11,171,016
Non-trainable params: 19,200
-------------------------------------------------------------------------------
Input size (MB): 48.00
Forward/backward pass size (MB): 4768.13
Params size (MB): 42.69
Estimated Total Size (MB): 4858.81
-------------------------------------------------------------------------------
2.4. 训练
model = MyNet()
model.train()
if 1:
try:
# 接续之前的模型重复训练
param_dict = paddle.load('./model.pdparams')
model.load_dict(param_dict)
except:
print('no such model file')
train_dataset=MyDateset()
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
drop_last=False)
max_epoch=10
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00000001, T_max=max_epoch)
opt = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
now_step=0
for epoch in range(max_epoch):
for step, data in enumerate(train_dataloader):
now_step+=1
img, label = data
pre = model(img)
loss = paddle.nn.functional.mse_loss(pre, label).mean()
loss.backward()
opt.step()
opt.clear_gradients()
if now_step%100==0:
print("epoch: {}, batch: {}, loss is: {}".format(epoch, step, loss.mean().numpy()))
paddle.save(model.state_dict(), 'model.pdparams')
2.5. 保存
# 保存
model = MyNet()
param_dict = paddle.load('./model.pdparams')
model.load_dict(param_dict)
model.eval()
paddle.jit.save(
layer=model,
path='mymodel/model',
input_spec=[paddle.static.InputSpec(shape=[None, 3, 512, 512], dtype='float32')])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:341: UserWarning: /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/vision/models/resnet.py:105
The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.
op_type, op_type, EXPRESSION_MAP[method_name]))
2.6. 准备predict.py文件
predict.py文件内容如下
# 代码示例
# python predict.py [src_image_dir] [results]
import os
import sys
import glob
import json
import cv2
import paddle
def process(src_image_dir, save_dir):
# load
model = paddle.jit.load('./mymodel/model')
model.eval()
image_paths = glob.glob(os.path.join(src_image_dir, "*.jpg"))
result = {}
for image_path in image_paths:
filename = os.path.split(image_path)[1]
# do something
img = cv2.imread(image_path)
# get img infor
h, w, c = img.shape
# pre-process of img
img = paddle.vision.transforms.resize(img, (512,512), interpolation='bilinear')
img = img.transpose((2,0,1))
img = img/255
img = paddle.to_tensor(img).astype('float32')
img = paddle.reshape(img, [1]+img.shape)
pre = model(img)
pre = pre[0]
pre[pre>1]=1
pre[pre<0]=0
pre = pre.tolist()
x1, y1, x2, y2, x3, y3, x4, y4 = pre
x1, x2, x3, x4 = [int(x*w) for x in [x1, x2, x3, x4]]
y1, y2, y3, y4 = [int(y*h) for y in [y1, y2, y3, y4]]
xmin = min(x1,x2,x3,x4)
xmax = max(x1,x2,x3,x4)
ymin = min(y1,y2,y3,y4)
ymax = max(y1,y2,y3,y4)
if filename not in result:
result[filename] = []
result[filename].append({
"box": [xmin, ymin, xmax, ymax],
"lb": [x1, y1],
"lt": [x2, y2],
"rt": [x3, y3],
"rb": [x4, y4],
})
with open(os.path.join(save_dir, "result.txt"), 'w', encoding="utf-8") as f:
f.write(json.dumps(result))
if __name__ == "__main__":
assert len(sys.argv) == 3
src_image_dir = sys.argv[1]
save_dir = sys.argv[2]
if not os.path.exists(save_dir):
os.makedirs(save_dir)
process(src_image_dir, save_dir)
# 运行代码查看预测结果
! python predict.py pubtest/imgs test_A_result
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/setuptools/depends.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
import imp
2.7. 打包提交
! zip -r submit.zip predict.py mymodel
3. 总结
本项目极简地通过ResNet18进行了坐标回归,如果追求更高的效果,可以考虑使用检测模型,具体可以参考百度网盘AI大赛——表格检测进阶:表格的结构化 Baseline
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接
更多推荐
已为社区贡献1438条内容
所有评论(0)