我的编程空间,编程开发者的网络收藏夹
学习永远不晚

解读torch.nn.GRU的输入及输出示例

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

解读torch.nn.GRU的输入及输出示例

我们有时会看到GRU中输入的参数有时是一个,但是有时又有两个。这难免会让人们感到疑惑,那么这些参数到底是什么呢。

一、输入到GRU的参数

输入的参数有两个,分别是input和h_0。

Inputs: input, h_0

①input的shape

The shape of input:(seq_len, batch, input_size) : tensor containing the feature of the input sequence. The input can also be a packed variable length sequence。

See functorch.nn.utils.rnn.pack_padded_sequencefor details.

②h_0的shape

从下面的解释中也可以看出,这个参数可以不提供,那么就默认为0.

The shape of h_0:(num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch.

Defaults to zero if not provided. If the RNN is bidirectional num_directions should be 2, else it should be 1.

综上,可以只输入一个参数。当输入两个参数的时候,那么第二个参数相当于是一个隐含层的输出。

为了便于理解,下面是一幅图:

二、GRU返回的数据

输出有两个,分别是output和h_n

①output

output 的shape是:(seq_len, batch, num_directions * hidden_size): tensor containing the output features h_t from the last layer of the GRU, for each t.

If a class:torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.

For the unpacked case, the directions can be separated using output.view(seq_len, batch, num_directions, hidden_size), with forward and backward being direction 0 and 1 respectively.

Similarly, the directions can be separated in the packed case.

②h_n

h_n的shape是:(num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len
Like output, the layers can be separated using
h_n.view(num_layers, num_directions, batch, hidden_size).

三、代码示例

数据的shape是[batch,seq_len,emb_dim]

RNN接收输入的数据的shape是[seq_len,batch,emb_dim]

即前两个维度调换就行了。

可以知道,加入批处理的时候一次处理128个句子,每个句子中有5个单词,那么上图中展示的input_data的shape是:[128,5,emb_dim]。

结合代码分析,本例子将演示有1个句子和5个句子的情况。假设每个句子中有9个单词,所以seq_len=9,并且每个单词对应的emb_dim=3,所以对应数据的shape是: [batch,9,3],由于输入到RNN中数据格式的格式,所以为[9,batch,3]

import torch
import torch.nn as nn

emb_dim = 3
hidden_dim = 2
rnn = nn.GRU(emb_dim,hidden_dim)
#rnn = nn.GRU(9,1,3)
print(type(rnn))

tensor1 = torch.tensor([[-0.5502, -0.1920, 1.1845],
[-0.8003, 2.0783, 0.0175],
[ 0.6761, 0.7183, -1.0084],
[ 0.9514, 1.4772, -0.2271],
[-1.0146, 0.7912, 0.2003],
[-0.5502, -0.1920, 1.1845],
[-0.8003, 2.0783, 0.0175],
[ 0.1718, 0.1070, 0.4255],
[-2.6727, -1.5680, -0.8369]])

tensor2 = torch.tensor([[-0.5502, -0.1920]])

# 假设input只有一个句子,那么batch为1
print('--------------batch=1时------------')
data = tensor1.unsqueeze(0)
h_0 = tensor2[0].unsqueeze(0).unsqueeze(0)
print('data.shape: [batch,seq_len,emb_dim]',data.shape)
print('')
input = data.transpose(0,1)
print('input.shape: [seq_len,batch,emb_dim]',input.shape)
print('h_0.shape: [1,batch,hidden_dim]',h_0.shape)
print('')
# 输入到rnn中
output,h_n = rnn(input,h_0)
print('output.shape: [seq_len,batch,hidden_dim]',output.shape)
print('h_n.shape: [1,batch,hidden_dim]',h_n.shape)

# 假设input中有5个句子,所以,batch = 5
print('\n--------------batch=5时------------')
data = tensor1.unsqueeze(0).repeat(5,1,1) # 由于batch为5
h_0 = tensor2[0].unsqueeze(0).repeat(1,5,1) # 由于batch为5
print('data.shape: [batch,seq_len,emb_dim]',data.shape)
print('')
input = data.transpose(0,1)

print('input.shape: [seq_len,batch,emb_dim]',input.shape)
print('h_0.shape: [1,batch,hidden_dim]',h_0.shape)
print('')
# 输入到rnn中
output,h_n = rnn(input,h_0)
print('output.shape: [seq_len,batch,hidden_dim]',output.shape)
print('h_n.shape: [1,batch,hidden_dim]',h_n.shape)

四、输出

<class ‘torch.nn.modules.rnn.GRU’>
--------------batch=1时------------
data.shape: [batch,seq_len,emb_dim] torch.Size([1, 9, 3])

input.shape: [seq_len,batch,emb_dim] torch.Size([9, 1, 3])
h_0.shape: [1,batch,hidden_dim] torch.Size([1, 1, 2])

output.shape: [seq_len,batch,hidden_dim] torch.Size([9, 1, 2])
h_n.shape: [1,batch,hidden_dim] torch.Size([1, 1, 2])

--------------batch=5时------------
data.shape: [batch,seq_len,emb_dim] torch.Size([5, 9, 3])

input.shape: [seq_len,batch,emb_dim] torch.Size([9, 5, 3])
h_0.shape: [1,batch,hidden_dim] torch.Size([1, 5, 2])

output.shape: [seq_len,batch,hidden_dim] torch.Size([9, 5, 2])
h_n.shape: [1,batch,hidden_dim] torch.Size([1, 5, 2])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

解读torch.nn.GRU的输入及输出示例

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

解读torch.nn.GRU的输入及输出示例

这篇文章主要介绍了解读torch.nn.GRU的输入及输出示例,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2023-01-28

C++输入和输出流的示例分析

这篇文章给大家分享的是有关C++输入和输出流的示例分析的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。输入和输出流从键盘输入数据,输出到显示器屏幕。这种输入输出称为标准的输入输出,简称标准I/O。从磁盘文件输入数据
2023-06-29

JAVA语言输入输出流的示例代码

这篇文章主要介绍了JAVA语言输入输出流的示例代码,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。public class IOStreamDemo { public
2023-06-03

Java中输入/输出流体系的示例分析

这篇文章主要介绍Java中输入/输出流体系的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!Java输入/输出流体系1.字节流和字符流字节流:按字节读取。字符流:按字符读取。字符流读取方便,字节流功能强大,当不
2023-05-30

Java IO中字节输入输出流的示例分析

这篇文章主要介绍Java IO中字节输入输出流的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!讲的是字节输入输出流:InputStream、OutputSteam(下图红色长方形框内),红色椭圆框内是其典型实
2023-06-26

java 对象输入输出流读写文件的操作实例

java 对象输入输出流读写文件的操作实例java 支持对对象的读写操作,所操作的对象必须实现Serializable接口。 实例代码:package vo; import java.io.Serializable; public cl
2023-05-31

Ruby迭代器及文件的输入与输出

这篇文章主要介绍了Ruby的迭代器和文件的输入输出,文章中有详细的代码示例,需要的朋友可以参考阅读一下
2023-05-15

Java中I/O输入输出的深入讲解

Java的I/O技术可以将数据保存到文本文件、二进制文件甚至是ZIP压缩文件中,以达到永久性保存数据的要求,下面这篇文章主要给大家介绍了关于Java中I/O输入输出的相关资料,需要的朋友可以参考下
2022-11-13

Python的输入和输出问题详解

输出用print()在括号中加上字符串,就可以向屏幕上输出指定的文字。比如输出'hello, world',用代码实现如下:>>> print('hello, world')print()函数也可以接受多个字符串,用逗号“,”隔开,就可以连
2023-01-30

Ruby迭代器及文件的输入与输出实例代码分析

这篇文章主要介绍“Ruby迭代器及文件的输入与输出实例代码分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“Ruby迭代器及文件的输入与输出实例代码分析”文章能帮助大家解决问题。Ruby 迭代器简单
2023-07-06

javaFileOutputStream输出流的使用解读

这篇文章主要介绍了javaFileOutputStream输出流的使用解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2022-12-26

JAVA语言的输入输出流详解(c)

详解b中的例子,详解[@more@]  1. BufferedReader是Reader的一个子类,它具有缓冲的作用,避免了频繁的从物理设备中读取信息。它有以下两个构造函数:BufferedReader(Reader in) Buffere
2023-06-03

python语言中流程的输入与输出案例

这篇文章将为大家详细讲解有关python语言中流程的输入与输出案例,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。流程中的输入,一般都会先保存在变量(即内存)中,而这个输入,可以来自于键盘(也称为标准输入)
2023-06-19

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录