百度网盘AI大赛——图像处理挑战赛: 通用场景手写文字擦除BaseLine

一、赛题背景说明

随着技术发展,OCR扫描在学习、办公等众多场景中被使用,通过技术和算法,对扫描获得的纸张文档上的手写笔迹还原修复,恢复文件本身的样子,使得人们的使用体验越来越便捷。上一期比赛,我们举办了试卷场景下的手写文字擦除,帮助学生党们擦除试卷上的笔迹。本次比赛,我们诚邀各位选手并拓宽场景:不限于试卷,对通用文件上的手写笔迹进行擦除后还原文件,帮助更多人解决扫描上的问题。

二、赛题数据说明及预处理

本赛题为img2img的图像处理任务,因此数据集也全部以图像格式给出。这里我挑选数据集中两个具有代表性的图像作为示例。

  • 如下图1所示为第一类,从左到右依次为手写图片,mask,真实图片。这类图片的特点是

1、尺寸小,基本上宽高都在1000像素点以内

2、数量多,这类图片数据集一共大概25G左右十分庞大

3、手写文字有时会覆盖在文档图片上方,因此不能简单的将文字区域像素设置为固定值就完事

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5cA259hj-1656682680942)(https://ai-studio-static-online.cdn.bcebos.com/2711ea44a3bc48d08456fee5871080cbbd5c0476d3a34ef3ac191fe8038648b4)]

  • 如下图2所示为第二类,从左到右依次为手写图片,mask(自己使用代码进行生成,本次比赛数据集官方没有提供这类图片的mask数据),真实图片。这类图片的特点是

1、尺寸巨大,有的宽高可以达到5000像素点,如果想要提高性能分,可以重点关注一下如何对该类图片进行处理

2、数量少,但是在A榜中的占比却高达0.4,不像训练数据集中的不到0.02。

3、是真实的拍照得到的图片,这类图片是先有手写图片,而后才有的真实图片(与第一类图片不一样,第一类是先有真实图片,而后才有的手写图片)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5bH94CJZ-1656682680943)(https://ai-studio-static-online.cdn.bcebos.com/2f9a3bfa6f44488e9f91cdce8f21ad958c250aeb6be045d891fefe35adae3f85)]

以下代码是用来生成mask的函数(不需要运行,仅作为个人想要提高得分来重新生成mask的一个参考),本项目已经将一类图片生成的mask和二类图片中的十分之一的数据打包作为数据集挂载。

个人认为:虽然机器学习定理告诉我们,训练数据量越多模型效果越好,越不容易过拟合;但这是有前提的,因为我们无法做到全批量梯度下降,真实的训练过程我们只会一次一个小batch的训练,最早期的batch对模型的梯度影响必然会被后期的batch洗掉一部分,反向传播决定了模型不能进行增量学习。所以,在显存不大的情况下,过大训练数据集对精度的提高是有限的,因此二类图片我只取了十分之一作为基线。另外这样也可以平衡训练中一二类图片的比例,对于单个模型来讲是有益的。

本项目将图片随机裁剪到512的大小作为模型的输入,并使用了SwinT接口来抽取全局特征。

总结一下:在数据处理部分,我们一共使用了三种策略, 1、缩减数据集25G–>4.1G 2、生成mask引导模型训练 3、随机裁剪至512x512大小

# 生成mask的函数如下

import os
import random

# 输入:水印图像路劲,原图路劲,保存的mask的路径
def generate_one_mask(image_path, gt_path, save_path):
    # 读取图像
    image = Image.open(image_path)
    gt = Image.open(gt_path)

    # 转成numpy数组格式
    image = 255 - np.array(image)[:, :, :3]
    gt = 255 - np.array(gt)[:, :, :3]

    # 设置阈值
    threshold = 15
    # 真实图片与手写图片做差,找出mask的位置
    diff_image = np.abs(image.astype(np.float32) - gt.astype(np.float32))  
    mean_image = np.max(diff_image, axis=-1)

    # 将mask二值化,即0和255。
    mask = np.greater(mean_image, threshold).astype(np.uint8) * 255
    mask[mask < 2] = 0
    mask[mask >= 1] = 255
    mask = 255 - mask
    mask = np.clip(mask, 0, 255)

    # 保存
    mask = np.array([mask, mask, mask, mask])
    mask = mask.transpose(1, 2, 0)
    mask = Image.fromarray(mask[:, :, :3])
    mask.save(save_path)

三、训练模型并可视化训练过程

我们模型用的是Erasenet,模型结构图如下:

Erasenet主体结构

运行以下代码块,进行训练。在基线中,提供了一个来自水印智能消除赛的预训练模型,可以加载此模型提高训练速度。训练过程中可以观察到模型输出图片与真实图片的对比,使用VisualDL可以清晰的看到psnr与loss曲线。

# 解压数据集文件
import os
if not os.path.exists('dataset/gts'):
    !unzip -oq data/data154420/dehw_train_clear.zip -d ./dataset
# 可视化
from visualdl import LogWriter

# paddle包
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import DataLoader
from dataset.data_loader import TrainDataSet, ValidDataSet

# 自定义的loss函数,包含mask的损失和image的损失
from loss.Loss import LossWithGAN_STE, LossWithSwin

# 使用SwinT增强的Erasenet
from models.swin_gan import STRnet2_change

# 其他工具
import utils
import random
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
%matplotlib inline


# 计算psnr
log = LogWriter('log')
def psnr(img1, img2):
   mse = np.mean((img1/1.0 - img2/1.0) ** 2 )
   if mse < 1.0e-10:
      return 100
   return 10 * math.log10(255.0**2/mse)


# 训练配置字典
CONFIG = {
    'modelsSavePath': 'train_models_swin_erasenet',
    'batchSize': 10,  # 模型大,batch_size调小一点防崩,拉满显存但刚好不超,就是炼丹仙人~
    'traindataRoot': 'dataset',
    'validdataRoot': 'dataset',   # 因为数据集量大,且分布一致,就直接取训练集中数据作为验证了。别问,问就是懒
    'pretrained': 'train_models_swin_erasenet/base_model.pdparams',
    'num_epochs': 100,
    'seed': 9420  # 就是爱你!~
}


# 设置随机种子
random.seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])
paddle.seed(CONFIG['seed'])
# noinspection PyProtectedMember
paddle.framework.random._manual_program_seed(CONFIG['seed'])


