百度网盘AI大赛——表格检测

基于Resnet18回归表格的四个拐角坐标,本项目附带一个可以直接提交的样例,分数20+。

比赛链接

1. 比赛介绍

生活中,扫描技术越来越常见,通过手机就能将图片转化为可编辑的文档等;但是现在的技术在处理带有表格类型的文字的时候往往没有那么灵敏,把完整表格拆分成难以使用的零散个体似乎很常见又令人苦恼。本次比赛旨在解决这个问题,通过万能的算法,准确地识别表格在图片中的位置并标注。

1.1. 数据介绍

本次比赛最新发布的数据集共包含训练集、A榜测试集、B榜测试集三个部分,其中训练集共10000张图片,A榜测试集共500张图片,B榜测试集500张图片;imgs目录下为所有训练图片;annos.txt 为标注文件,json格式,格式示例如下:

{
    "a.jpg": [  # 图片文件名称
        {
            'box': [xmin, ymin, xmax,ymax],  # 表格box位置,(xmin, ymin)为box左上点,(xmax, ymax)为box右下点
            'lb': [x, y],  # 表格left bottom点,即左下顶点
            'lt': [x, y],  # 表格left top点,即左上顶点
            'rt': [x, y],  # 表格right top点,即右上顶点
            'rb': [x, y]  # 表格right bottom点,即右下顶点
        }],
    
    "b.jpg": [{
        'box': [xmin, ymin, xmax, ymax],
        'lb': [x, y],
        'lt': [x, y],
        'rt': [x, y],
        'rb': [x, y]
    }]
}

1.2. 赛题分析

查看数据可知,本次代标注的图片中仅有一个表格,因此可以看作是一个回归问题。

2. 代码

2.1. 准备数据

! unzip -oq data/data182385/train.zip
! unzip -oq data/data182385/testA.zip

2.2. 创建数据读取器

import json
import paddle
import os
import cv2

class MyDateset(paddle.io.Dataset):
    def __init__(self, txt_dir = 'train/annos.txt', root_dir = 'train/imgs', mode = 'train'):
        super(MyDateset, self).__init__()

        self.mode = mode
        self.root_dir = root_dir
        with open('train/annos.txt', 'r') as f:
            self.data = json.load(f)
            self.name_list = list(self.data.keys())
    
    def __getitem__(self, index):
        name = self.name_list[index]
        img_dir = os.path.join(self.root_dir, name)

        img = cv2.imread(img_dir)
        # get img infor
        h, w, c = img.shape
        # pre-process of img
        img = paddle.vision.transforms.resize(img, (512,512), interpolation='bilinear')
        img = img.transpose((2,0,1))
        img = img/255
        img = paddle.to_tensor(img).astype('float32')

        # pre-process of labels
        data = self.data[name][0] # get box infor
        labels = []
        for pos in ['lb','lt','rt','rb']:
            labels += [data[pos][0]/w, data[pos][1]/h] 
        labels = paddle.to_tensor(labels).astype('float32')
        
        return img,labels

    def __len__(self):
        return len(self.name_list)

if 1:
    train_dataset=MyDateset('train.txt')

    train_dataloader = paddle.io.DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        drop_last=False)

    for step, data in enumerate(train_dataloader):
        data, label = data
        print(step, data.shape, label.shape)
        break
W1222 12:38:52.568889   851 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1222 12:38:52.572556   851 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


0 [16, 3, 512, 512] [16, 8]

2.3. 构建网络

对于回归问题,可以简单使用Resnet进行处理。

class MyNet(paddle.nn.Layer):
    def __init__(self):
        super(MyNet,self).__init__()
        self.resnet = paddle.vision.models.resnet18(pretrained = True, num_classes = 0)
        self.fc = paddle.nn.Linear(in_features=512, out_features=8)

    def forward(self,x):
        x = self.resnet(x)
        x = paddle.flatten(x, 1)
        x = self.fc(x)
        return x

