一、前言

本项目为百度论文复现赛《From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network》论文复现代码。

依赖环境:

  • paddlepaddle-gpu2.3.1
  • python3.7

复现精度:

MethodsIIIT5KIC13SVTIC15SVTPCUTE
论文95.895.791.783.786.088.5
官方repo95.996.390.784.185.388.9
复现repo95.996.390.984.185.489.2

二、模型背景及其介绍

这篇场景文本检测论文有别于以往的分步两阶段工作需要先进行视觉预测再利用语言模型纠正的策略,该工作提出了视觉语言网络 Vision-LAN,直接赋予视觉模型语言能力,将视觉和语言模型当作一个整体。由于语言信息是和视觉特征一同获取的,不需要额外的语言模型,Vision-LAN显著提高39%的前向速度,并且能够自适应考虑语言信息来增强视觉特征,进而达到更高的识别准确率。

1、 Vision-LAN模型框架

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YEEJApWW-1665148595347)(https://ai-studio-static-online.cdn.bcebos.com/e9101ae52a314bbfa1b9fdb454ef0463131e65e4b60a45798c1b48c715fcb761)]

如上图,Vision-LAN模型包括三部分,特征提取网络,掩码语言感知模块(Masked Language Aware Module)和视觉推理模块(Visual Reasoning Module)。训练阶段,通过特征提取网络得到视觉特征,接着MLM模块输入视觉特征以及字符索引,通过弱监督的方法在对应字符索引的位置生成掩码Mask。该模块主要用来模拟视觉信息字符遮挡的情况。VRM模块输入带遮挡的文本图片,通过在视觉空间捕获长距离的信息,预测对应的文本行识别内容。在测试阶段,移除MLM模块,只使用VRM模块用于文本识别。由于无需额外的语言模型即可获取语言信息和视觉特征,Vision-LAN 可以零计算成本即获得语言信息。

2、掩码语言感知模块

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HQI83BMr-1665148595348)(https://ai-studio-static-online.cdn.bcebos.com/9d07e465483b4d69ad582f92acd4f0202679049a8c154ec0add54353c7135efe)]

如上图,为了引导掩码模块的学习,设计了两个额外的分支。第一个分支,将特征图和对应的字符mask相乘,得到遮挡字符的特征图;第二个分支,将特征图和1-mask相乘,得到未被遮挡的字符特征图。通过这两个分支使用交叉熵监督训练,使得mask区域只遮挡第i个字符的位置,而不交叠到其它的字符区域。

3、视觉推理模块

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qLXo0OCu-1665148595349)(https://ai-studio-static-online.cdn.bcebos.com/6dd6a59d361948af9cb23c8eee82a87cd1580f2156c74475af2f4277af818379)]

如上图,视觉推理模块(VRM)通过一个结构同时建模视觉信息和语言信息,通过使用视觉上下文中的字符信息,进而从被遮挡的特征中预测出字符。VRM模块包括视觉语义推理(VSR)和并行预测层(PP)。不同于使用Transformer单元进行纯语言建模,VRM中的Transformer单元用于序列建模,不会受单词长度影响。

官方参考github项目复现github项目

三、数据集

MJSynth和SynthText都是合成生成的数据集,其中单词实例被放置在自然场景图像中,同时考虑到场景布局。 SynthText数据集由80万张图像和大约800万合成词实例组成。每个文本实例都用其文本字符串、单词级和字符级的边界框进行注释。

MJSynth和SynthText数据集下载:训练集 分别解压后放在./datasets/train/下

评估数据集下载:测试集 解压后放在./datasets/下

数据目录的结构为:

datasets
├── evaluation
│   ├── Sumof6benchmarks
│   ├── CUTE
│   ├── IC13
│   ├── IC15
│   ├── IIIT5K
│   ├── SVT
│   └── SVTP
└── train
    ├── MJSynth
    └── SynText

四、模型运行

1、加载环境

%cd /home/aistudio/work/Paddle-VisionLAN/
!pip install -r requirements.txt

2、解压数据集

# 解压MJSynth和SynText训练集
%cd /home/aistudio/work/Paddle-VisionLAN/datasets/train
!unzip -oq /home/aistudio/data/data168907/MJSynth.zip
!unzip -oq /home/aistudio/data/data168907/SynText.zip
/home/aistudio/work/Paddle-VisionLAN/datasets/train
# 解压测试集
%cd /home/aistudio/work/Paddle-VisionLAN/datasets/
!unzip -oq /home/aistudio/data/data168908/evaluation.zip
/home/aistudio/work/Paddle-VisionLAN/datasets

3、训练

训练分为两个过程:Language-free (LF) process和Language-aware (LA) process

模型训练权重保存到./output文件下, 训练日志保存到./logs文件下

可以将训练好的模型权重下载 解压后放在本repo/下,直接运行评估和预测。

# Step 1 (LF_1): first train the vision model without MLM
%cd /home/aistudio/work/Paddle-VisionLAN/
!python train.py --cfg_type LF_1 --batch_size 384 --epochs 8 --output_dir './output/LF_1/'
# Step 2 (LF_2): finetune the MLM with vision model.
%cd /home/aistudio/work/Paddle-VisionLAN/
!python train.py --cfg_type LF_2 --batch_size 220 --epochs 4 --output_dir './output/LF_2/' --pretrained './output/LF_1/best_acc_M.pdparams'
%cd /home/aistudio/work/Paddle-VisionLAN/
!python train.py --cfg_type LA --batch_size 220 --epochs 8 --output_dir './output/LA/' --pretrained './output/LF_2/best_acc_M.pdparams'

4、测试

通过Language-aware (LA) process训练后,执行测试。

%cd /home/aistudio/work/Paddle-VisionLAN/
!python eval.py

5、预测

加载./output/LA/best_acc_M.pdparams文件,预测如下图片demo1.png结果。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AW6IIXmE-1665148595349)(https://ai-studio-static-online.cdn.bcebos.com/9a1b5bdbcd844f648bb521713c03f37abad81fad54ce4a21bd8139606f2399a3)]

%cd /home/aistudio/work/Paddle-VisionLAN/
!python predict.py --img_file './images/demo1.png'
/home/aistudio/work/Paddle-VisionLAN
W0926 17:19:11.611407 49570 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0926 17:19:11.614293 49570 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
pre_string: residencia

五、训练推理开发

飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。
本部分为飞桨训推一体全流程(Training and Inference Pipeline Criterion(TIPC))基础训练推理内容,读者可选择自行跳过这部分学习。

1、模型推理

包含模型动转静脚本(export_model.py)以及模型基于 Paddle Inference 的预测脚本(infer.py)。

# 基于推理引擎导出模型
%cd /home/aistudio/work/Paddle-VisionLAN
!python3 deploy/export_model.py
/home/aistudio/work/Paddle-VisionLAN
inference model has been saved into deploy
# 基于推理引擎的模型预测
%cd /home/aistudio/work/Paddle-VisionLAN
!python3 deploy/infer.py --img_path images/demo1.png
/home/aistudio/work/Paddle-VisionLAN
image_name: images/demo1.png, predict data: residencia

2、训推自动化测试

为了方便快速验证训练/评估/推理过程,建立了一个小数据集,放在lite_data文件夹下。

%cd /home/aistudio/work/Paddle-VisionLAN
!bash test_tipc/test_train_inference_python.sh test_tipc/configs/VisionLAN/train_infer_python.txt lite_train_lite_infer

六、代码解析

1、代码结构

    |--images                         # 测试使用的样例图片
    |--deploy                         # 预测部署相关
        |--export_model.py            # 导出模型
        |--infer.py                   # 部署预测
    |--datasets                       # 训练和测试数据集
    |--lite_data                      # 用于tipc的小数据集
    |--logs                           # 训练日志信息  
    |--output                         # 模型输出文件
    |--modules                        # 论文模块
        |--modules.py                 # 模块组件
        |--resnet.py                  # resnet45模型
    |--test_tipc                      # tipc代码
    |--utils                          # 工具代码
    |--VisionLAN                      # 论文模型
    |--predict.py                     # 预测代码
    |--eval.py                        # 评估代码
    |--train.py                       # 训练代码
    |----README.md                    # 用户手册

2、构造数据加载器

class lmdbDataset(Dataset):
    def __init__(self, roots=None, ratio=None, img_height=32, img_width=128, transform=None, global_state='Test'):
        super().__init__()
        self.envs = []
        self.nSamples = 0
        self.lengths = []
        self.ratio = []
        self.global_state = global_state
        for i in range(0, len(roots)):
            env = lmdb.open(
                roots[i],
                max_readers=1,
                readonly=True,
                lock=False,
                readahead=False,
                meminit=False)
            if not env:
                print('cannot creat lmdb from %s' % (roots[i]))
                sys.exit(0)

            with env.begin(write=False) as txn:
                nSamples = int(txn.get('num-samples'.encode()))
                self.nSamples += nSamples
            self.lengths.append(nSamples)
            self.envs.append(env)

        if ratio != None:
            assert len(roots) == len(ratio), 'length of ratio must equal to length of roots!'
            for i in range(0, len(roots)):
                self.ratio.append(ratio[i] / float(sum(ratio)))
        else:
            for i in range(0, len(roots)):
                self.ratio.append(self.lengths[i] / float(self.nSamples))
        self.transform = transform
        self.maxlen = max(self.lengths)
        self.img_height = img_height
        self.img_width = img_width
        self.target_ratio = img_width / float(img_width)
        self.min_size = (img_width * 0.5, img_width * 0.75, img_width)

        self.augment_tfs = transforms.Compose([
            CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
            CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
            CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
        ])

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        fromwhich = self.__fromwhich__()
        if self.global_state == 'Train':
            index = random.randint(0, self.maxlen - 1)
        index = index % self.lengths[fromwhich]
        assert index <= len(self), 'index range error'
        index += 1
        with self.envs[fromwhich].begin(write=False) as txn:
            img_key = 'image-%09d' % index
            try:
                imgbuf = txn.get(img_key.encode())
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                img = Image.open(buf).convert('RGB')
            except:
                print('Corrupted image for %d' % index)
                return self[index + 1]
            label_key = 'label-%09d' % index
            label = str(txn.get(label_key.encode()))
            # if python3
            label = str(txn.get(label_key.encode()), 'utf-8')
            label = re.sub('[^0-9a-zA-Z]+', '', label)
            if (len(label) > 25 or len(label) <= 0) and self.global_state == 'Train':
                return self[index + 1]
            try:
                img = self.keepratio_resize(img, self.global_state)
            except:
                print('Size error for %d' % index)
                return self[index + 1]
            if self.transform:
                img = self.transform(img)
            # generate masked_id masked_character remain_string
            label_res, label_sub, label_id = des_orderlabel(label)
            sample = {'image': img, 'label': label, 'label_res': label_res, 'label_sub': label_sub, 'label_id': label_id}
            return sample

3、Vision-LAN模型

class MLM(nn.Layer):
    def __init__(self, n_dim=512):
        super(MLM, self).__init__()
        self.MLM_SequenceModeling_mask = Transforme_Encoder(n_layers=2, n_position=256)
        self.MLM_SequenceModeling_WCL = Transforme_Encoder(n_layers=1, n_position=256)
        self.pos_embedding = nn.Embedding(25, 512)
        self.w0_linear = nn.Linear(1, 256)
        self.wv = nn.Linear(n_dim, n_dim)
        self.active = nn.Tanh()
        self.we = nn.Linear(n_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, label_pos, state=False):
        # transformer unit for generating mask_c
        feature_v_seq = self.MLM_SequenceModeling_mask(input, src_mask=None)[0]
        # position embedding layer
        pos_emb = self.pos_embedding(label_pos.cast('int64'))
        pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2)).transpose([0, 2, 1])
        # fusion position embedding with features V & generate mask_c
        att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
        att_map_sub = self.we(att_map_sub)  # b,256,1
        att_map_sub = self.sigmoid(att_map_sub.transpose([0, 2, 1]))  # b,1,256
        # WCL
        ## generate inputs for WCL
        f_res = input * (1 - att_map_sub.transpose([0, 2, 1]))  # second path with remaining string
        f_sub = input * (att_map_sub.transpose([0, 2, 1]))  # first path with occluded character
        ## transformer units in WCL
        f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)[0]
        f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)[0]
        return f_res, f_sub, att_map_sub