batchSize = CONFIG['batchSize']
if not os.path.exists(CONFIG['modelsSavePath']):
    os.makedirs(CONFIG['modelsSavePath'])

traindataRoot = CONFIG['traindataRoot']
validdataRoot = CONFIG['validdataRoot']

# 创建数据集容器
TrainData = TrainDataSet(training=True, file_path=traindataRoot)
TrainDataLoader = DataLoader(TrainData, batch_size=batchSize, shuffle=True,
                             num_workers=0, drop_last=True)
ValidData = ValidDataSet(file_path=validdataRoot)
ValidDataLoader = DataLoader(ValidData, batch_size=1, shuffle=True, num_workers=0, drop_last=True)


netG = STRnet2_change()


if CONFIG['pretrained'] is not None:
    print('loaded ')
    weights = paddle.load(CONFIG['pretrained'])
    netG.load_dict(weights)


# 开始直接上大火
lr = 2e-3
G_optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=netG.parameters())


loss_function = LossWithGAN_STE()


print('OK!')
num_epochs = CONFIG['num_epochs']
best_psnr = 0
iters = 0


for epoch_id in range(1, num_epochs + 1):

    netG.train()

    if epoch_id % 8 == 0:
        # 每8个epoch时重置优化器,学习率变为1/10,抖动式学习法
        lr /= 10
        G_optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=netG.parameters())

    for k, (imgs, gts, masks) in enumerate(TrainDataLoader):
        iters += 1

        fake_images, mm = netG(imgs)
        G_loss = loss_function(masks, fake_images, mm, gts)
        G_loss = G_loss.sum()

        #后向传播,更新参数的过程
        G_loss.backward()
        # 最小化loss,更新参数
        G_optimizer.step()
        # 清除梯度
        G_optimizer.clear_grad()

        # 打印训练信息
        if iters % 100 == 0:
            print('epoch{}, iters{}, loss:{:.5f}, lr:{}'.format(
                epoch_id, iters, G_loss.item(), G_optimizer.get_lr()
            ))
            log.add_scalar(tag="train_loss", step=iters, value=G_loss.item())

    # 对模型进行评价并保存
    netG.eval()
    val_psnr = 0

    # noinspection PyAssignmentToLoopOrWithParameter
    for index, (imgs, gt) in enumerate(ValidDataLoader):
        _, _, h, w = imgs.shape
        rh, rw = h, w
        step = 512
        pad_h = step - h if h < step else 0
        pad_w = step - w if w < step else 0
        m = nn.Pad2D((0, pad_w, 0, pad_h))
        imgs = m(imgs)
        _, _, h, w = imgs.shape
        res = paddle.zeros_like(imgs)
        mm_out = paddle.zeros_like(imgs)
        mm_in = paddle.zeros_like(imgs)

        for i in range(0, h, step):
            for j in range(0, w, step):
                if h - i < step:
                    i = h - step
                if w - j < step:
                    j = w - step
                clip = imgs[:, :, i:i + step, j:j + step]
                clip = clip.cuda()
                with paddle.no_grad():
                    g_images_clip, mm = netG(clip)
                g_images_clip = g_images_clip.cpu()
                mm = mm.cpu()
                clip = clip.cpu()
                mm_in[:, :, i:i + step, j:j + step] = mm
                g_image_clip_with_mask = clip * (1 - mm) + g_images_clip * mm
                res[:, :, i:i + step, j:j + step] = g_image_clip_with_mask


        res = res[:, :, :rh, :rw]
        # 改变通道
        output = utils.pd_tensor2img(res)
        target = utils.pd_tensor2img(gt)
        mm_in = utils.pd_tensor2img(mm_in)

        psnr_value = psnr(output, target)
        print('psnr: ', psnr_value)

        if index in [2, 3, 5, 7, 11]:
            fig = plt.figure(figsize=(20, 10),dpi=100)
            # 图一
            ax1 = fig.add_subplot(2, 2, 1)  # 1行 2列 索引为1
            ax1.imshow(output)
            # 图二
            ax2 = fig.add_subplot(2, 2, 2)
            ax2.imshow(mm_in)
            # 图三
            ax3 = fig.add_subplot(2, 2, 3)
            ax3.imshow(target)

            plt.show()

        del res
        del gt
        del target
        del output

        val_psnr += psnr_value
    ave_psnr = val_psnr / (index + 1)
    print('epoch:{}, psnr:{}'.format(epoch_id, ave_psnr))
    log.add_scalar(tag="valid_psnr", step=epoch_id, value=ave_psnr)
    paddle.save(netG.state_dict(), CONFIG['modelsSavePath'] +
                '/STE_{}_{:.4f}.pdparams'.format(epoch_id, ave_psnr
                ))
    if ave_psnr > best_psnr:
        best_psnr = ave_psnr
        paddle.save(netG.state_dict(), CONFIG['modelsSavePath'] + '/STE_best.pdparams')

