我的编程空间,编程开发者的网络收藏夹
学习永远不晚

怎么在Pytorch中利用WGAN生成动漫头像

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

怎么在Pytorch中利用WGAN生成动漫头像

本篇文章为大家展示了怎么在Pytorch中利用WGAN生成动漫头像,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。

WGAN与GAN的不同

  • 去除sigmoid

  • 使用具有动量的优化方法,比如使用RMSProp

  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torchimport torch.nn as nnimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderfrom torchvision.utils import save_imageimport osfrom anime_face_generator.dataset import ImageDataset batch_size = 32num_epoch = 100z_dimension = 100dir_path = './wgan_img' # 创建文件夹if not os.path.exists(dir_path):  os.mkdir(dir_path)  def to_img(x):  """因为我们在生成器里面用了tanh"""  out = 0.5 * (x + 1)  return out  dataset = ImageDataset()dataloader = DataLoader(dataset, batch_size=32, shuffle=False)  class Generator(nn.Module):  def __init__(self):    super().__init__()     self.gen = nn.Sequential(      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),      nn.BatchNorm2d(512),      nn.ReLU(True),      # 上一步的输出形状:(512) x 4 x 4      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),      nn.BatchNorm2d(256),      nn.ReLU(True),      # 上一步的输出形状: (256) x 8 x 8      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),      nn.BatchNorm2d(128),      nn.ReLU(True),      # 上一步的输出形状: (256) x 16 x 16      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),      nn.BatchNorm2d(64),      nn.ReLU(True),      # 上一步的输出形状:(256) x 32 x 32      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),      nn.Tanh() # 输出范围 -1~1 故而采用Tanh      # nn.Sigmoid()      # 输出形状:3 x 96 x 96    )   def forward(self, x):    x = self.gen(x)    return x   def weight_init(m):    # weight_initialization: important for wgan    class_name = m.__class__.__name__    if class_name.find('Conv') != -1:      m.weight.data.normal_(0, 0.02)    elif class_name.find('Norm') != -1:      m.weight.data.normal_(1.0, 0.02)  class Discriminator(nn.Module):  def __init__(self):    super().__init__()    self.dis = nn.Sequential(      nn.Conv2d(3, 64, 5, 3, 1, bias=False),      nn.LeakyReLU(0.2, inplace=True),      # 输出 (64) x 32 x 32       nn.Conv2d(64, 128, 4, 2, 1, bias=False),      nn.BatchNorm2d(128),      nn.LeakyReLU(0.2, inplace=True),      # 输出 (128) x 16 x 16       nn.Conv2d(128, 256, 4, 2, 1, bias=False),      nn.BatchNorm2d(256),      nn.LeakyReLU(0.2, inplace=True),      # 输出 (256) x 8 x 8       nn.Conv2d(256, 512, 4, 2, 1, bias=False),      nn.BatchNorm2d(512),      nn.LeakyReLU(0.2, inplace=True),      # 输出 (512) x 4 x 4       nn.Conv2d(512, 1, 4, 1, 0, bias=False),      nn.Flatten(),      # nn.Sigmoid() # 输出一个数(概率)    )   def forward(self, x):    x = self.dis(x)    return x   def weight_init(m):    # weight_initialization: important for wgan    class_name = m.__class__.__name__    if class_name.find('Conv') != -1:      m.weight.data.normal_(0, 0.02)    elif class_name.find('Norm') != -1:      m.weight.data.normal_(1.0, 0.02)  def save(model, filename="model.pt", out_dir="out/"):  if model is not None:    if not os.path.exists(out_dir):      os.mkdir(out_dir)    torch.save({'model': model.state_dict()}, out_dir + filename)  else:    print("[ERROR]:Please build a model!!!")  import QuickModelBuilder as builder if __name__ == '__main__':  one = torch.FloatTensor([1]).cuda()  mone = -1 * one   is_print = True  # 创建对象  D = Discriminator()  G = Generator()  D.weight_init()  G.weight_init()   if torch.cuda.is_available():    D = D.cuda()    G = G.cuda()   lr = 2e-4  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)   fake_img = None   # ##########################进入训练##判别器的判断过程#####################  for epoch in range(num_epoch): # 进行多个epoch的训练    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))    for i, img in enumerate(dataloader):      num_img = img.size(0)      real_img = img.cuda() # 将tensor变成Variable放入计算图中      # 这里的优化器是D的优化器      for param in D.parameters():        param.requires_grad = True      # ########判别器训练train#####################      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假       # 计算真实图片的损失      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0      real_out = D(real_img) # 将真实图片放入判别器中      d_loss_real = real_out.mean(0).view(1)      d_loss_real.backward(one)       # 计算生成图片的损失      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声      z = z.reshape(num_img, z_dimension, 1, 1)      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离      fake_out = D(fake_img) # 判别器判断假的图片,      d_loss_fake = fake_out.mean(0).view(1)      d_loss_fake.backward(mone)       d_loss = d_loss_fake - d_loss_real      d_optimizer.step() # 更新参数       # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01      for parm in D.parameters():        parm.data.clamp_(-0.01, 0.01)       # ==================训练生成器============================      # ###############################生成网络的训练###############################      for param in D.parameters():        param.requires_grad = False       # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D      g_optimizer.zero_grad() # 梯度归0       z = torch.randn(num_img, z_dimension).cuda()      z = z.reshape(num_img, z_dimension, 1, 1)      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片      output = D(fake_img) # 经过判别器得到的结果      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss      g_loss = torch.mean(output).view(1)      # bp and optimize      g_loss.backward(one) # 进行反向传播      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数       # 打印中间的损失      pbar.set_right_info(d_loss=d_loss.data.item(),                g_loss=g_loss.data.item(),                real_scores=real_out.data.mean().item(),                fake_scores=fake_out.data.mean().item(),                )      pbar.update()      try:        fake_images = to_img(fake_img.cpu())        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))      except:        pass      if is_print:        is_print = False        real_images = to_img(real_img.cpu())        save_image(real_images, dir_path + '/real_images.png')    pbar.finish()    d_scheduler.step()    g_scheduler.step()    save(D, "wgan_D.pt")    save(G, "wgan_G.pt")

