我的编程空间,编程开发者的网络收藏夹
学习永远不晚

解决pytorch rnn 变长输入序列的问题

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

解决pytorch rnn 变长输入序列的问题

pytorch实现变长输入的rnn分类

输入数据是长度不固定的序列数据,主要讲解两个部分

1、Data.DataLoader的collate_fn用法,以及按batch进行padding数据

2、pack_padded_sequence和pad_packed_sequence来处理变长序列

collate_fn

Dataloader的collate_fn参数,定义数据处理和合并成batch的方式。

由于pack_padded_sequence用到的tensor必须按照长度从大到小排过序的,所以在Collate_fn中,需要完成两件事,一是把当前batch的样本按照当前batch最大长度进行padding,二是将padding后的数据从大到小进行排序。


def pad_tensor(vec, pad):
    """
    args:
        vec - tensor to pad
        pad - the size to pad to
    return:
        a new tensor padded to 'pad'
    """
    return torch.cat([vec, torch.zeros(pad - len(vec), dtype=torch.float)], dim=0).data.numpy()
class Collate:
    """
    a variant of callate_fn that pads according to the longest sequence in
    a batch of sequences
    """
    def __init__(self):
        pass
    def _collate(self, batch):
        """
        args:
            batch - list of (tensor, label)
        reutrn:
            xs - a tensor of all examples in 'batch' before padding like:
                '''
                [tensor([1,2,3,4]),
                 tensor([1,2]),
                 tensor([1,2,3,4,5])]
                '''
            ys - a LongTensor of all labels in batch like:
                '''
                [1,0,1]
                '''
        """
        xs = [torch.FloatTensor(v[0]) for v in batch]
        ys = torch.LongTensor([v[1] for v in batch])
        # 获得每个样本的序列长度
        seq_lengths = torch.LongTensor([v for v in map(len, xs)])
        max_len = max([len(v) for v in xs])
        # 每个样本都padding到当前batch的最大长度
        xs = torch.FloatTensor([pad_tensor(v, max_len) for v in xs])
        # 把xs和ys按照序列长度从大到小排序
        seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
        xs = xs[perm_idx]
        ys = ys[perm_idx]
        return xs, seq_lengths, ys
    def __call__(self, batch):
        return self._collate(batch)

定义完collate类以后,在DataLoader中直接使用


train_data = Data.DataLoader(dataset=train_dataset, batch_size=32, num_workers=0, collate_fn=Collate())

torch.nn.utils.rnn.pack_padded_sequence()

pack_padded_sequence将一个填充过的变长序列压紧。输入参数包括

input(Variable)- 被填充过后的变长序列组成的batch data

lengths (list[int]) - 变长序列的原始序列长度

batch_first (bool,optional) - 如果是True,input的形状应该是(batch_size,seq_len,input_size)

返回值:一个PackedSequence对象,可以直接作为rnn,lstm,gru的传入数据。

用法:


from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# x是填充过后的batch数据,seq_lengths是每个样本的序列长度
packed_input = pack_padded_sequence(x, seq_lengths, batch_first=True)

RNN模型

定义了一个单向的LSTM模型,因为处理的是变长序列,forward函数传入的值是一个PackedSequence对象,返回值也是一个PackedSequence对象


class Model(nn.Module):
    def __init__(self, in_size, hid_size, n_layer, drop=0.1, bi=False):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(input_size=in_size,
                            hidden_size=hid_size,
                            num_layers=n_layer,
                            batch_first=True,
                            dropout=drop,
                            bidirectional=bi)
        # 分类类别数目为2
        self.fc = nn.Linear(in_features=hid_size, out_features=2)
    def forward(self, x):
        '''
        :param x: 变长序列时,x是一个PackedSequence对象
        :return: PackedSequence对象
        '''
        # lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)
        lstm_out, _ = self.lstm(x)  
        
        return lstm_out
model = Model()
lstm_out = model(packed_input)

torch.nn.utils.rnn.pad_packed_sequence()

这个操作和pack_padded_sequence()是相反的,把压紧的序列再填充回来。因为前面提到的LSTM模型传入和返回的都是PackedSequence对象,所以我们如果想要把返回的PackedSequence对象转换回Tensor,就需要用到pad_packed_sequence函数。

参数说明:

sequence (PackedSequence) – 将要被填充的 batch

batch_first (bool, optional) – 如果为True,返回的数据的形状为(batch_size,seq_len,input_size)

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

用法:


# 此处lstm_out是一个PackedSequence对象
output, _ = pad_packed_sequence(lstm_out)

返回的output是一个形状为(batch_size,seq_len,input_size)的tensor。

总结

1、pytorch在自定义dataset时,可以在DataLoader的collate_fn参数中定义对数据的变换,操作以及合成batch的方式。

2、处理变长rnn问题时,通过pack_padded_sequence()将填充的batch数据转换成PackedSequence对象,直接传入rnn模型中。通过pad_packed_sequence()来将rnn模型输出的PackedSequence对象转换回相应的Tensor。

补充:pytorch实现不定长输入的RNN / LSTM / GRU

情景描述

As we all know,RNN循环神经网络(及其改进模型LSTM、GRU)可以处理序列的顺序信息,如人类自然语言。但是在实际场景中,我们常常向模型输入一个批次(batch)的数据,这个批次中的每个序列往往不是等长的。

pytorch提供的模型(nn.RNN,nn.LSTM,nn.GRU)是支持可变长序列的处理的,但条件是传入的数据必须按序列长度排序。本文针对以下两种场景提出解决方法。

