1. 背景

用PaddlePaddle复现论文LIIF,LIIF中使用的Encoder是RDN,本文介绍一下RDN。

RDN论文:https://arxiv.org/abs/1802.08797

Torch代码: https://github.com/yinboc/liif/blob/main/models/rdn.py

2. RDN的组成

RDN网络结构
在这里插入图片描述

2.1 浅层特征提取网络(SFENet)

就是网络最开始的那两个卷积层。

2.2 残差密集块RDB

(residual density block, RDB )
在这里插入图片描述

2.3 残差密集块RDBs

在这里插入图片描述

通过将卷积层密集连接的方式提取丰富的局部特征。
RDB还允许从先前所有RDB的状态直接连接到当前RDB中的所有层,从而形成了连续记忆(contiguous memory,CM)机制。

【残差密集块RDB = 密集连接层 + 局部特征融合(LFF)+ 局部残差】,形成了连续记忆机制(Contiguous Memory)

连续记忆机制(CM)
就是可以将第 d-1 个RDB块的输出直接输入到第 d 个RDB块中的每一层去(见上图dense部分的红线所示),经过dense的作用,可以将 F d − 1 , F d , 1 , F d , c , F d , C F_{d-1},F_{d,1},F_{d,c},F_{d,C} Fd1,Fd,1,Fd,c,Fd,C 的特征都利用起来。

局部特征融合(Local feature fusion,LFF)

即RDB中的那个concat,能够将前一个RDB的输出 F d − 1 F_{d-1} Fd1 、当前RDB F d F_{d} Fd 中每一层得到的状态融合通过concat在一起。然后,再利用 1 x 1 卷积对concat降低通道数,简化数据。

局部残差学习(Local residual learning)

由于RDB中存在多个卷积层,因此引入局部残差学习以进一步改善信息流。

2.4 密集特征融合(DFF)

通过一系列RDBs提取了局部密集特征后,进一步提出密集特征融合(DFF),从全局的角度挖掘多层次特征(hierarchical features)。 DFF由全局特征融合(GFF)和全局残差学习(GRL)两部分组成。

全局特征融合(Global Feature Fusion)
在这里插入图片描述

如上面Figure 2. 所示,全局特征融合即:

  • 把多个RDBs的输出( F 1 , F d , … , F D F_{1}, F_{d}, \ldots, F_{D} F1,Fd,,FD)concat在一起;
  • 再经过一个 1 x 1 Conv层,将这一系列不同level的特征自适应地融合在一起;
  • 再通过 3 x 3 Conv层,进一步提取特征得到 F G F F_{GF} FGF ,用接下来的全局残差学习(GRL)。

全局残差学习(Global Residual Learning)

全局残差学习就是上面的Figure 2. 中,将通过第一个Conv层得到的浅层特征图 F − 1 F_{-1} F1 ,与上面全局特征融合GFF得到的 F G F F_{GF} FGF 作element-wise的相加,得到 F D F F_{DF} FDF

2.5 上采样网络(UPNet)

就是一个上采样+卷积操作,最终输出HR结果 I H R I_{HR} IHR

3. 数据集

3.1 数据集介绍

DIV2K是一个流行的单图像超分辨率数据集,它包含 1000 张不同场景的图像,分为 800 张用于训练,100 张用于验证,100 张用于测试。它是为 NTIRE2017 和 NTIRE2018 超分辨率挑战收集的,以鼓励对具有更逼真退化的图像超分辨率的研究。该数据集包含具有不同类型退化的低分辨率图像。

div2k数据集官方地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/

本项目使用的是 AiStudio公开数据集里的已存在的div2k https://aistudio.baidu.com/aistudio/datasetdetail/104667

# 解压数据集
!unzip -qo /home/aistudio/data/data104667/DIV2K_train_HR.zip -d /home/aistudio/DIV2K
!unzip -qo /home/aistudio/data/data104667/DIV2K_valid_HR.zip -d /home/aistudio/DIV2K

#导入包
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import time
import warnings 
warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"]="0"


# 定义常量
BATCHSIZE=16 
SCALE = 4  
PATCHSIZE = [48,48]  
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

3.2 dataset类编写

读取图片目录并处理为成对的数据集