if 1:
    paddle.summary(MyNet(),(16, 3, 512, 512))
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
     Conv2D-21      [[16, 3, 512, 512]]   [16, 64, 256, 256]       9,408     
  BatchNorm2D-21    [[16, 64, 256, 256]]  [16, 64, 256, 256]        256      
      ReLU-10       [[16, 64, 256, 256]]  [16, 64, 256, 256]         0       
    MaxPool2D-2     [[16, 64, 256, 256]]  [16, 64, 128, 128]         0       
     Conv2D-22      [[16, 64, 128, 128]]  [16, 64, 128, 128]      36,864     
  BatchNorm2D-22    [[16, 64, 128, 128]]  [16, 64, 128, 128]        256      
      ReLU-11       [[16, 64, 128, 128]]  [16, 64, 128, 128]         0       
     Conv2D-23      [[16, 64, 128, 128]]  [16, 64, 128, 128]      36,864     
  BatchNorm2D-23    [[16, 64, 128, 128]]  [16, 64, 128, 128]        256      
   BasicBlock-9     [[16, 64, 128, 128]]  [16, 64, 128, 128]         0       
     Conv2D-24      [[16, 64, 128, 128]]  [16, 64, 128, 128]      36,864     
  BatchNorm2D-24    [[16, 64, 128, 128]]  [16, 64, 128, 128]        256      
      ReLU-12       [[16, 64, 128, 128]]  [16, 64, 128, 128]         0       
     Conv2D-25      [[16, 64, 128, 128]]  [16, 64, 128, 128]      36,864     
  BatchNorm2D-25    [[16, 64, 128, 128]]  [16, 64, 128, 128]        256      
   BasicBlock-10    [[16, 64, 128, 128]]  [16, 64, 128, 128]         0       
     Conv2D-27      [[16, 64, 128, 128]]  [16, 128, 64, 64]       73,728     
  BatchNorm2D-27    [[16, 128, 64, 64]]   [16, 128, 64, 64]         512      
      ReLU-13       [[16, 128, 64, 64]]   [16, 128, 64, 64]          0       
     Conv2D-28      [[16, 128, 64, 64]]   [16, 128, 64, 64]       147,456    
  BatchNorm2D-28    [[16, 128, 64, 64]]   [16, 128, 64, 64]         512      
     Conv2D-26      [[16, 64, 128, 128]]  [16, 128, 64, 64]        8,192     
  BatchNorm2D-26    [[16, 128, 64, 64]]   [16, 128, 64, 64]         512      
   BasicBlock-11    [[16, 64, 128, 128]]  [16, 128, 64, 64]          0       
     Conv2D-29      [[16, 128, 64, 64]]   [16, 128, 64, 64]       147,456    
  BatchNorm2D-29    [[16, 128, 64, 64]]   [16, 128, 64, 64]         512      
      ReLU-14       [[16, 128, 64, 64]]   [16, 128, 64, 64]          0       
     Conv2D-30      [[16, 128, 64, 64]]   [16, 128, 64, 64]       147,456    
  BatchNorm2D-30    [[16, 128, 64, 64]]   [16, 128, 64, 64]         512      
   BasicBlock-12    [[16, 128, 64, 64]]   [16, 128, 64, 64]          0       
     Conv2D-32      [[16, 128, 64, 64]]   [16, 256, 32, 32]       294,912    
  BatchNorm2D-32    [[16, 256, 32, 32]]   [16, 256, 32, 32]        1,024     
      ReLU-15       [[16, 256, 32, 32]]   [16, 256, 32, 32]          0       
     Conv2D-33      [[16, 256, 32, 32]]   [16, 256, 32, 32]       589,824    
  BatchNorm2D-33    [[16, 256, 32, 32]]   [16, 256, 32, 32]        1,024     
     Conv2D-31      [[16, 128, 64, 64]]   [16, 256, 32, 32]       32,768     
  BatchNorm2D-31    [[16, 256, 32, 32]]   [16, 256, 32, 32]        1,024     
   BasicBlock-13    [[16, 128, 64, 64]]   [16, 256, 32, 32]          0       
     Conv2D-34      [[16, 256, 32, 32]]   [16, 256, 32, 32]       589,824    
  BatchNorm2D-34    [[16, 256, 32, 32]]   [16, 256, 32, 32]        1,024     
      ReLU-16       [[16, 256, 32, 32]]   [16, 256, 32, 32]          0       
     Conv2D-35      [[16, 256, 32, 32]]   [16, 256, 32, 32]       589,824    
  BatchNorm2D-35    [[16, 256, 32, 32]]   [16, 256, 32, 32]        1,024     
   BasicBlock-14    [[16, 256, 32, 32]]   [16, 256, 32, 32]          0       
     Conv2D-37      [[16, 256, 32, 32]]   [16, 512, 16, 16]      1,179,648   
  BatchNorm2D-37    [[16, 512, 16, 16]]   [16, 512, 16, 16]        2,048     
      ReLU-17       [[16, 512, 16, 16]]   [16, 512, 16, 16]          0       
     Conv2D-38      [[16, 512, 16, 16]]   [16, 512, 16, 16]      2,359,296   
  BatchNorm2D-38    [[16, 512, 16, 16]]   [16, 512, 16, 16]        2,048     
     Conv2D-36      [[16, 256, 32, 32]]   [16, 512, 16, 16]       131,072    
  BatchNorm2D-36    [[16, 512, 16, 16]]   [16, 512, 16, 16]        2,048     
   BasicBlock-15    [[16, 256, 32, 32]]   [16, 512, 16, 16]          0       
     Conv2D-39      [[16, 512, 16, 16]]   [16, 512, 16, 16]      2,359,296   
  BatchNorm2D-39    [[16, 512, 16, 16]]   [16, 512, 16, 16]        2,048     
      ReLU-18       [[16, 512, 16, 16]]   [16, 512, 16, 16]          0       
     Conv2D-40      [[16, 512, 16, 16]]   [16, 512, 16, 16]      2,359,296   
  BatchNorm2D-40    [[16, 512, 16, 16]]   [16, 512, 16, 16]        2,048     
   BasicBlock-16    [[16, 512, 16, 16]]   [16, 512, 16, 16]          0       
