我的编程空间,编程开发者的网络收藏夹
学习永远不晚

tensorflow使用tf.data.Dataset处理大型数据集问题

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

tensorflow使用tf.data.Dataset处理大型数据集问题

最近深度学习用到的数据集比较大,如果一次性将数据集读入内存,那服务器是顶不住的,所以需要分批进行读取,这里就用到了tf.data.Dataset构建数据集:

概括一下,tf.data.Dataset主要有几个部分最重要:

  • 构建生成器函数
  • 使用tf.data.Dataset的from_generator函数,通过指定数据类型,数据的shape等参数,构建一个Dataset
  • 指定batch_size
  • 使用make_one_shot_iterator()函数,构建一个iterator
  • 使用上面构建的迭代器开始get_next() 。(必须要有这个get_next(),迭代器才会工作)

一.构建生成器

生成器的要点是要在while True中加入yield,yield的功能有点类似return,有yield才能起到迭代的作用。

我的数据是一个[6047, 6000, 1]的文本数据,我每次迭代返回的shape为[1,6000,1],要注意的是返回的shape要和构建Dataset时的shape一致,下面会说到。

代码如下:

def gen():                
        train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
        train.fillna(0, inplace = True)
        label_encoder = LabelEncoder().fit(train[6000])
        label = label_encoder.transform(train[6000])  
        train = train.drop([6000], axis=1) 
        scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
        scaled_train = scaler.transform(train.values)
        #print(scaled_train)
        #拆分训练集和测试集--------------
        sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
        for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
            X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
            y_train, y_valid=label[train_index], label[valid_index]
        X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
        X_train_r[:,: ,0]=X_train[:,0:6000]     
    
        X_valid_r=np.zeros((len(X_valid), 6000, 1))
        X_valid_r[:,: ,0]=X_valid[:,0:6000]
    
        y_train=np_utils.to_categorical(y_train, 3)
        y_valid=np_utils.to_categorical(y_valid, 3)
        
        leng=len(X_train_r)
        index=0
        while True:
            x_train_batch=X_train_r[index, :, 0:1]
            y_train_batch=y_train[index, :]
            yield (x_train_batch, y_train_batch)
            index=index+1
            if index>leng:
                break

代码中while True上面的部分是标准化数据的代码,可以不用看,只需要看 while True中的代码即可。

x_train_batch, y_train_batch都只是一行的数据,这里是一行一行数据迭代。

二.使用tf.data.Dataset包装生成器

data=tf.data.Dataset.from_generator(gen_1, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3])))
data=data.batch(128)
iterator=data.make_one_shot_iterator()

这里的tf.TensorShape([6000,1]) 和 tf.TensorShape([3])中的shape要和上面生成器yield返回的数据的shape一致。

  • data=data.batch(128)是设置batchsize,这里设为128,在运行时,因为我们yield的是一行的数据[1, 6000, 1],所以将会循环yield够128次,得到[128, 6000, 1],即一个batch,才会开始训练。
  • iterator=data.make_one_shot_iterator()是构建迭代器,one_shot迭代器人如其名,意思就是数据输出一次后就丢弃了。

三.获取生成器返回的数据

x, y=iterator.get_next()
x_batch, y_batch=sess.run([x,y])

注意要有get_next(),迭代器才能开始工作。

第二行是run第一行代码。获取训练数据和训练标签。

这里做个关于yield的小笔记:

上一次迭代,yield返回了值,然后get_next()开启了下一次迭代,此时,程序是从yield处开始运行的,也就是说,如果yield后面还有程序,那就会运行yield后面的程序。一直运行的是while True中的程序,没有运行while True外面的程序。

下面是我写的总的代码。可以不用看。

import os
import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
from keras.models import Sequential, Model
from keras.layers import Dense, Activation, Flatten, Conv1D, Dropout, MaxPooling1D, GlobalAveragePooling1D
from keras.layers import GlobalAveragePooling2D,BatchNormalization, UpSampling1D, RepeatVector,Reshape
from keras.layers.core import Lambda
from keras.optimizers import SGD, Adam, Adadelta
from keras.utils import np_utils
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.backend import conv3d,reshape, shape, categorical_crossentropy, mean, square
from keras.applications.vgg16 import VGG16
from keras.layers import Input,LSTM
from keras import regularizers
from keras.utils import multi_gpu_model
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
os.environ["CUDA_VISIBLE_DEVICES"]="2" 
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
keep_prob = tf.placeholder("float")
# 设置session
KTF.set_session(session )