# 定义批量读取数据集目录
def reader_patch(batchsize,scale=SCALE,patchsize=PATCHSIZE):

    dirsname = '/home/aistudio/DIV2K/DIV2K_train_HR/'
    dirs = os.listdir(dirsname)
    np.random.shuffle(dirs)
    LRs = np.zeros((batchsize,3,patchsize[0],patchsize[1])).astype("float32")
    HRs = np.zeros((batchsize,3,patchsize[0]*scale,patchsize[1]*scale)).astype("float32")
    for filename in dirs:
        image = Image.open(dirsname+filename)
        sz = image.size
        sz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scale
        diff_row = sz[1] - sz_row
        sz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scale
        diff_col = sz[0] - sz_col
        row_min = np.random.randint(diff_row+1)
        col_min = np.random.randint(diff_col+1)
        HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))
        LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)
        LR = np.array(LR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1
        HR = np.array(HR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1
        for batch in range(batchsize):
            rowMin, colMin = np.random.randint(0,LR.shape[1]-patchsize[0]+1), np.random.randint(0,LR.shape[2]-patchsize[1]+1)
            LRs[batch,:,:,:] = LR[:,rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1]]
            HRs[batch,:,:,:] = HR[:,scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])]
        yield LRs, HRs

#随机查看数据集的LR和HR效果
data = reader_patch(1)
for i in range(2):
    LR, HR = next(data)
    LR = LR.transpose([2,3,1,0]).reshape(PATCHSIZE[0],PATCHSIZE[1],3)
    LR = Image.fromarray(np.uint8((LR+1)/2*255))
    HR = HR.transpose([2,3,1,0]).reshape(PATCHSIZE[0]*SCALE,PATCHSIZE[1]*SCALE,3)
    HR = Image.fromarray(np.uint8((HR+1)/2*255))
    plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE)) #是把HR处理为LR后再放大多少倍
    plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')
    plt.show()

在这里插入图片描述

在这里插入图片描述

# 定义读取数据集方法
def load_data(mode='train',batchsize=BATCHSIZE,scale=SCALE,patchsize=PATCHSIZE):
    if mode=='train':
        dirsname = '/home/aistudio/DIV2K/DIV2K_train_HR/'
    elif mode=='valid':
        dirsname = '/home/aistudio/DIV2K/DIV2K_valid_HR/'
    dirs = os.listdir(dirsname)

    # 定义数据生成器
    def data_generator():
        # 训练模式下,打乱训练数据
        if mode == 'train':
            np.random.shuffle(dirs)
        
        LRs = np.zeros((batchsize,3,patchsize[0],patchsize[1])).astype("float32")
        HRs = np.zeros((batchsize,3,patchsize[0]*scale,patchsize[1]*scale)).astype("float32")
        for filename in dirs:
            # print(filename)
            image = Image.open(dirsname+filename)
            sz = image.size
            sz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scale
            diff_row = sz[1] - sz_row
            sz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scale
            diff_col = sz[0] - sz_col
            row_min = np.random.randint(diff_row+1)
            col_min = np.random.randint(diff_col+1)
            HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))
            LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)
            LR = np.array(LR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1
            HR = np.array(HR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1
            for batch in range(batchsize):
                rowMin, colMin = np.random.randint(0,LR.shape[1]-patchsize[0]+1), np.random.randint(0,LR.shape[2]-patchsize[1]+1)
                LRs[batch,:,:,:] = LR[:,rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1]]
                HRs[batch,:,:,:] = HR[:,scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])]
            yield LRs, HRs

    return data_generator
#随机读取对比下低精和高精效果
train_loader = load_data('train',1)
for i in range(2):
    LR, HR = next(train_loader())
    LR = LR.transpose([2,3,1,0]).reshape(PATCHSIZE[0],PATCHSIZE[1],3)
    LR = Image.fromarray(np.uint8((LR+1)/2*255))
    HR = HR.transpose([2,3,1,0]).reshape(PATCHSIZE[0]*SCALE,PATCHSIZE[1]*SCALE,3)
    HR = Image.fromarray(np.uint8((HR+1)/2*255))
    plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE))
    plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')
    plt.show()