上述内容就是怎么在Pytorch中利用WGAN生成动漫头像,你们学到知识或技能了吗?如果还想学到更多技能或者丰富自己的知识储备,欢迎关注编程网行业资讯频道。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

怎么在Pytorch中利用WGAN生成动漫头像

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

怎么在Pytorch中利用WGAN生成动漫头像

本篇文章为大家展示了怎么在Pytorch中利用WGAN生成动漫头像,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。WGAN与GAN的不同去除sigmoid使用具有动量的优化方法,比如使用RMSProp
2023-06-06

怎么利用Python实现一键将头像转成动漫风

本篇内容主要讲解“怎么利用Python实现一键将头像转成动漫风”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“怎么利用Python实现一键将头像转成动漫风”吧!PyQt5框架用Python编程语言
2023-07-02

怎么利用Python编写一个藏头诗在线生成器

这篇文章主要介绍了怎么利用Python编写一个藏头诗在线生成器的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇怎么利用Python编写一个藏头诗在线生成器文章都会有所收获,下面我们一起来看看吧。一、藏头诗(“小浪
2023-06-30

怎么在树莓派中利用mjpg-streamer调用摄像头

本篇文章给大家分享的是有关怎么在树莓派中利用mjpg-streamer调用摄像头,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。1.更新软件:sudo apt-get updat
2023-06-06

C++中怎么利用Test自动生成函数

C++中怎么利用Test自动生成函数,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。静态测试 C++Test内嵌了业界最出名的Effective C++(epcc)、More
2023-06-17

怎么在Android中利用文字生成图片

这期内容当中小编将会给大家带来有关怎么在Android中利用文字生成图片,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。1.根据原图片的大小和字体的大小创建一张空白图片 2.把原图片按字体的大小分成若干块,
2023-05-30

怎么在python中利用scipy.stats生成随机数

这期内容当中小编将会给大家带来有关怎么在python中利用scipy.stats生成随机数,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。Python主要用来做什么Python主要应用于:1、Web开发;2
2023-06-14

怎么在python中利用choice生成随机数

这篇文章将为大家详细讲解有关怎么在python中利用choice生成随机数,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。Python的优点有哪些1、简单易用,与C/C++、Java、C# 等
2023-06-14

怎么在java中利用反射生成对象

这期内容当中小编将会给大家带来有关怎么在java中利用反射生成对象,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。Java是什么Java是一门面向对象编程语言,可以编写桌面应用程序、Web应用程序、分布式系
2023-06-14

怎么在SpringBoot中利用Captcha生成验证码

本篇文章给大家分享的是有关怎么在SpringBoot中利用Captcha生成验证码,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。springboot是什么springboot一
2023-06-14

怎么在python中利用生成器实现协程

这篇文章给大家介绍怎么在python中利用生成器实现协程,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。python是什么意思Python是一种跨平台的、具有解释性、编译性、互动性和面向对象的脚本语言,其最初的设计是用于
2023-06-14

怎么利用反射生成MyBatisPlus中QueryWrapper动态条件

这篇文章主要介绍了怎么利用反射生成MyBatisPlus中QueryWrapper动态条件的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇怎么利用反射生成MyBatisPlus中QueryWrapper动态条件文
2023-06-29

python中怎么使用Pillow做动态图在图中生成二维码及图像处理

这篇文章主要讲解了“python中怎么使用Pillow做动态图在图中生成二维码及图像处理”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“python中怎么使用Pillow做动态图在图中生成二维
2023-06-29

怎么在SpringBoot中使用Mybatis-Plus自动代码生成

本篇文章为大家展示了怎么在SpringBoot中使用Mybatis-Plus自动代码生成,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。springboot是什么springboot一种全新的编程规范
2023-06-14

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录