AdaptiveAvgPool2D-2 [[16, 512, 16, 16]]    [16, 512, 1, 1]           0       
     ResNet-2       [[16, 3, 512, 512]]    [16, 512, 1, 1]           0       
     Linear-2           [[16, 512]]            [16, 8]             4,104     
===============================================================================
Total params: 11,190,216
Trainable params: 11,171,016
Non-trainable params: 19,200
-------------------------------------------------------------------------------
Input size (MB): 48.00
Forward/backward pass size (MB): 4768.13
Params size (MB): 42.69
Estimated Total Size (MB): 4858.81
-------------------------------------------------------------------------------

2.4. 训练

model = MyNet()
model.train()

if 1:
    try:
        # 接续之前的模型重复训练
        param_dict = paddle.load('./model.pdparams')
        model.load_dict(param_dict)
    except:
        print('no such model file')

train_dataset=MyDateset()
train_dataloader = paddle.io.DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    drop_last=False)

max_epoch=10
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00000001, T_max=max_epoch)
opt = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())

now_step=0
for epoch in range(max_epoch):
    for step, data in enumerate(train_dataloader):
        now_step+=1

        img, label = data
        pre = model(img)
        loss = paddle.nn.functional.mse_loss(pre, label).mean()

        loss.backward()
        opt.step()
        opt.clear_gradients()
        if now_step%100==0:
            print("epoch: {}, batch: {}, loss is: {}".format(epoch, step, loss.mean().numpy()))

paddle.save(model.state_dict(), 'model.pdparams')

2.5. 保存

# 保存
model = MyNet()
param_dict = paddle.load('./model.pdparams')
model.load_dict(param_dict)
model.eval()

paddle.jit.save(
    layer=model,
    path='mymodel/model',
    input_spec=[paddle.static.InputSpec(shape=[None, 3, 512, 512], dtype='float32')])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:341: UserWarning: /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/vision/models/resnet.py:105
The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.
  op_type, op_type, EXPRESSION_MAP[method_name]))

2.6. 准备predict.py文件

predict.py文件内容如下

# 代码示例
# python predict.py [src_image_dir] [results]

import os
import sys
import glob
import json
import cv2
import paddle


def process(src_image_dir, save_dir):
    # load
    model = paddle.jit.load('./mymodel/model')
    model.eval()

    image_paths = glob.glob(os.path.join(src_image_dir, "*.jpg"))
    result = {}
    for image_path in image_paths:
        filename = os.path.split(image_path)[1]
        # do something
        img = cv2.imread(image_path)
        # get img infor
        h, w, c = img.shape
        # pre-process of img
        img = paddle.vision.transforms.resize(img, (512,512), interpolation='bilinear')
        img = img.transpose((2,0,1))
        img = img/255
        img = paddle.to_tensor(img).astype('float32')
        img = paddle.reshape(img, [1]+img.shape)

        pre = model(img)
        pre = pre[0]
        pre[pre>1]=1
        pre[pre<0]=0
        pre = pre.tolist()
        x1, y1, x2, y2, x3, y3, x4, y4 = pre
        x1, x2, x3, x4 = [int(x*w) for x in [x1, x2, x3, x4]]
        y1, y2, y3, y4 = [int(y*h) for y in [y1, y2, y3, y4]]

        xmin = min(x1,x2,x3,x4)
        xmax = max(x1,x2,x3,x4)
        ymin = min(y1,y2,y3,y4)
        ymax = max(y1,y2,y3,y4)

        if filename not in result:
            result[filename] = []
        result[filename].append({
            "box": [xmin, ymin, xmax, ymax],
            "lb": [x1, y1],
            "lt": [x2, y2],
            "rt": [x3, y3],
            "rb": [x4, y4],
        })
    with open(os.path.join(save_dir, "result.txt"), 'w', encoding="utf-8") as f:
        f.write(json.dumps(result))


if __name__ == "__main__":
    assert len(sys.argv) == 3

    src_image_dir = sys.argv[1]
    save_dir = sys.argv[2]

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    process(src_image_dir, save_dir)
# 运行代码查看预测结果
! python predict.py pubtest/imgs test_A_result
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/setuptools/depends.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp

2.7. 打包提交

! zip -r submit.zip predict.py mymodel

3. 总结

本项目极简地通过ResNet18进行了坐标回归,如果追求更高的效果,可以考虑使用检测模型,具体可以参考百度网盘AI大赛——表格检测进阶:表格的结构化 Baseline

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