四、模型预测

在模型预测时,我们将单个图片进行重叠裁剪,裁剪之后为512x512的尺寸,将这个序列依次输入网络进行预测。如下

import os
import sys
import glob
import json
import cv2


import paddle
import paddle.nn as nn
import paddle.nn.functional as F
# 加载Erasenet改
from models.swin_gan import STRnet2_change
import utils
from paddle.vision.transforms import Compose, ToTensor
from PIL import Image


# 加载我们训练到的最好的模型
netG = STRnet2_change()
weights = paddle.load('train_models_swin_erasenet/best_submit_model.pdparams')
netG.load_dict(weights)
netG.eval()


def ImageTransform():
    return Compose([ToTensor(), ])


ImgTrans = ImageTransform()


def process(src_image_dir, save_dir):
    image_paths = glob.glob(os.path.join(src_image_dir, "*.jpg"))
    for image_path in image_paths:

        # do something
        img = Image.open(image_path)
        inputImage = paddle.to_tensor([ImgTrans(img)])

        _, _, h, w = inputImage.shape
        rh, rw = h, w
        step = 512
        pad_h = step - h if h < step else 0
        pad_w = step - w if w < step else 0
        m = nn.Pad2D((0, pad_w, 0, pad_h))
        imgs = m(inputImage)
        _, _, h, w = imgs.shape
        res = paddle.zeros_like(imgs)

        for i in range(0, h, step):
            for j in range(0, w, step):
                if h - i < step:
                    i = h - step
                if w - j < step:
                    j = w - step
                clip = imgs[:, :, i:i + step, j:j + step]
                clip = clip.cuda()
                with paddle.no_grad():
                    g_images_clip, mm = netG(clip)
                g_images_clip = g_images_clip.cpu()
                mm = mm.cpu()
                clip = clip.cpu()
                g_image_clip_with_mask = g_images_clip * mm + clip * (1 - mm)
                res[:, :, i:i + step, j:j + step] = g_image_clip_with_mask
                del g_image_clip_with_mask, g_images_clip, mm, clip
        res = res[:, :, :rh, :rw]
        output = utils.pd_tensor2img(res)

        # 保存结果图片
        save_path = os.path.join(save_dir, os.path.basename(image_path).replace(".jpg", ".png"))
        cv2.imwrite(save_path, output)
        del output, res
        