1、每个样本只有一个序列:(seq,label),其中seq是一个长度不定的序列。则使用pytorch训练时,我们将按列把一个批次的数据输入网络,seq这一列的形状就是(batch_size, seq_len),经过编码层(如word2vec)之后的形状是(batch_size, seq_len, emb_size)。

2、情况1的拓展:每个样本有两个(或多个)序列,如(seq1, seq2, label)。这种样本形式在问答系统、推荐系统多见。

通用解决方案

定义ImprovedRnn类。与nn.RNN,nn.LSTM,nn.GRU相比,除了此两点【①forward函数多一个参数lengths表示每个seq的长度】【②初始化函数(__init__)第一个参数module必须指定三者之一】外,使用方法完全相同。


import torch
from torch import nn
class ImprovedRnn(nn.Module):
    def __init__(self, module, *args, **kwargs):
        assert module in (nn.RNN, nn.LSTM, nn.GRU)
        super().__init__()
        self.module = module(*args, **kwargs)
    def forward(self, input, lengths):  # input shape(batch_size, seq_len, input_size)
        if not hasattr(self, '_flattened'):
            self.module.flatten_parameters()
            setattr(self, '_flattened', True)
        max_len = input.shape[1]
        # enforce_sorted=False则自动按lengths排序,并且返回值package.unsorted_indices可用于恢复原顺序
        package = nn.utils.rnn.pack_padded_sequence(input, lengths.cpu(), batch_first=self.module.batch_first, enforce_sorted=False)
        result, hidden = self.module(package)
        # total_length参数一般不需要,因为lengths列表中一般含最大值。但分布式训练时是将一个batch切分了,故一定要有!
        result, lens = nn.utils.rnn.pad_packed_sequence(result, batch_first=self.module.batch_first, total_length=max_len)
        return result[package.unsorted_indices], hidden  # output shape(batch_size, seq_len, rnn_hidden_size)

使用示例:


class TestNet(nn.Module):
    def __init__(self, word_emb, gru_in, gru_out):
        super().__init__()
        self.encode = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
        self.rnn = ImprovedRnn(nn.RNN, input_size=gru_in, hidden_size=gru_out,
		        				batch_first=True, bidirectional=True)
    def forward(self, seq1, seq1_lengths, seq2, seq2_lengths):
        seq1_emb = self.encode(seq1)
        seq2_emb = self.encode(seq2)
        rnn1, hn = self.rnn(seq1_emb, seq1_lengths)
        rnn2, hn = self.rnn(seq2_emb, seq2_lengths)
        """
        此处略去rnn1和rnn2的后续计算,当前网络最后计算结果记为prediction
        """
        return prediction

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

解决pytorch rnn 变长输入序列的问题

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

如何解决pytorch显存一直变大的问题

本篇内容介绍了“如何解决pytorch显存一直变大的问题”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!在代码中添加以下两行可以解决:torc
2023-06-14

如何解决JSON反序列化Long变Integer或Double的问题

这篇文章主要为大家展示了“如何解决JSON反序列化Long变Integer或Double的问题”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“如何解决JSON反序列化Long变Integer或Do
2023-06-26

Gojson反序列化“null“的问题解决

本文主要介绍了Gojson反序列化“null“的问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-03-14

Python 解决空列表.append() 输出为None的问题

想要实现的功能 空列表中添加数 原代码:FitnessBest = [] FitnessBest = FitnessBest.append(fitnessVal[0, 0]) print(FitnessBest)输出:None解决办法分析
2022-06-02

Android TimePicker 直接输入的问题解决方案

Android TimePicker 直接输入的问题解决方案 TimePicker 提供了上下的按钮,点击按钮,相关操作都是正常的。但是如果直接在输入框中修改小时或分钟后直接点击按钮取值,会发现不能真正改变时间。以下代码得不到预期结果。@O
2022-06-06

SpringBoot之Json的序列化和反序列化问题怎么解决

这篇文章主要讲解了“SpringBoot之Json的序列化和反序列化问题怎么解决”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“SpringBoot之Json的序列化和反序列化问题怎么解决”吧
2023-07-02

vue+element-ui中form输入框无法输入问题的解决方法

很多初次接触element-ui的同学,在用到elementform组件时可能会遇到input框无法输入文字的问题,下面这篇文章主要给大家介绍了关于vue+element-ui中form输入框无法输入问题的解决方法,需要的朋友可以参考下
2023-05-14

Android输入法弹出时覆盖输入框问题的解决方法

当一个activity中含有输入框时,我们点击输入框,会弹出输入法界面,整个界面的变化效果与manifest中对应设置的android:windowSoftInputMode属性有关,一般可以设置的值如下,
2022-06-06

useState 解决文本框无法输入的问题详解

这篇文章主要为大家介绍了useState 解决文本框无法输入的问题详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2023-03-15

Vue-element中el-input输入卡顿问题的解决

这篇文章主要介绍了Vue-element中el-input输入卡顿问题的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2023-05-17

解决MySQL无法输入中文字符的问题

文章目录 问题描述问题排查解决方案1️⃣创建数据库时设置字符集为utf82️⃣修改数据库配置文件【比较麻烦】 写在最后 前几日在使用MySQL数据库的时候,出现了一处保存,故作此记录✍ 问题描述 下面是我这样exa
2023-08-16

Go json反序列化“null“的问题如何解决

本文小编为大家详细介绍“Go json反序列化“null“的问题如何解决”,内容详细,步骤清晰,细节处理妥当,希望这篇“Go json反序列化“null“的问题如何解决”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧
2023-07-05

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录