我的编程空间,编程开发者的网络收藏夹
学习永远不晚

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。

初始化函数:torch.nn.init

# -*- coding: utf-8 -*-"""Created on 2019@author: fancp"""import torch import torch.nn as nnw = torch.empty(3,5)#1.均匀分布 - u(a,b)#torch.nn.init.uniform_(tensor, a=0.0, b=1.0)print(nn.init.uniform_(w))# =============================================================================# tensor([[0.9160, 0.1832, 0.5278, 0.5480, 0.6754],#     [0.9509, 0.8325, 0.9149, 0.8192, 0.9950],#     [0.4847, 0.4148, 0.8161, 0.0948, 0.3787]])# =============================================================================#2.正态分布 - N(mean, std)#torch.nn.init.normal_(tensor, mean=0.0, std=1.0)print(nn.init.normal_(w))# =============================================================================# tensor([[ 0.4388, 0.3083, -0.6803, -1.1476, -0.6084],#     [ 0.5148, -0.2876, -1.2222, 0.6990, -0.1595],#     [-2.0834, -1.6288, 0.5057, -0.5754, 0.3052]])# =============================================================================#3.常数 - 固定值 val#torch.nn.init.constant_(tensor, val)print(nn.init.constant_(w, 0.3))# =============================================================================# tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])# =============================================================================#4.全1分布#torch.nn.init.ones_(tensor)print(nn.init.ones_(w))# =============================================================================# tensor([[1., 1., 1., 1., 1.],#     [1., 1., 1., 1., 1.],#     [1., 1., 1., 1., 1.]])# =============================================================================#5.全0分布#torch.nn.init.zeros_(tensor)print(nn.init.zeros_(w))# =============================================================================# tensor([[0., 0., 0., 0., 0.],#     [0., 0., 0., 0., 0.],#     [0., 0., 0., 0., 0.]])# =============================================================================#6.对角线为 1,其它为 0#torch.nn.init.eye_(tensor)print(nn.init.eye_(w))# =============================================================================# tensor([[1., 0., 0., 0., 0.],#     [0., 1., 0., 0., 0.],#     [0., 0., 1., 0., 0.]])# =============================================================================#7.xavier_uniform 初始化#torch.nn.init.xavier_uniform_(tensor, gain=1.0)#From - Understanding the difficulty of training deep feedforward neural networks - Bengio 2010print(nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')))# =============================================================================# tensor([[-0.1270, 0.3963, 0.9531, -0.2949, 0.8294],#     [-0.9759, -0.6335, 0.9299, -1.0988, -0.1496],#     [-0.7224, 0.2181, -1.1219, 0.8629, -0.8825]])# =============================================================================#8.xavier_normal 初始化#torch.nn.init.xavier_normal_(tensor, gain=1.0)print(nn.init.xavier_normal_(w))# =============================================================================# tensor([[ 1.0463, 0.1275, -0.3752, 0.1858, 1.1008],#     [-0.5560, 0.2837, 0.1000, -0.5835, 0.7886],#     [-0.2417, 0.1763, -0.7495, 0.4677, -0.1185]])# =============================================================================#9.kaiming_uniform 初始化#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')#From - Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - HeKaiming 2015print(nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu'))# =============================================================================# tensor([[-0.7712, 0.9344, 0.8304, 0.2367, 0.0478],#     [-0.6139, -0.3916, -0.0835, 0.5975, 0.1717],#     [ 0.3197, -0.9825, -0.5380, -1.0033, -0.3701]])# =============================================================================#10.kaiming_normal 初始化#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')print(nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu'))# =============================================================================# tensor([[-0.0210, 0.5532, -0.8647, 0.9813, 0.0466],#     [ 0.7713, -1.0418, 0.7264, 0.5547, 0.7403],#     [-0.8471, -1.7371, 1.3333, 0.0395, 1.0787]])# =============================================================================#11.正交矩阵 - (semi)orthogonal matrix#torch.nn.init.orthogonal_(tensor, gain=1)#From - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe 2013print(nn.init.orthogonal_(w))# =============================================================================# tensor([[-0.0346, -0.7607, -0.0428, 0.4771, 0.4366],#     [-0.0412, -0.0836, 0.9847, 0.0703, -0.1293],#     [-0.6639, 0.4551, 0.0731, 0.1674, 0.5646]])# =============================================================================#12.稀疏矩阵 - sparse matrix #torch.nn.init.sparse_(tensor, sparsity, std=0.01)#From - Deep learning via Hessian-free optimization - Martens 2010print(nn.init.sparse_(w, sparsity=0.1))# =============================================================================# tensor([[ 0.0000, 0.0000, -0.0077, 0.0000, -0.0046],#     [ 0.0152, 0.0030, 0.0000, -0.0029, 0.0005],#     [ 0.0199, 0.0132, -0.0088, 0.0060, 0.0000]])# =============================================================================