if __name__ == "__main__":
    assert len(sys.argv) == 3

    src_image_dir = sys.argv[1]
    save_dir = sys.argv[2]

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    process(src_image_dir, save_dir)

五、打包提交

  • 1、在本地新建models文件夹,将本项目中models文件夹下文件全部下载放入。

  • 2、在本地新建train_models_swin_erasenet文件夹,将本项目中train_models_swin_erasenet文件夹下文件全部下载放入。

  • 3、下载utils.py文件

  • 4、下载predict.py文件

完成之后打包成submit.zip文件,然后就可以进行提交,得分为0.67分左右,psnr为35.6。

六、提分思路(不一定能提分,但可以试)

1、将一类图片和二类图片各单独使用一个模型,如一类图片我们可以使用手写文字擦除挑战赛那边最好的模型。由于一类图片和二类图片良好的可区分性,给这个增加了可实施的可能性。

2、使用更多的二类图片数据,使用旋转、缩放等更多形式的增广,改进网络模型

3、改进mask的生成,例如使用腐蚀和膨胀操作,调整真实图片和手写图片做差的阈值。

els文件夹下文件全部下载放入。

  • 2、在本地新建train_models_swin_erasenet文件夹,将本项目中train_models_swin_erasenet文件夹下文件全部下载放入。

  • 3、下载utils.py文件

  • 4、下载predict.py文件

完成之后打包成submit.zip文件,然后就可以进行提交,得分为0.67分左右,psnr为35.6。

六、提分思路(不一定能提分,但可以试)

1、将一类图片和二类图片各单独使用一个模型,如一类图片我们可以使用手写文字擦除挑战赛那边最好的模型。由于一类图片和二类图片良好的可区分性,给这个增加了可实施的可能性。

2、使用更多的二类图片数据,使用旋转、缩放等更多形式的增广,改进网络模型

3、改进mask的生成,例如使用腐蚀和膨胀操作,调整真实图片和手写图片做差的阈值。

原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4282393

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