在这里插入图片描述

在这里插入图片描述

4. 组建网络

4.1 RDN网络模型

# RDB卷积层
class RDB_Conv(nn.Layer):
    def __init__(self, inChannels, growRate, kSize=3):
        super().__init__()
        Cin = inChannels
        G  = growRate
        self.conv = nn.Sequential(*[
            nn.Conv2D(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
            nn.ReLU()
        ])

    def forward(self, x):
        out = self.conv(x)
        return paddle.concat([x, out], 1)

# 残差密集块RDB        
class RDB(nn.Layer):
    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
        super().__init__()
        G0 = growRate0
        G  = growRate
        C  = nConvLayers
        convs = []
        for c in range(C):
            convs.append(RDB_Conv(G0 + c*G, G))
        self.convs = nn.Sequential(*convs)
        # Local Feature Fusion
        self.LFF = nn.Conv2D(G0 + C*G, G0, 1, padding=0, stride=1)

    def forward(self, x):
        return self.LFF(self.convs(x)) + x
# 定义RND网络
class RDN(nn.Layer):
    def __init__(self):
        super(RDN, self).__init__()

        # self.args = args
        # r = args.scale[0]
        # G0 = args.G0
        # kSize = args.RDNkSize

        r = SCALE
        G0 = 64
        kSize = 3
        n_colors = 3
        self.no_upsampling = False

        # number of RDB blocks, conv layers, out channels
        self.D, C, G = {
            'A': (20, 6, 32),
            'B': (16, 8, 64),
        }['B']
        # Shallow feature extraction net
        self.SFENet1 = nn.Conv2D(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
        self.SFENet2 = nn.Conv2D(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        # Redidual dense blocks and dense feature fusion
        self.RDBs = nn.LayerList()
        for i in range(self.D):
            self.RDBs.append(
                RDB(growRate0 = G0, growRate = G, nConvLayers = C)
            )
        # Global Feature Fusion
        self.GFF = nn.Sequential(*[
            nn.Conv2D(self.D * G0, G0, 1, padding=0, stride=1),
            nn.Conv2D(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        ])
        #
        if self.no_upsampling:
            self.out_dim = G0
        else:
            self.out_dim = n_colors
            # Up-sampling net
            if r == 2 or r == 3:
                self.UPNet = nn.Sequential(*[
                    nn.Conv2D(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
                    nn.PixelShuffle(r),
                    nn.Conv2D(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
                ])
            elif r == 4:
                self.UPNet = nn.Sequential(*[
                    nn.Conv2D(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                    nn.PixelShuffle(2),
                    nn.Conv2D(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                    nn.PixelShuffle(2),
                    nn.Conv2D(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
                ])
            else:
                raise ValueError("scale must be 2 or 3 or 4.")

    def forward(self, x):
        f__1 = self.SFENet1(x)
        x  = self.SFENet2(f__1)
        #
        RDBs_out = []
        for i in range(self.D):
            x = self.RDBs[i](x)
            RDBs_out.append(x)
        #
        x = self.GFF(paddle.concat(RDBs_out,1))
        x += f__1
        #
        if self.no_upsampling:
            return x
        else:
            return self.UPNet(x)

定义显示图片,包括LR(低清图),HR(高清图) 和SR(超分辨率图)

# SR图片显示函数
def show_image(G=None):
    if G==None:
        G = RDN()
    G.eval()
    dirsname = '/home/aistudio/DIV2K/DIV2K_train_HR/'
    dirs = os.listdir(dirsname)
    np.random.shuffle(dirs)
    fig = plt.figure(figsize=(25, 25))
    gs = plt.GridSpec(1, 3)
    gs.update(wspace=0.1, hspace=0.1)
    image = Image.open(dirsname+dirs[0])
    # image = image.crop([0,0,image.size[0]//SCALE*SCALE,image.size[1]//SCALE*SCALE])
    # image = image.crop([0,0,100,100])
    image = image.crop([0,0,200,200])
    LR0 = image.resize((image.size[0]//SCALE,image.size[1]//SCALE),Image.BICUBIC)
    LR = np.array(LR0).astype('float32').reshape([image.size[1]//SCALE,image.size[0]//SCALE,3,1]).transpose([3,2,0,1]) / 255 * 2 - 1
    LSR = G(paddle.to_tensor(LR)).numpy()
    print(np.max(LSR), np.min(LSR))
    LSR = LSR.reshape([3,image.size[1]//SCALE*SCALE,image.size[0]//SCALE*SCALE]).transpose([1,2,0])
    # LSR = Image.fromarray(np.uint8((LSR+1)/2*255)) ### 亮斑罪魁祸首
    LSR = (LSR+1)/2

    ax = plt.subplot(gs[0])
    plt.imshow(LR0)
    plt.title('LR')
    ax = plt.subplot(gs[1])
    plt.imshow(LSR)
    plt.title('SR')
    ax = plt.subplot(gs[2])
    plt.imshow(image)
    plt.title('HR')
    plt.show()

4.2 模型训练

定义数据增强函数

# 数据增强
def data_augmentation(LR, HR): 
    if np.random.randint(2) == 1:
        LR = LR[:,:,:,::-1]
        HR = HR[:,:,:,::-1]
    n = np.random.randint(4)
    if n == 1:
        LR = LR.transpose([0,1,3,2])
        LR = LR[:,:,::-1,:]
        HR = HR.transpose([0,1,3,2])
        HR = HR[:,:,::-1,:]
    if n == 2:
        LR = LR[:,:,:,::-1]
        LR = LR[:,:,::-1,:]
        HR = HR[:,:,:,::-1]
        HR = HR[:,:,::-1,:]
    if n == 3:
        LR = LR.transpose([0,1,3,2])
        LR = LR[:,:,:,::-1]
        HR = HR.transpose([0,1,3,2])
        HR = HR[:,:,:,::-1]
    return LR, HR

定义数据加载和训练函数

from visualdl import LogWriter
log_writer = LogWriter(logdir="./output/RDN/log")

#调用加载数据的函数
train_loader = load_data('train')
# LR, HR = next(train_loader())
# print(LR, HR )

def train(model,epoch_num=200,batchsize=1,load_model=False):
    model.train()
    optimizer = paddle.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=model.parameters())
    model_path = './output/RDN/'
    if load_model == True:
        model.set_state_dict(paddle.load(model_path+'Model.pdparams'))

    iteration_num = 0
    iters=[]
    losses=[]
    for epoch_id in range(epoch_num):
        for batch_id, data in enumerate(train_loader()):
            iteration_num += 1             
            LR, HR =  data
            LR, HR = data_augmentation(LR, HR) # 数据增强
            LR = paddle.to_tensor(LR)
            HR = paddle.to_tensor(HR)
            y = model(LR) 
            loss = paddle.mean(paddle.abs(y - HR))

            #每训练了100批次的数据,打印下当前Loss的情况
            if(iteration_num % 20 == 0):
                datetime = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
                print("{} epoch: {},batch_id: {}, iter: {}, loss is: {}".format(datetime,epoch_id,batch_id, iteration_num, loss.numpy())) 
                 # 累计迭代次数和对应的loss
                log_writer.add_scalar(tag = 'loss', step = iteration_num, value = loss.numpy())
            
            #后向传播,更新参数的过程
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
        
        #保存模型参数                    
        paddle.save(model.state_dict(), model_path+'Model.pdparams')
        print('save model in {}'.format(model_path+'Model.pdparams'))
    


训练开始

# train(epoch_num=1,  load_model=False, batchsize=12)  #第一个epoch,保存模型, 记得注释
# train(epoch_num=1000, load_model=True,  batchsize=16)  #后面每次训练就可以读取之前的模型,继续训练了

#启动训练过程
model = RDN()
train(model=model,epoch_num=10, load_model=False,  batchsize=16)

4.3 模型预测

# 加载训练好的RDN模型
GG = RDN()
GG.eval()

GG.set_state_dict(paddle.load('./Best.pdparams'))

# 显示图像,SR与LR和HR的对比
show_image(GG)

0.6906135 -0.94815046

在这里插入图片描述

5. 结束

了解RDN后,就可以进一步了解LIIF了。戳《超分辨率模型-LIIF,可放大30多倍

如果对你有帮助,请关注、点赞、fork。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