补充:【pytorch参数初始化】 pytorch默认参数初始化以及自定义参数初始化

本文用两个问题来引入

1.pytorch自定义网络结构不进行参数初始化会怎样,参数值是随机的吗?

2.如何自定义参数初始化?

先回答第一个问题

在pytorch中,有自己默认初始化参数方式,所以在你定义好网络结构以后,不进行参数初始化也是可以的。

Conv2d继承自_ConvNd,在_ConvNd中,可以看到默认参数就是进行初始化的,如下图所示

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

torch.nn.BatchNorm2d也一样有默认初始化的方式

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

torch.nn.Linear也如此

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

现在来回答第二个问题。

pytorch中对神经网络模型中的参数进行初始化方法如下:

from torch.nn import init#define the initial function to init the layer's parameters for the networkdef weigth_init(m):  if isinstance(m, nn.Conv2d):    init.xavier_uniform_(m.weight.data)    init.constant_(m.bias.data,0.1)  elif isinstance(m, nn.BatchNorm2d):    m.weight.data.fill_(1)    m.bias.data.zero_()  elif isinstance(m, nn.Linear):    m.weight.data.normal_(0,0.01)    m.bias.data.zero_()

首先定义了一个初始化函数,接着进行调用就ok了,不过要先把网络模型实例化:

 #Define Network  model = Net(args.input_channel,args.output_channel)  model.apply(weigth_init)

此上就完成了对模型中训练参数的初始化。

在知乎上也有看到一个类似的版本,也相应的贴上来作为参考了:

def initNetParams(net):  '''Init net parameters.'''  for m in net.modules():    if isinstance(m, nn.Conv2d):      init.xavier_uniform(m.weight)      if m.bias:        init.constant(m.bias, 0)    elif isinstance(m, nn.BatchNorm2d):      init.constant(m.weight, 1)      init.constant(m.bias, 0)    elif isinstance(m, nn.Linear):      init.normal(m.weight, std=1e-3)      if m.bias:        init.constant(m.bias, 0) initNetParams(net)

再说一下关于模型的保存及加载

保存有两种方式,第一种是保存模型的整个结构信息和参数,第二种是只保存模型的参数

 #保存整个网络模型及参数 torch.save(net, 'net.pkl')   #仅保存模型参数 torch.save(net.state_dict(), 'net_params.pkl')

加载对应保存的两种网络

# 保存和加载整个模型 torch.save(model_object, 'model.pth') model = torch.load('model.pth')  # 仅保存和加载模型参数 torch.save(model_object.state_dict(), 'params.pth') model_object.load_state_dict(torch.load('params.pth'))

看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注编程网行业资讯频道,感谢您对编程网的支持。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化

怎么在Pytorch 中对TORCH.NN.INIT 参数进行初始化?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。初始化函数:torch.nn.init# -*
2023-06-06

怎么在python中对defaultdict进行初始化

