基于 Swin Transformer 的图像超分辨率模型
简单介绍两个基于 Swin Transformer 的图像超分辨率模型,并加载官方预训练模型完成模型推理。
·
1. 引入
-
随着 Vision Transformer 的流行,越来越多计算机视觉任务开始采用 Transformer 来实现,也都取得了不错的效果
-
本次就来分享两个基于 Swin Transformer 的图像恢复 / 图像超分辨率模型
2. 效果演示
原图 | SwinIR-M 2倍 |
---|---|
原图 | SwinIR-M 4倍 |
---|---|
原图 | SwinIR-L 4倍 |
---|---|
原图 | Swin2SR 4倍 |
---|---|
3. 快速体验
-
可以使用 PaddleHub 快速调用这两个图像超分辨率预训练模型
-
可用的模型列表如下:
模型名称 网络结构 放大倍率 swin2sr_real_sr_x4 Swin2SR 4 swinir_m_real_sr_x2 SwinIR-M 2 swinir_m_real_sr_x4 SwinIR-M 4 swinir_l_real_sr_x4 SwinIR-L 4 -
注:随着输入图像的分辨率和放大倍率的提高,运行所需的显存 / 内存占用会随之增大。
# hub run 模型名称 --input_path 输入图片路径 --output_dir 输出文件目录
!hub run swin2sr_real_sr_x4 --input_path images/test.jpeg --output_dir outputs
import paddlehub as hub
# 加载模型
module = hub.Module(name="swin2sr_real_sr_x4")
# 图像超分辨率
result = module.real_sr(
image='images/test.jpeg', # 输入图片路径
visualization=True, # 是否可视化
output_dir='outputs' # 输出文件目录
)
4. 参考资料
-
论文:
-
参考实现:
5. 模型结构
-
两个模型的大致结构如下图所示:
SwinIR Swin2SR -
总体结构与绝大多数图像恢复或图像超分辨率模型区别不是很大,只是将 Swin Transformer 融入进模型的特征提取与转换网络中。
-
更多具体的细节可以参考论文或其官方代码实现。
6. 代码实现
- 由于模型代码长度太长,不太便于展示在代码块中,具体的代码请参考 swinsr 目录中的文件。
import cv2
import paddle
from swinsr import preprocess, postprocess
from swinsr import swin2sr_real_sr_x4, swinir_l_real_sr_x4, swinir_m_real_sr_x2, swinir_m_real_sr_x4
model_dict = {
'swin2sr_real_sr_x4': {
'model': swin2sr_real_sr_x4,
'ckpt': 'swinsr/pretrained_models/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pdparams'
},
'swinir_l_real_sr_x4': {
'model': swinir_l_real_sr_x4,
'ckpt': 'swinsr/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pdparams'
},
'swinir_m_real_sr_x4': {
'model': swinir_m_real_sr_x4,
'ckpt': 'swinsr/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pdparams'
},
'swinir_m_real_sr_x2': {
'model': swinir_m_real_sr_x2,
'ckpt': 'swinsr/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x2_GAN.pdparams'
}
}
class SwinSR:
def __init__(self, name='swin2sr_real_sr_x4'):
self.model = model_dict[name]['model'](ckpt=model_dict[name]['ckpt'])
self.model.eval()
def __call__(self, image_path, output_path):
image = cv2.imread(image_path)
with paddle.no_grad():
img_input = preprocess(image)
img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)
img_output = self.model(img_input)
img_output = img_output.numpy()[0]
img_output = postprocess(img_output)
cv2.imwrite(output_path, img_output)
# 加载模型
model = SwinSR(name='swin2sr_real_sr_x4') # 模型名称
# 图像超分辨率
model(
image_path='images/test.jpeg', # 输入图像路径
output_path='outputs/swin2sr_real_sr_x4.jpg' # 输出图像路径
)
7. 小结
- 简单介绍了两个基于 Swin Transformer 的图像恢复 / 图像超分辨率模型
此文章为搬运
原项目链接
更多推荐
已为社区贡献1438条内容
所有评论(0)