1. 引入

  • 随着 Vision Transformer 的流行,越来越多计算机视觉任务开始采用 Transformer 来实现,也都取得了不错的效果

  • 本次就来分享两个基于 Swin Transformer 的图像恢复 / 图像超分辨率模型

2. 效果演示

原图SwinIR-M 2倍
原图SwinIR-M 4倍
原图SwinIR-L 4倍
原图Swin2SR 4倍

3. 快速体验

  • 可以使用 PaddleHub 快速调用这两个图像超分辨率预训练模型

  • 可用的模型列表如下:

    模型名称网络结构放大倍率
    swin2sr_real_sr_x4Swin2SR4
    swinir_m_real_sr_x2SwinIR-M2
    swinir_m_real_sr_x4SwinIR-M4
    swinir_l_real_sr_x4SwinIR-L4
  • 注:随着输入图像的分辨率和放大倍率的提高,运行所需的显存 / 内存占用会随之增大。

# hub run 模型名称 --input_path 输入图片路径 --output_dir 输出文件目录
!hub run swin2sr_real_sr_x4 --input_path images/test.jpeg --output_dir outputs
import paddlehub as hub

# 加载模型
module = hub.Module(name="swin2sr_real_sr_x4")

# 图像超分辨率
result = module.real_sr(
    image='images/test.jpeg', # 输入图片路径
    visualization=True, # 是否可视化
    output_dir='outputs' # 输出文件目录
)

4. 参考资料

5. 模型结构

  • 两个模型的大致结构如下图所示:

    SwinIR

    Swin2SR
  • 总体结构与绝大多数图像恢复或图像超分辨率模型区别不是很大,只是将 Swin Transformer 融入进模型的特征提取与转换网络中。

  • 更多具体的细节可以参考论文或其官方代码实现。

6. 代码实现

  • 由于模型代码长度太长,不太便于展示在代码块中,具体的代码请参考 swinsr 目录中的文件。
import cv2
import paddle

from swinsr import preprocess, postprocess
from swinsr import swin2sr_real_sr_x4, swinir_l_real_sr_x4, swinir_m_real_sr_x2, swinir_m_real_sr_x4


model_dict = {
    'swin2sr_real_sr_x4': {
        'model': swin2sr_real_sr_x4,
        'ckpt': 'swinsr/pretrained_models/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pdparams'
    },
    'swinir_l_real_sr_x4': {
        'model': swinir_l_real_sr_x4,
        'ckpt': 'swinsr/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pdparams'
    },
    'swinir_m_real_sr_x4': {
        'model': swinir_m_real_sr_x4,
        'ckpt': 'swinsr/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pdparams'
    },
    'swinir_m_real_sr_x2': {
        'model': swinir_m_real_sr_x2,
        'ckpt': 'swinsr/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x2_GAN.pdparams'
    }
}


class SwinSR:
    def __init__(self, name='swin2sr_real_sr_x4'):
        self.model = model_dict[name]['model'](ckpt=model_dict[name]['ckpt'])
        self.model.eval()
    
    def __call__(self, image_path, output_path):
        image = cv2.imread(image_path)

        with paddle.no_grad():
            img_input = preprocess(image)
            img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)

            img_output = self.model(img_input)
            img_output = img_output.numpy()[0]
            img_output = postprocess(img_output)

        cv2.imwrite(output_path, img_output)
# 加载模型
model = SwinSR(name='swin2sr_real_sr_x4') # 模型名称

# 图像超分辨率
model(
    image_path='images/test.jpeg', # 输入图像路径
    output_path='outputs/swin2sr_real_sr_x4.jpg' # 输出图像路径
)

7. 小结

  • 简单介绍了两个基于 Swin Transformer 的图像恢复 / 图像超分辨率模型

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