#-----生成训练数据-----------------------------------------------
def gen_1():
    train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
    train.fillna(0, inplace = True)
    label_encoder = LabelEncoder().fit(train[6000])
    label = label_encoder.transform(train[6000])  
    train = train.drop([6000], axis=1) 
    scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
    scaled_train = scaler.transform(train.values)
    #print(scaled_train)
    #拆分训练集和测试集--------------
    sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
    for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
        X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
        y_train, y_valid=label[train_index], label[valid_index]
    X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
    #开始赋值
    #https://stackoverflow.com/questions/43290202/python-typeerror-unhashable-type-slice-for-encoding-categorical-data
    X_train_r[:,: ,0]=X_train[:,0:6000]     

    X_valid_r=np.zeros((len(X_valid), 6000, 1))
    X_valid_r[:,: ,0]=X_valid[:,0:6000]

    y_train=np_utils.to_categorical(y_train, 3)
    y_valid=np_utils.to_categorical(y_valid, 3)
    
    leng=len(X_train_r)
    index=0
    while True:
        x_train_batch=X_train_r[index, :, 0:1]
        y_train_batch=y_train[index, :]
        yield (x_train_batch, y_train_batch)
        index=index+1
        if index>leng:
            break
        
#----生成测试数据--------------------------------------
def gen_2():
    train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
    train.fillna(0, inplace = True)
    label_encoder = LabelEncoder().fit(train[6000])
    label = label_encoder.transform(train[6000])  
    train = train.drop([6000], axis=1) 
    scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
    scaled_train = scaler.transform(train.values)
    #print(scaled_train)
    #拆分训练集和测试集--------------
    sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
    for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
        X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
        y_train, y_valid=label[train_index], label[valid_index]
    X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
    #开始赋值
    #https://stackoverflow.com/questions/43290202/python-typeerror-unhashable-type-slice-for-encoding-categorical-data
    X_train_r[:,: ,0]=X_train[:,0:6000]     

    X_valid_r=np.zeros((len(X_valid), 6000, 1))
    X_valid_r[:,: ,0]=X_valid[:,0:6000]

    y_train=np_utils.to_categorical(y_train, 3)
    y_valid=np_utils.to_categorical(y_valid, 3)
    
    leng=len(X_valid_r)
    index=0
    while True:
        x_test_batch=X_valid_r[index, :, 0:1]
        y_test_batch=y_valid[index, :]
        yield (x_test_batch, y_test_batch)
        index=index+1
        if index>leng:
            break
        
#---------------------------------------------------------------------
        
def custom_mean_squared_error(y_true, y_pred):
    return mean(square(y_pred - y_true))
def custom_categorical_crossentropy(y_true, y_pred):
    return categorical_crossentropy(y_true, y_pred)

def loss_func(y_loss, x_loss):
    return categorical_crossentropy + 0.05 * mean_squared_error

#建立模型
with tf.device('/cpu:0'):
    inputs1=tf.placeholder(tf.float32, shape=(None,6000,1))

    x1=LSTM(128, return_sequences=True)(inputs1)
    encoded=LSTM(64 ,return_sequences=True)(x1)
    print('encoded shape:',shape(encoded))

    #decode
    x1=LSTM(128, return_sequences=True)(encoded)
    decoded=LSTM(1, return_sequences=True,name='decode')(x1)
    #classify
    labels=tf.placeholder(tf.float32, shape=(None,3))
    x2=Conv1D(20,kernel_size=50, strides=2, activation='relu' )(encoded)  #步数论文中未提及,第一层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Conv1D(20,kernel_size=50, strides=2, activation='relu')(x2)   #第二层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)
    x2=Conv1D(24,kernel_size=30, strides=2, activation='relu')(x2)   #第三层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)
    x2=Conv1D(24,kernel_size=30, strides=2, activation='relu')(x2)   #第四层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)
    x2=Conv1D(24,kernel_size=10, strides=2, activation='relu')(x2)  #第五层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)

    x2=Dense(192)(x2) #第一个全连接层
    x2=Dense(192)(x2)  #第二个全连接层
    x2=Flatten()(x2)
    x2=Dense(3,activation='softmax', name='classify')(x2)

    def get_accuracy(x2, labels):
        current = tf.cast(tf.equal(tf.argmax(x2, 1), tf.argmax(labels, 1)), 'float')
        accuracy = tf.reduce_mean(current)
        return accuracy
    #实例化获取准确率函数
    getAccuracy = get_accuracy(x2, labels)
    #定义损失函数
    all_loss=tf.reduce_mean(categorical_crossentropy(x2 , labels) + tf.convert_to_tensor(0.5)*square(decoded-inputs1))
    train_option=tf.train.AdamOptimizer(0.01).minimize(all_loss)
    #-----------------------------------------
    #生成训练数据
    data=tf.data.Dataset.from_generator(gen_1, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3])))
    data=data.batch(128)
    iterator=data.make_one_shot_iterator()
    
    #生成测试数据
    data2=tf.data.Dataset.from_generator(gen_2, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3])))
    data2=data2.batch(128)
    iterator2=data2.make_one_shot_iterator()
    #-----------------------------------------
    with tf.Session() as sess:
        init=tf.global_variables_initializer()
        sess.run(init)
        i=-1
        
        for k in range(20):
            #-----------------------------------------
            x, y=iterator.get_next()
            x_batch, y_batch=sess.run([x,y])
            print('batch shape:',x_batch.shape, y_batch.shape)
            #-----------------------------------------
            if k%2==0:
                print('第',k,'轮')
                x3=sess.run(x2, feed_dict={inputs1:x_batch, labels:y_batch })
                dc=sess.run(decoded, feed_dict={inputs1:x_batch})
                accuracy=sess.run(getAccuracy, feed_dict={x2:x3, labels:y_batch, keep_prob: 1.0})
                loss=sess.run(all_loss, feed_dict={x2:x3, labels:y_batch, inputs1:x_batch, decoded:dc})
                print("step(s): %d ----- accuracy: %g -----loss: %g" % (i, accuracy, loss))
                sess.run(train_option, feed_dict={inputs1:x_batch, labels:y_batch, keep_prob: 0.5})
        x, y=iterator2.get_next()
        x_test_batch, y_test_batch=sess.run([x,y])
        print('batch shape:',x_test_batch.shape, y_test_batch.shape)
        x_test=sess.run(x2, feed_dict={inputs1:x_test_batch, labels:y_test_batch })
        print ("test accuracy %f"%getAccuracy.eval(feed_dict={x2:x_test, labels:y_test_batch, keep_prob: 1.0}))

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

