我的编程空间,编程开发者的网络收藏夹
学习永远不晚

PyTorch中的参数类torch.nn.Parameter()详解

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

PyTorch中的参数类torch.nn.Parameter()详解

前言

今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论。

分析

先看其名,parameter,中文意为参数。我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输入一张图像),输出一个预测(如输出这张图像中的物体是属于什么类别)。而在我们给定这个函数的结构(如卷积、全连接等)之后,能学习的就是这个函数的参数了,我们设计一个损失函数,配合梯度下降法,使得我们学习到的函数(神经网络)能够尽量准确地完成预测任务。

通常,我们的参数都是一些常见的结构(卷积、全连接等)里面的计算参数。而当我们的网络有一些其他的设计时,会需要一些额外的参数同样很着整个网络的训练进行学习更新,最后得到最优的值,经典的例子有注意力机制中的权重参数、Vision Transformer中的class token和positional embedding等。

而这里的torch.nn.Parameter()就可以很好地适应这种应用场景。

下面是这篇博客的一个总结,笔者认为讲的比较明白,在这里引用一下:

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

ViT中nn.Parameter()的实验

看过这个分析后,我们再看一下Vision Transformer中的用法:

...

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...

我们知道在ViT中,positonal embedding和class token是两个需要随着网络训练学习的参数,但是它们又不属于FC、MLP、MSA等运算的参数,在这时,就可以用nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter。

为了确定这两个参数确实是被添加到了net.Parameters()内,笔者稍微改动源码,显式地指定这两个参数的初始数值为0.98,并打印迭代器net.Parameters()。

...

self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...

实例化一个ViT模型并打印net.Parameters():

net_vit = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

for para in net_vit.parameters():
        print(para.data)

输出结果中可以看到,最前两行就是我们显式指定为0.98的两个参数pos_embedding和cls_token:

tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         ...,
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064,  0.0111,  ...,  0.0091, -0.0041, -0.0060],
        [ 0.0003,  0.0115,  0.0059,  ..., -0.0052, -0.0056,  0.0010],
        [ 0.0079,  0.0016, -0.0094,  ...,  0.0174,  0.0065,  0.0001],
        ...,
        [-0.0110, -0.0137,  0.0102,  ...,  0.0145, -0.0105, -0.0167],
        [-0.0116, -0.0147,  0.0030,  ...,  0.0087,  0.0022,  0.0108],
        [-0.0079,  0.0033, -0.0087,  ..., -0.0174,  0.0103,  0.0021]])
...
...

这就可以确定nn.Parameter()添加的参数确实是被添加到了Parameters列表中,会被送入优化器中随训练一起学习更新。

from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)

其他解释

以下是国外StackOverflow的一个大佬的解读,笔者自行翻译并放在这里供大家参考,想查看原文的同学请戳这里。

我们知道Tensor相当于是一个高维度的矩阵,它是Variable类的子类。Variable和Parameter之间的差异体现在与Module关联时。当Parameter作为model的属性与module相关联时,它会被自动添加到Parameters列表中,并且可以使用net.Parameters()迭代器进行访问。

最初在Torch中,一个Variable(例如可以是某个中间state)也会在赋值时被添加为模型的Parameter。在某些实例中,需要缓存变量,而不是将它们添加到Parameters列表中。

文档中提到的一种情况是RNN,在这种情况下,您需要保存最后一个hidden state,这样就不必一次又一次地传递它。需要缓存一个Variable,而不是让它自动注册为模型的Parameter,这就是为什么我们有一个显式的方法将参数注册到我们的模型,即nn.Parameter类。

举个例子:

import torch
import torch.nn as nn
from torch.optim import Adam

class NN_Network(nn.Module):
    def __init__(self,in_dim,hid,out_dim):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)
        self.linear2 = nn.Linear(hid,out_dim)
        self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
        self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hid))

    def forward(self, input_array):
        h = self.linear1(input_array)
        y_pred = self.linear2(h)
        return y_pred

in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

然后检查一下这个模型的Parameters列表:

for param in net.parameters():
    print(type(param.data), param.size())

""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

可以轻易地送入到优化器中:

opt = Adam(net.parameters(), learning_rate=0.001)

另外,请注意Parameter的require_grad会自动设定。

各位读者有疑惑或异议的地方,欢迎留言讨论。

参考:

https://www.jb51.net/article/238632.htm

https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter

总结

到此这篇关于PyTorch中torch.nn.Parameter()的文章就介绍到这了,更多相关PyTorch中torch.nn.Parameter()内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

PyTorch中的参数类torch.nn.Parameter()详解

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

PyTorch中Torch.arange函数详解

PyTorch是由Facebook开发的开源机器学习库,它用于深度神经网络和自然语言处理,下面这篇文章主要给大家介绍了关于PyTorch中Torch.arange函数详解的相关资料,需要的朋友可以参考下
2023-02-03

Golang形参要求详解:参数类型、参数个数及顺序

Golang形参要求详解:参数类型、参数个数及顺序在Golang中,函数的形参定义非常灵活,可以传递不同类型的参数及不固定数量的参数。形参主要包括参数类型、参数个数及参数顺序,下面将通过具体的代码示例来详细解释。参数类型在Golang中
Golang形参要求详解:参数类型、参数个数及顺序
2024-03-03

Pytorch中的torch.distributions库详解

这篇文章主要介绍了Pytorch中的torch.distributions库,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
2023-02-24

C++ 函数参数详解:运行时类型识别在参数传递中的作用

C++ 函数参数详解:运行时类型识别在参数传递中的作用在 C++ 中,函数参数传递可以通过值传递、引用传递或指针传递实现。每种传递方式都有各自的优缺点。运行时类型识别 (RTTI) 是 C++ 中一种在运行时获取对象类型的机制。它允许我
C++ 函数参数详解:运行时类型识别在参数传递中的作用
2024-04-27

pytorch中BatchNorm2d函数的参数怎么使用

本篇内容主要讲解“pytorch中BatchNorm2d函数的参数怎么使用”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“pytorch中BatchNorm2d函数的参数怎么使用”吧!BN原理、作
2023-07-04

pytorch和numpy默认浮点类型位数详解

这篇文章主要介绍了pytorch和numpy默认浮点类型位数,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2023-02-02

pytorch中backward的参数含义是什么

本文小编为大家详细介绍“pytorch中backward的参数含义是什么”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch中backward的参数含义是什么”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧
2023-07-05

Python中的默认参数详解

文章的主题不要使用可变对象作为函数的默认参数例如 list,dict,因为def是一个可执行语句,只有def执行的时候才会计算默认默认参数的值,所以使用默认参数会造成函数执行的时候一直在使用同一个对象,引起bug。基本原理在 Python
2023-01-31

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录