如何在Pytorch中使用Dataset和DataLoader读取数据
本篇文章给大家分享的是有关如何在Pytorch中使用Dataset和DataLoader读取数据,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。
一、前言
确保安装
scikit-image
numpy
二、Dataset
一个例子:
# 导入需要的包import torchimport torch.utils.data.dataset as Datasetimport numpy as np # 编造数据Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])Label = np.asarray([[0], [1], [0], [2]])# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1] #创建子类class subDataset(Dataset.Dataset): #初始化,定义数据内容和标签 def __init__(self, Data, Label): self.Data = Data self.Label = Label #返回数据集大小 def __len__(self): return len(self.Data) #得到数据内容和标签 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.IntTensor(self.Label[index]) return data, label # 主函数if __name__ == '__main__': dataset = subDataset(Data, Label) print(dataset) print('dataset大小为:', dataset.__len__()) print(dataset.__getitem__(0)) print(dataset[0])
输出的结果
我们有了对Dataset的一个整体的把握,再来分析里面的细节:
#创建子类class subDataset(Dataset.Dataset):
创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!
len和getitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。
三、DatasetLoader
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_works=0, clollate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;
还是举一个实例:
import torchimport torch.utils.data.dataset as Datasetimport torch.utils.data.dataloader as DataLoaderimport numpy as np Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])Label = np.asarray([[0], [1], [0], [2]])#创建子类class subDataset(Dataset.Dataset): #初始化,定义数据内容和标签 def __init__(self, Data, Label): self.Data = Data self.Label = Label #返回数据集大小 def __len__(self): return len(self.Data) #得到数据内容和标签 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.IntTensor(self.Label[index]) return data, label if __name__ == '__main__': dataset = subDataset(Data, Label) print(dataset) print('dataset大小为:', dataset.__len__()) print(dataset.__getitem__(0)) print(dataset[0]) #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程 dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4) for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据 print('i:', i) data, label = item #数据是一个元组 print('data:', data) print('label:', label)
四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)
这部分可以直接去看博客:Dataset和DataLoader
总结下来时有两种方法解决
如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。
不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。
这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。
pytorch的优点
1.PyTorch是相当简洁且高效快速的框架;2.设计追求最少的封装;3.设计符合人类思维,它让用户尽可能地专注于实现自己的想法;4.与google的Tensorflow类似,FAIR的支持足以确保PyTorch获得持续的开发更新;5.PyTorch作者亲自维护的论坛 供用户交流和求教问题6.入门简单
以上就是如何在Pytorch中使用Dataset和DataLoader读取数据,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注编程网行业资讯频道。
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341