tensorflow使用tf.data.Dataset处理大型数据集问题

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

tensorflow使用tf.data.Dataset处理大型数据集问题

这篇文章主要介绍了tensorflow使用tf.data.Dataset处理大型数据集问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2022-12-16

如何在TensorFlow中使用数据集API加载和处理数据

在TensorFlow中,可以使用数据集API来加载和处理数据。下面是一个简单的例子,展示如何使用数据集API加载和处理数据:import tensorflow as tf# 创建一个数据集data = tf.data.Dataset.
如何在TensorFlow中使用数据集API加载和处理数据
2024-03-01

MariaDB中如何处理大型数据集

在MariaDB中处理大型数据集时,可以采取以下几种方法:数据分区:可以将大型表拆分成多个小表,每个小表处理的数据量更小,查询效率更高。可以按照时间范围、地理位置等条件对数据进行分区。索引优化:通过在表的列上创建适当的索引,可以加快查询速
MariaDB中如何处理大型数据集
2024-04-09

PostgreSQL中如何处理大型数据集和高并发访问

处理大型数据集和高并发访问是 PostgreSQL 数据库管理员经常面临的挑战之一。以下是一些处理大型数据集和高并发访问的常用方法:分区表:将数据表按照某种规则进行分区,可以将大型数据集分解成更小的部分,便于管理和查询。这样可以减少查询时需
PostgreSQL中如何处理大型数据集和高并发访问
2024-04-09

如何使用泛型解决golang中数据处理问题

go 中的泛型允许创建处理各种类型数据的函数和类型,从而简化数据处理。它通过类型参数实现,这些参数可以在函数和类型中使用,强制执行类型安全并提高代码重用性、可读性和可维护性。如何使用泛型解决 Go 中的数据处理问题背景在 Go 1.18
如何使用泛型解决golang中数据处理问题
2024-05-04

使用C++构建机器学习模型:大型数据集的处理技巧

通过利用 c++++ 的优势,我们可以构建机器学习模型来处理大型数据集:优化内存管理:使用智能指针(如 unique_ptr、shared_ptr)使用内存池并行化处理:多线程(使用 std::thread 库)openmp 并行编程标准c
使用C++构建机器学习模型:大型数据集的处理技巧
2024-05-12

使用Golang函数处理大数据集的策略

在 golang 中处理大数据集时,有效运用函数式特性至关重要,高阶函数(map、filter、reduce)可高效操作集合。此外,并发处理(goroutine 和 sync.waitgroup)和流式处理(channel 和 for-ra
使用Golang函数处理大数据集的策略
2024-04-12

C#开发中如何处理大数据集的操作问题

C#开发中如何处理大数据集的操作问题,需要具体代码示例摘要:在现代软件开发中,大数据已成为一种常见的数据处理形式。如何高效地处理大数据集是一个重要的问题。本文将介绍C#中处理大数据集的一些常见问题和解决方法,并提供具体的代码示例。数据集拆分
2023-10-22

巨大数据集处理:使用Go WaitGroup优化性能

在处理巨大数据集时,使用Go的WaitGroup可以帮助优化性能。WaitGroup是Go语言中用于等待一组goroutine完成任务的机制。下面是使用WaitGroup优化性能的基本步骤:1. 创建WaitGroup对象:在开始处理数据集
2023-10-12

STL 函数对象在优化大型数据集处理中的作用?

使用 stl 函数对象可以显著优化大型数据集处理。stl 提供了许多函数对象,例如 std::function、std::bind、std::for_each、std::transform 和 std::sort,它们可以用来提升处理效率。
STL 函数对象在优化大型数据集处理中的作用?
2024-04-26

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录