def trans_1d_2d(x):
    b, w_h, c = x.shape  # b, 256, 512
    x = x.transpose([0, 2, 1])
    x = x.reshape([b, c, 32, 8])
    x = x.transpose([0, 1, 3, 2])  # [16, 512, 8, 32]
    return x


class MLM_VRM(nn.Layer):
    def __init__(self, ):
        super(MLM_VRM, self).__init__()
        self.MLM = MLM()
        self.SequenceModeling = Transforme_Encoder(n_layers=3, n_position=256)
        self.Prediction = Prediction(n_position=256, N_max_character=26, n_class=37)  # N_max_character = 1 eos + 25 characters
        self.nclass = 37

    def forward(self, input, label_pos, training_stp, is_Train=False):
        b, c, h, w = input.shape
        nT = 25
        input = input.transpose([0, 1, 3, 2])
        input = input.reshape([b, c, h * w])
        input = input.transpose([0, 2, 1])
        if is_Train:
            if training_stp == 'LF_1':
                f_res = 0
                f_sub = 0
                input = self.SequenceModeling(input, src_mask=None)[0]
                text_pre, test_rem, text_mas = self.Prediction(input, f_res, f_sub, Train_is=True, use_mlm=False)
                return text_pre, text_pre, text_pre, text_pre
            elif training_stp == 'LF_2':
                # MLM
                f_res, f_sub, mask_c = self.MLM(input, label_pos, state=True)
                input = self.SequenceModeling(input, src_mask=None)[0]
                text_pre, test_rem, text_mas = self.Prediction(input, f_res, f_sub, Train_is=True)
                mask_c_show = trans_1d_2d(mask_c.transpose([0, 2, 1]))
                return text_pre, test_rem, text_mas, mask_c_show
            elif training_stp == 'LA':
                # MLM
                f_res, f_sub, mask_c = self.MLM(input, label_pos, state=True)
                ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
                ## ratio controls the occluded number in a batch
                ratio = 2
                character_mask = paddle.zeros_like(mask_c)
                character_mask = paddle.concat([mask_c[0:b // ratio], character_mask], 0)[:mask_c.shape[0]]
                input = input * (1 - character_mask.transpose([0, 2, 1]))
                # VRM
                ## transformer unit for VRM
                input = self.SequenceModeling(input, src_mask=None)[0]
                ## prediction layer for MLM and VSR
                text_pre, test_rem, text_mas = self.Prediction(input, f_res, f_sub, Train_is=True)
                mask_c_show = trans_1d_2d(mask_c.transpose([0, 2, 1]))
                return text_pre, test_rem, text_mas, mask_c_show
        else:  # VRM is only used in the testing stage
            f_res = 0
            f_sub = 0
            contextual_feature = self.SequenceModeling(input, src_mask=None)[0]
            C = self.Prediction(contextual_feature, f_res, f_sub, Train_is=False, use_mlm=False)
            C = C.transpose([1, 0, 2])  # (25, b, 38))
            lenText = nT
            nsteps = nT
            out_res = paddle.zeros([lenText, b, self.nclass])

            out_length = paddle.zeros([b])
            now_step = 0
            while 0 in out_length and now_step < nsteps:
                tmp_result = C[now_step, :, :]
                out_res[now_step] = tmp_result
                tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
                for j in range(b):
                    if out_length[j] == 0 and tmp_result[j] == 0:
                        out_length[j] = now_step + 1
                now_step += 1
            for j in range(0, b):
                if int(out_length[j]) == 0:
                    out_length[j] = nsteps
            start = 0
            output = paddle.zeros([int(out_length.sum()), self.nclass])
            for i in range(0, b):
                cur_length = int(out_length[i])
                output[start: start + cur_length] = out_res[0: cur_length, i, :]
                start += cur_length

            return output, out_length


class VisionLAN(nn.Layer):
    def __init__(self, strides, input_shape):
        super(VisionLAN, self).__init__()
        self.backbone = resnet.resnet45(strides)
        self.input_shape = input_shape
        self.MLM_VRM = MLM_VRM()

    def forward(self, input, label_pos, training_stp, Train_in=True):
        # extract features
        features = self.backbone(input)
        s = features[-1]
        # MLM + VRM
        if Train_in:
            text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(features[-1], label_pos, training_stp, is_Train=Train_in)
            return text_pre, test_rem, text_mas, mask_map
        else:
            output, out_length = self.MLM_VRM(features[-1], label_pos, training_stp, is_Train=Train_in)
            return output, out_length

4、训练核心代码

 for nEpoch in range(0, cfgs.global_cfgs['epoch']):
        for batch_idx, sample_batched in enumerate(train_loader):
            # data_prepare
            data = sample_batched['image']
            label = sample_batched['label']  # original string
            label_res = sample_batched['label_res']  # remaining string
            label_sub = sample_batched['label_sub']  # occluded character
            label_id = sample_batched['label_id']  # character index
            target = encdec.encode(label)
            Train_or_Eval(model, 'Train')
            label_flatten, length = flatten_label(target)
            # prediction
            text_pre, text_rem, text_mas, att_mask_sub = model(data, label_id, cfgs.global_cfgs['step'])
            # loss_calculation
            if cfgs.global_cfgs['step'] == 'LF_1':
                text_pre = _flatten(text_pre, length)
                pre_ori, label_ori = train_acc_counter.add_iter(text_pre, length.cast('int64'), length, label) 
                loss_ori = criterion_CE(text_pre, label_flatten)
                loss = loss_ori
            else:
                target_res = encdec.encode(label_res)
                target_sub = encdec.encode(label_sub)
                label_flatten_res, length_res = flatten_label(target_res)
                label_flatten_sub, length_sub = flatten_label(target_sub)
                text_pre = _flatten(text_pre, length)
                text_rem = _flatten(text_rem, length_res)
                text_mas = _flatten(text_mas, length_sub)
                pre_ori, label_ori = train_acc_counter.add_iter(text_pre, length.cast('int64'), length, label)
                pre_rem, label_rem = train_acc_counter_rem.add_iter(text_rem, length_res.cast('int64'), length_res, label_res)
                pre_sub, label_sub = train_acc_counter_sub.add_iter(text_mas, length_sub.cast('int64'), length_sub, label_sub)

                loss_ori = criterion_CE(text_pre, label_flatten)
                loss_res = criterion_CE(text_rem, label_flatten_res)
                loss_mas = criterion_CE(text_mas, label_flatten_sub)
                loss = loss_ori + loss_res * ratio_res + loss_mas * ratio_sub
                loss_ori_show += loss_res
                loss_mas_show += loss_mas
            # loss for display
            loss_show += loss
            # optimize
            Zero_Grad(model)
            loss.backward()
            optimizer.step()

5、评估核心代码

def _test(test_loader, model, tools, best_acc, string_name):
    Train_or_Eval(model, 'Eval')
    print('------' + string_name + '--------')
    for sample_batched in test_loader:
        data = sample_batched['image']
        label = sample_batched['label']
        target = tools[0].encode(label)
        label_flatten, length = tools[1](target)
        output, out_length = model(data, target, '', False)
        tools[2].add_iter(output, out_length, length, label)
    best_acc, change = tools[2].show_test(best_acc)
    Train_or_Eval(model, 'Train')
    return best_acc, change
    
def show_test(self, best_acc, change=False):
    print(self.display_string)
    if self.total_samples == 0:
        pass
    if (self.correct / self.total_samples) >= best_acc:
        best_acc = np.copy(self.correct / self.total_samples)
        change = True
    print('Accuracy: {:.6f}, AR: {:.6f}, CER: {:.6f}, WER: {:.6f}, best_acc: {:.6f}'.format(
        self.correct / self.total_samples,
        1 - self.distance_C / self.total_C,
        self.distance_C / self.total_C,
        self.distance_W / self.total_W, best_acc))

    self.clear()
    return best_acc, change

七、复现总结

作为第一个带有Language能力的Visual模型工作,论文提出了一种简洁有效的场景文本识别框架。Vision-LAN实现了从两步识别到一步识别(从二到一)的转变,在一个统一的结构中自适应地考虑视觉和语言信息,无需额外的语言模型。相比于之前的语言模型,VisionLAN在保持高效的同时展现出更强的语言能力。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