我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Tensorflow与RNN、双向LSTM等的踩坑记录及解决

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Tensorflow与RNN、双向LSTM等的踩坑记录及解决

1、tensorflow(不定长)文本序列读取与解析

tensorflow读取csv时需要指定各列的数据类型。

但是对于RNN这种接受序列输入的模型来说,一条序列的长度是不固定。这时如果使用csv存储序列数据,应当首先将特征序列拼接成一列。

例如两条数据序列,第一项是标签,之后是特征序列

[0, 1.1, 1.2, 2.3] 转换成 [0, '1.1_1.2_2.3']

[1, 1.0, 2.5, 1.6, 3.2, 4.5] 转换成 [1, '1.0_2.5_1.6_3.2_4.5']

这样每条数据都只包含固定两列了。

读取方式是指定第二列为字符串类型,再将字符串按照'_'分割并转换为数字。

关键的几行代码示例如下:


def readMyFileFormat(fileNameQueue):
    reader = tf.TextLineReader()
    key, value = reader.read(fileNameQueue)

    record_defaults = [["Null"], [-1], ["Null"], ["Null"], [-1]]
    phone1, seqlen, ts_diff_strseq, t_cod_strseq, userlabel = tf.decode_csv(value, record_defaults=record_defaults)
    ts_diff_str = tf.string_split([ts_diff_strseq], delimiter='_')
    t_cod_str = tf.string_split([t_cod_strseq], delimiter='_')
    # 每个字符串转数字
    Str2Float = lambda string: tf.string_to_number(string, tf.float32)
    Str2Int = lambda string: tf.string_to_number(string, tf.int32)
    ts_diff_seq = tf.map_fn(Str2Float, ts_diff_str.values, dtype = tf.float32) # 一定要加上dtype,且必须与fn的输出类型一致
    t_cod_seq = tf.map_fn(Str2Int, t_cod_str.values, dtype = tf.int32)

2、时序建模的序列预测、序列拟合、标签预测,及输入数据格式

序列预测、拟合的“标签”都是序列本身,区别是未来时刻或者是当前时刻,当前时刻的拟合任务类似于antoencoder的reconstruction

标签预测常见于语言学建模,有单词级标签的分词与整句标签的情感分析,前者需要对每一个单词输入都要输出其分词标识,后者是取最后若干输出级联前馈神经网络分类器

keras的输入-输出对:需要将序列拆分成多个片段

序列形式:

按时间列表:static_bidirectional_rnn

多维数组:bidirectional_dynamic_rnn与stack_bidirectional_dynamic_rnn 变长双向rnn的正确使用姿势

3、多任务设置及相应的输出向量划分

对于标签预测任务,按需取输出即可

对于序列预测、拟合:

双向lstm:通常用于拟合。但如果需要捕捉动态信息,尽管需要序列完整输入,则仍可以加上正向预测与反向预测

单向lstm:拟合与预测

4、zero padding

后一般需要通过tf.boolean_mask()隔离这些零的影响,函数输入包括数据矩阵和补零位置的指示矩阵。

5、get_shape()方法

与 tf.shape() 类型区别,前者得到一个list,后者得到一个tensor

6、双向LSTM的信息瓶颈的解决

在这里插入图片描述

如果在时间步的最后输出,则可能会导致开始的一些字符被遗忘门给遗忘。

所以这里就对每个时间步的输出做出了处理,

主要处理有:

1、拼接:把所有的输出拼接在一起。

2、Average

3、Pooling

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Tensorflow与RNN、双向LSTM等的踩坑记录及解决

下载Word文档到电脑,方便收藏和打印~

下载Word文档

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录