怎么在python中对defaultdict进行初始化?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。python是什么意思Python是一种跨平台的、具有解释性、编译性、互动性
2023-06-14

JAVA对象怎么进行初始化

在Java中,对象可以通过以下几种方式进行初始化:1. 使用new关键字:通过使用new关键字可以创建一个对象,并调用构造方法对对象进行初始化。例如,可以使用以下方式创建一个String对象并初始化:```javaString str =
2023-09-09

怎么在python中初始化进程池Pool

本篇文章为大家展示了怎么在python中初始化进程池Pool,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。python的数据类型有哪些?python的数据类型:1. 数字类型,包括int(整型)、l
2023-06-14

怎么在java中初始化数组

这篇文章给大家介绍怎么在java中初始化数组,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。Java可以用来干什么Java主要应用于:1. web开发;2. Android开发;3. 客户端开发;4. 网页开发;5. 企
2023-06-14

怎么在java中对方法参数进行核对

本篇文章为大家展示了怎么在java中对方法参数进行核对,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。Java是什么Java是一门面向对象编程语言,可以编写桌面应用程序、Web应用程序、分布式系统和嵌
2023-06-14

怎么在Java中动态初始化数组

怎么在Java中动态初始化数组?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。Java可以用来干什么Java主要应用于:1. web开发;2. Android开发;3. 客户
2023-06-14

怎么在Java中初始化二维数组

今天就跟大家聊聊有关怎么在Java中初始化二维数组,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。Java可以用来干什么Java主要应用于:1. web开发;2. Android开发;
2023-06-14

Java中怎么对参数进行传递

本篇文章给大家分享的是有关Java中怎么对参数进行传递,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。可以理解当我们要调用一个方法时,我们会把指定的数值,传递给方法中的参数,这样
2023-05-31

怎么用teamviewer远程控制正在初始化显示参数

这篇文章主要介绍了怎么用teamviewer远程控制正在初始化显示参数的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇怎么用teamviewer远程控制正在初始化显示参数文章都会有所收获,下面我们一起来看看吧。t
2023-07-01

怎么在MySQL中对Group by进行优化

本篇文章为大家展示了怎么在MySQL中对Group by进行优化,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。一个标准的 Group by 语句包含排序、分组、聚合函数,比如 select a,co
2023-06-08

怎么在MySQL中对查询进行优化

本篇文章给大家分享的是有关怎么在MySQL中对查询进行优化,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、创建索引规范在学习索引优化之前,需要对创建索引的规范有一定的了解,此
2023-06-08

怎么在Python中利用for循环初始化数组

这篇文章给大家介绍怎么在Python中利用for循环初始化数组,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。python是什么意思Python是一种跨平台的、具有解释性、编译性、互动性和面向对象的脚本语言,其最初的设计
2023-06-14

怎么在Android中对SQLite数据库进行数据持久化

怎么在Android中对SQLite数据库进行数据持久化?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。1、SQLiteOpenHelper:创建数据库和数据库版本管理的辅助
2023-05-31

怎么在java中对数组进行排序

这期内容当中小编将会给大家带来有关怎么在java中对数组进行排序,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。Java是什么Java是一门面向对象编程语言,可以编写桌面应用程序、Web应用程序、分布式系统
2023-06-14

怎么在JavaScript中对数组进行求和

这篇文章将为大家详细讲解有关怎么在JavaScript中对数组进行求和,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。JavaScript数组求和的方法:1、for循环,代码为【for (va
2023-06-14

怎么在java中对数据进行比较

这篇文章将为大家详细讲解有关怎么在java中对数据进行比较,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。Java的特点有哪些Java的特点有哪些1.Java语言作为静态面向对象编程语言的代表
2023-06-14

怎么在java8中对函数进行引用

怎么在java8中对函数进行引用?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。函数引用的类型函数引用分为以下四种:静态函数,比如 Integer 类的 pars
2023-05-31

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录