我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Pytorch怎么搭建SRGAN平台提升图片超分辨率

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Pytorch怎么搭建SRGAN平台提升图片超分辨率

本篇内容介绍了“Pytorch怎么搭建SRGAN平台提升图片超分辨率”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!

网络构建

一、什么是SRGAN

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。

Pytorch怎么搭建SRGAN平台提升图片超分辨率


该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。

二、生成网络的构建

Pytorch怎么搭建SRGAN平台提升图片超分辨率


生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。:

SRGAN的生成网络由三个部分组成。

低分辨率图像进入后会经过一个卷积+RELU函数。

然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。

然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升。

前两个部分用于特征提取,第三部分用于提高分辨率。

import mathimport torchfrom torch import nnclass ResidualBlock(nn.Module):    def __init__(self, channels):        super(ResidualBlock, self).__init__()        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)        self.bn1 = nn.BatchNorm2d(channels)        self.prelu = nn.PReLU(channels)        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)        self.bn2 = nn.BatchNorm2d(channels)    def forward(self, x):        short_cut = x        x = self.conv1(x)        x = self.bn1(x)        x = self.prelu(x)        x = self.conv2(x)        x = self.bn2(x)        return x + short_cutclass UpsampleBLock(nn.Module):    def __init__(self, in_channels, up_scale):        super(UpsampleBLock, self).__init__()        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)        self.pixel_shuffle = nn.PixelShuffle(up_scale)        self.prelu = nn.PReLU(in_channels)    def forward(self, x):        x = self.conv(x)        x = self.pixel_shuffle(x)        x = self.prelu(x)        return xclass Generator(nn.Module):    def __init__(self, scale_factor, num_residual=16):        upsample_block_num = int(math.log(scale_factor, 2))        super(Generator, self).__init__()        self.block_in = nn.Sequential(            nn.Conv2d(3, 64, kernel_size=9, padding=4),            nn.PReLU(64)        )        self.blocks = []        for _ in range(num_residual):            self.blocks.append(ResidualBlock(64))        self.blocks = nn.Sequential(*self.blocks)        self.block_out = nn.Sequential(            nn.Conv2d(64, 64, kernel_size=3, padding=1),            nn.BatchNorm2d(64)        )        self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]        self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))        self.upsample = nn.Sequential(*self.upsample)    def forward(self, x):        x = self.block_in(x)        short_cut = x        x = self.blocks(x)        x = self.block_out(x)        upsample = self.upsample(x + short_cut)        return torch.tanh(upsample)

三、判别网络的构建

Pytorch怎么搭建SRGAN平台提升图片超分辨率


判别网络的构成如上图所示:

SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果。

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

class Discriminator(nn.Module):    def __init__(self):        super(Discriminator, self).__init__()        self.net = nn.Sequential(            nn.Conv2d(3, 64, kernel_size=3, padding=1),            nn.LeakyReLU(0.2),            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),            nn.BatchNorm2d(64),            nn.LeakyReLU(0.2),            nn.Conv2d(64, 128, kernel_size=3, padding=1),            nn.BatchNorm2d(128),            nn.LeakyReLU(0.2),            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),            nn.BatchNorm2d(128),            nn.LeakyReLU(0.2),            nn.Conv2d(128, 256, kernel_size=3, padding=1),            nn.BatchNorm2d(256),            nn.LeakyReLU(0.2),            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),            nn.BatchNorm2d(256),            nn.LeakyReLU(0.2),            nn.Conv2d(256, 512, kernel_size=3, padding=1),            nn.BatchNorm2d(512),            nn.LeakyReLU(0.2),            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),            nn.BatchNorm2d(512),            nn.LeakyReLU(0.2),            nn.AdaptiveAvgPool2d(1),            nn.Conv2d(512, 1024, kernel_size=1),            nn.LeakyReLU(0.2),            nn.Conv2d(1024, 1, kernel_size=1)        )    def forward(self, x):        batch_size = x.size(0)        return torch.sigmoid(self.net(x).view(batch_size))

训练思路

SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。

一、判别器的训练

在训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签。

因此判别器的训练步骤如下:

随机选取batch_size个真实高分辨率图片。
 

利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。
 

真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

Pytorch怎么搭建SRGAN平台提升图片超分辨率

二、生成器的训练

在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。
 

将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss

Pytorch怎么搭建SRGAN平台提升图片超分辨率

利用SRGAN生成图片

SRGAN的库整体结构如下:

Pytorch怎么搭建SRGAN平台提升图片超分辨率

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

Pytorch怎么搭建SRGAN平台提升图片超分辨率

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。

Pytorch怎么搭建SRGAN平台提升图片超分辨率

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。

Pytorch怎么搭建SRGAN平台提升图片超分辨率


训练过程中,可在results文件夹内查看训练效果:

Pytorch怎么搭建SRGAN平台提升图片超分辨率

“Pytorch怎么搭建SRGAN平台提升图片超分辨率”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注编程网网站,小编将为大家输出更多高质量的实用文章!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Pytorch怎么搭建SRGAN平台提升图片超分辨率

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Pytorch怎么搭建SRGAN平台提升图片超分辨率

本篇内容介绍了“Pytorch怎么搭建SRGAN平台提升图片超分辨率”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!网络构建一、什么是SRGA
2023-06-30

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录