我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Pytorch中TensorDataset与DataLoader怎么联合使用

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Pytorch中TensorDataset与DataLoader怎么联合使用

这篇文章主要介绍了Pytorch中TensorDataset与DataLoader怎么联合使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch中TensorDataset与DataLoader怎么联合使用文章都会有所收获,下面我们一起来看看吧。

    Pytorch中TensorDataset,DataLoader的联合使用

    首先从字面意义上来理解TensorDataset和DataLoader,TensorDataset是个只用来存放tensor(张量)的数据集,而DataLoader是一个数据加载器,一般用到DataLoader的时候就说明需要遍历和操作数据了。

    TensorDataset(tensor1,tensor2)的功能就是形成数据tensor1和标签tensor2的对应,也就是说tensor1中是数据,而tensor2是tensor1所对应的标签。

    来个小例子

    from torch.utils.data import TensorDataset,DataLoaderimport torch a = torch.tensor([[1, 2, 3],                  [4, 5, 6],                  [7, 8, 9],                  [1, 2, 3],                  [4, 5, 6],                  [7, 8, 9],                  [1, 2, 3],                  [4, 5, 6],                  [7, 8, 9],                  [1, 2, 3],                  [4, 5, 6],                  [7, 8, 9]]) b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])train_ids = TensorDataset(a,b)# 切片输出print(train_ids[0:4]) # 第0,1,2,3行# 循环取数据for x_train,y_label in train_ids:    print(x_train,y_label)

    下面是对应的输出:

    (tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9],
            [1, 2, 3]]), tensor([44, 55, 66, 44]))
    ===============================================
    tensor([1, 2, 3]) tensor(44)
    tensor([4, 5, 6]) tensor(55)
    tensor([7, 8, 9]) tensor(66)
    tensor([1, 2, 3]) tensor(44)
    tensor([4, 5, 6]) tensor(55)
    tensor([7, 8, 9]) tensor(66)
    tensor([1, 2, 3]) tensor(44)
    tensor([4, 5, 6]) tensor(55)
    tensor([7, 8, 9]) tensor(66)
    tensor([1, 2, 3]) tensor(44)
    tensor([4, 5, 6]) tensor(55)
    tensor([7, 8, 9]) tensor(66)

    从输出结果我们就可以很好的理解,tensor型数据和tensor型标签的对应了,这就是TensorDataset的基本应用。

    接下来我们把构造好的TensorDataset封装到DataLoader来操作里面的数据:

    # 参数说明,dataset=train_ids表示需要封装的数据集,batch_size表示一次取几个# shuffle表示乱序取数据,设为False表示顺序取数据,True表示乱序取数据train_loader = DataLoader(dataset=train_ids,batch_size=4,shuffle=False)# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)for i,data in enumerate(train_loader,1):    train_data, label = data    print(' batch:{0} train_data:{1}  label: {2}'.format(i+1, train_data, label))

    下面是对应的输出:

     batch:1 x_data:tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9],
            [1, 2, 3]])  label: tensor([44, 55, 66, 44])
     batch:2 x_data:tensor([[4, 5, 6],
            [7, 8, 9],
            [1, 2, 3],
            [4, 5, 6]])  label: tensor([55, 66, 44, 55])
     batch:3 x_data:tensor([[7, 8, 9],
            [1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])  label: tensor([66, 44, 55, 66])

    至此,TensorDataset和DataLoader的联合使用就介绍完了。

    我们再看一下这两种方法的源码:

    class TensorDataset(Dataset[Tuple[Tensor, ...]]):    r"""Dataset wrapping tensors.    Each sample will be retrieved by indexing tensors along the first dimension.    Arguments:        *tensors (Tensor): tensors that have the same size of the first dimension.    """    tensors: Tuple[Tensor, ...]     def __init__(self, *tensors: Tensor) -> None:        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)        self.tensors = tensors     def __getitem__(self, index):        return tuple(tensor[index] for tensor in self.tensors)     def __len__(self):        return self.tensors[0].size(0) # 由于此类内容过多,故仅列举了与本文相关的参数,其余参数可以自行去查看源码class DataLoader(Generic[T_co]):    r"""    Data loader. Combines a dataset and a sampler, and provides an iterable over    the given dataset.    The :class:`~torch.utils.data.DataLoader` supports both map-style and    iterable-style datasets with single- or multi-process loading, customizing    loading order and optional automatic batching (collation) and memory pinning.    See :py:mod:`torch.utils.data` documentation page for more details.    Arguments:        dataset (Dataset): dataset from which to load the data.        batch_size (int, optional): how many samples per batch to load            (default: ``1``).        shuffle (bool, optional): set to ``True`` to have the data reshuffled            at every epoch (default: ``False``).    """    dataset: Dataset[T_co]    batch_size: Optional[int]     def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,                 shuffle: bool = False):         self.dataset = dataset        self.batch_size = batch_size

    Pytorch的DataLoader和Dataset以及TensorDataset的源码分析

    1.为什么要用DataLoader和Dataset

    要对大量数据进行加载和处理时因为可能会出现内存不够用的情况,这时候就需要用到数据集类Dataset或TensorDataset和数据集加载类DataLoader了。

    使用这些类后可以将原本的数据分成小块,在需要使用的时候再一部分一本分读进内存中,而不是一开始就将所有数据读进内存中。

    2.Dateset的使用

    pytorch中的torch.utils.data.Dataset是表示数据集的抽象类,但它一般不直接使用,而是通过自定义一个数据集来使用。

    来自定义数据集应该继承Dataset并应该有实现返回数据集尺寸的__len__方法和用来获取索引数据的__getitem__方法。

    Dataset类的源码如下:

    class Dataset(object):    r"""An abstract class representing a :class:`Dataset`.    All datasets that represent a map from keys to data samples should subclass    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a    data sample for a given key. Subclasses could also optionally overwrite    :meth:`__len__`, which is expected to return the size of the dataset by many    :class:`~torch.utils.data.Sampler` implementations and the default options    of :class:`~torch.utils.data.DataLoader`.    .. note::      :class:`~torch.utils.data.DataLoader` by default constructs a index      sampler that yields integral indices.  To make it work with a map-style      dataset with non-integral indices/keys, a custom sampler must be provided.    """    def __getitem__(self, index):        raise NotImplementedError    def __add__(self, other):        return ConcatDataset([self, other])    # No `def __len__(self)` default?    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]    # in pytorch/torch/utils/data/sampler.py

    可以看到Dataset类中没有__len__方法,虽然有__getitem__方法,但是并没有实现啥有用的功能。

    所以要写一个Dataset类的子类来实现其应有的功能。

    自定义类的实现举例:

    import torchfrom torch.utils.data import Dataset, DataLoader, TensorDatasetfrom torch.autograd import Variableimport numpy as npimport pandas as pdvalue_df = pd.read_csv('data1.csv')value_array = np.array(value_df)print("value_array.shape =", value_array.shape)  # (73700, 300)value_size = value_array.shape[0]  # 73700train_size = int(0.7*value_size)train_array = val_array[:train_size]  train_label_array = val_array[60:train_size+60]class DealDataset(Dataset):    """        下载数据、初始化数据,都可以在这里完成    """    def __init__(self, *arrays):        assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)        self.arrays = arrays    def __getitem__(self, index):        return tuple(array[index] for array in self.arrays)    def __len__(self):        return self.arrays[0].shape[0]# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。train_dataset = DealDataset(train_array, train_label_array)train_loader2 = DataLoader(dataset=train_dataset,                           batch_size=32,                           shuffle=True)for epoch in range(2):    for i, data in enumerate(train_loader2):        # 将数据从 train_loader 中读出来,一次读取的样本数是32个        inputs, labels = data        # 将这些数据转换成Variable类型        inputs, labels = Variable(inputs), Variable(labels)        # 接下来就是跑模型的环节了,我们这里使用print来代替        print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

    结果:

    epoch: 0 的第 0 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
    epoch: 0 的第 1 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
    epoch: 0 的第 2 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
    epoch: 0 的第 3 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
    epoch: 0 的第 4 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
    epoch: 0 的第 5 个inputs torch.Size([32, 300]) labels torch.Size([32, 300])
    ...

    3.TensorDataset的使用

    TensorDataset是可以直接使用的数据集类,它的源码如下:

    class TensorDataset(Dataset):    r"""Dataset wrapping tensors.    Each sample will be retrieved by indexing tensors along the first dimension.    Arguments:        *tensors (Tensor): tensors that have the same size of the first dimension.    """    def __init__(self, *tensors):        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)        self.tensors = tensors    def __getitem__(self, index):        return tuple(tensor[index] for tensor in self.tensors)    def __len__(self):        return self.tensors[0].size(0)

    可以看到TensorDataset类是Dataset类的子类,且拥有返回数据集尺寸的__len__方法和用来获取索引数据的__getitem__方法,所以可以直接使用。

    它的结构跟上面自定义的子类的结构是一样的,惟一的不同是TensorDataset已经规定了传入的数据必须是torch.Tensor类型的,而自定义子类可以自由设定。

    使用举例:

    import torchfrom torch.utils.data import Dataset, DataLoader, TensorDatasetfrom torch.autograd import Variableimport numpy as npimport pandas as pdvalue_df = pd.read_csv('data1.csv')value_array = np.array(value_df)print("value_array.shape =", value_array.shape)  # (73700, 300)value_size = value_array.shape[0]  # 73700train_size = int(0.7*value_size)train_array = val_array[:train_size]  train_tensor = torch.tensor(train_array, dtype=torch.float32).to(device)train_label_array = val_array[60:train_size+60]train_labels_tensor = torch.tensor(train_label_array,dtype=torch.float32).to(device)train_dataset = TensorDataset(train_tensor, train_labels_tensor)train_loader = DataLoader(dataset=train_dataset,                          batch_size=100,                          shuffle=False,                          num_workers=0)for epoch in range(2):    for i, data in enumerate(train_loader):        inputs, labels = data        inputs, labels = Variable(inputs), Variable(labels)        print(epoch, i, "inputs", inputs.data.size(), "labels", labels.data.size())

    结果:

    0 0 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 1 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 2 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 3 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 4 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 5 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 6 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 7 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 8 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 9 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    0 10 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
    ...

    关于“Pytorch中TensorDataset与DataLoader怎么联合使用”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“Pytorch中TensorDataset与DataLoader怎么联合使用”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注编程网行业资讯频道。

    免责声明:

    ① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

    ② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

    Pytorch中TensorDataset与DataLoader怎么联合使用

    下载Word文档到电脑,方便收藏和打印~

    下载Word文档

    猜你喜欢

    Pytorch中TensorDataset与DataLoader怎么联合使用

    这篇文章主要介绍了Pytorch中TensorDataset与DataLoader怎么联合使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch中TensorDataset与DataLoader怎么联
    2023-07-05

    Pytorch中TensorDataset,DataLoader的联合使用方式

    这篇文章主要介绍了Pytorch中TensorDataset,DataLoader的联合使用方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-20

    PyTorch的TensorDataset功能怎么使用

    本文小编为大家详细介绍“PyTorch的TensorDataset功能怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“PyTorch的TensorDataset功能怎么使用”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来
    2023-07-05

    PyTorch中怎么使用DataLoader加载数据

    在PyTorch中使用DataLoader加载数据主要有以下几个步骤:创建数据集对象:首先,需要创建一个数据集对象,该数据集对象必须继承自torch.utils.data.Dataset类,并实现__len__和__getitem__方法。
    PyTorch中怎么使用DataLoader加载数据
    2024-03-05

    FlexBuilder3.0与Eclipse3.4怎么联合使用

    这篇文章主要介绍了FlexBuilder3.0与Eclipse3.4怎么联合使用,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。FlexBuilder3.0ForEclipse
    2023-06-17

    Optimizer与optimizer.step()怎么在pytorch中使用

    今天就跟大家聊聊有关Optimizer与optimizer.step()怎么在pytorch中使用,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。当我们想指定每一层的学习率时:opti
    2023-06-15

    parameter与buffer怎么在Pytorch模型中使用

    本篇文章给大家分享的是有关parameter与buffer怎么在Pytorch模型中使用,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。第一种参数有两种方式我们可以直接将模型的成
    2023-06-15

    Pytorch中的Tensorboard与Transforms怎么搭配使用

    这篇文章主要介绍了Pytorch中的Tensorboard与Transforms怎么搭配使用,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。直接上代码:from PIL imp
    2023-06-22

    Python中Pytorch怎么使用

    这篇文章将为大家详细讲解有关Python中Pytorch怎么使用,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。一、TensorTensor(张量是一个统称,其中包括很多类型):0阶张量:标量、常数、0-D
    2023-06-15

    MySQL联合索引怎么使用

    MySQL联合索引是指在一个表中同时使用多个列作为索引的方式,可以提高查询效率。使用方法如下:创建联合索引:ALTER TABLE 表名 ADD INDEX 索引名称 (列1, 列2, 列3, ...);例如:ALTER TABLE
    2023-10-27

    PyTorch中torch.utils.data.DataLoader怎么使用

    这篇文章主要介绍“PyTorch中torch.utils.data.DataLoader怎么使用”,在日常操作中,相信很多人在PyTorch中torch.utils.data.DataLoader怎么使用问题上存在疑惑,小编查阅了各式资料,
    2023-07-02

    pytorch中nn.Dropout怎么使用

    小编给大家分享一下pytorch中nn.Dropout怎么使用,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!看代码吧~Class USeDropout(nn.Module): def __init__(self):
    2023-06-15

    pytorch中[..., 0]怎么使用

    这篇文章将为大家详细讲解有关pytorch中[..., 0]怎么使用,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。在看程序的时候看到了x[…, 0]的语句不是很理解,后来自己做实验略微了解,以此记录方便自
    2023-06-15

    Pureftpd和PostgreSQL联合怎么使用

    小编给大家分享一下Pureftpd和PostgreSQL联合怎么使用,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!Pureftpd是一款在多种类Unix上使用并符
    2023-06-16

    pytorch中nn.RNN()怎么使用

    这篇文章主要介绍“pytorch中nn.RNN()怎么使用”,在日常操作中,相信很多人在pytorch中nn.RNN()怎么使用问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”pytorch中nn.RNN()怎
    2023-07-04

    Pytorch中怎么使用TensorBoard

    本文小编为大家详细介绍“Pytorch中怎么使用TensorBoard”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中怎么使用TensorBoard”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。一
    2023-07-02

    pytorch中with torch.no_grad()怎么使用

    本篇内容主要讲解“pytorch中with torch.no_grad()怎么使用”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“pytorch中with torch.no_grad()怎么使用”
    2023-06-29

    SpringBoot中怎么整合MyBatisPlus Join使用联表查询

    这篇文章主要介绍了SpringBoot中怎么整合MyBatisPlus Join使用联表查询的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇SpringBoot中怎么整合MyBatisPlus Join使用联表查
    2023-07-05

    PyTorch中的nn.Embedding怎么使用

    这篇“PyTorch中的nn.Embedding怎么使用”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“PyTorch中的nn
    2023-07-02

    Ant Design Vue中的table与pagination的联合使用方式

    这篇文章主要介绍了Ant Design Vue中的table与pagination的联合使用方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-13

    编程热搜

    • Python 学习之路 - Python
      一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
      Python 学习之路 - Python
    • chatgpt的中文全称是什么
      chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
      chatgpt的中文全称是什么
    • C/C++中extern函数使用详解
    • C/C++可变参数的使用
      可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
      C/C++可变参数的使用
    • css样式文件该放在哪里
    • php中数组下标必须是连续的吗
    • Python 3 教程
      Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
      Python 3 教程
    • Python pip包管理
      一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
      Python pip包管理
    • ubuntu如何重新编译内核
    • 改善Java代码之慎用java动态编译

    目录