我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch怎么加载自己的图片数据集

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch怎么加载自己的图片数据集

本文小编为大家详细介绍“pytorch怎么加载自己的图片数据集”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch怎么加载自己的图片数据集”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别

导入ImageFolder()包from torchvision.datasets import ImageFolder

pytorch怎么加载自己的图片数据集

在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。

ImageFolder 加载数据集

导入包和设置transform

import torchfrom torchvision import transforms, datasetsimport torch.nn as nnfrom torch.utils.data import DataLoader transforms = transforms.Compose([    transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:    transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片    transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型])

加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。

path = r'D:\数据集\Flower_Orig_dataset'data_train = datasets.ImageFolder(path, transform=transforms)data_loader = DataLoader(data_train, batch_size=64, shuffle=True)for i, data in enumerate(data_loader):    images, labels = data    print(images.shape)    print(labels.shape)    break

使用pytorch提供的Dataset类创建自己的数据集。

具体步骤:

  首先要有一个txt文件, 这个文件格式是: 图片路径  标签.  这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。

有了这个txt文件, 我们就可以在类里面构造我们的数据集.

1    把图片路径和图片标签分割开, 有两个列表, 一个列表是图片路径名, 一个列表是标签号, 有一点就是第 i 个图片列表和 第 i 个标签是对应的

重写__len__方法  和  __getitem__方法

1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可

就可以构建数据集实例和加载数据集.

 定义一个用来生成[ 图片路径 标签] 这样的txt文件函数

def make_txt(root, file_name, label):    path = os.path.join(root, file_name)    data = os.listdir(path)    f = open(path+'\\'+'f.txt', 'w')    for line in data:        f.write(line+' '+str(label)+'\n')    f.close()#调用函数生成两个文件夹下的txt文件make_txt(path, file_name='flower_orig', label=0)make_txt(path, file_name='sunflower', label=1)

将连个txt文件合并成一个txt文件,表示数据集所有的图片和标签

def link_txt(file1, file2):    txt_list = []    path = r'D:\数据集\Flower_Orig_dataset\data.txt'     f = open(path, 'a')     f1 = open(file1, 'r')    data1 = f1.readlines()    for line in data1:        txt_list.append(line)     f2 = open(file2, 'r')    data2 = f2.readlines()    for line in data2:        txt_list.append(line)     for line in txt_list:        f.write(line)     f.close()    f1.close()    f2.close() #调用函数, 将两个文件夹下的txt文件合并file1 = r'D:\数据集\Flower_Orig_dataset\flower_orig\f.txt'file2 = r'D:\数据集\Flower_Orig_dataset\sunflower\f.txt'link_txt(file1=file1, file2=file2)

现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息

Dataset加载数据集

我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径   class MyDataset(Dataset):    def __init__(self, img_path, transform=None):        super(MyDataset, self).__init__()        self.root = img_path         self.txt_root = self.root + 'data.txt'        f = open(self.txt_root, 'r')        data = f.readlines()         imgs = []        labels = []        for line in data:            line = line.rstrip()            word = line.split()            imgs.append(os.path.join(self.root, word[1], word[0]))             labels.append(word[1])        self.img = imgs        self.label = labels        self.transform = transform     def __len__(self):        return len(self.label)     def __getitem__(self, item):        img = self.img[item]        label = self.label[item]         img = Image.open(img).convert('RGB')         #此时img是PIL.Image类型   label是str类型         if transforms is not None:            img = self.transform(img)         label = np.array(label).astype(np.int64)        label = torch.from_numpy(label)                return img, label

 加载我们的数据集:

path = r'D:\数据集\Flower_Orig_dataset'dataset = MyDataset(path, transform=transform) data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

接下来我们就可以构建我们的网络架构:

class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(3,16,3)        self.maxpool = nn.MaxPool2d(2,2)        self.conv2 = nn.Conv2d(16,5,3)         self.relu = nn.ReLU()        self.fc1 = nn.Linear(55*55*5, 1200)        self.fc2 = nn.Linear(1200,64)        self.fc3 = nn.Linear(64,2)     def forward(self,x):        x = self.maxpool(self.relu(self.conv1(x)))    #113        x = self.maxpool(self.relu(self.conv2(x)))    #55        x = x.view(-1, self.num_flat_features(x))        x = self.relu(self.fc1(x))        x = self.relu(self.fc2(x))        x = self.fc3(x)        return x            def num_flat_features(self, x):        size = x.size()[1:]        num_features = 1        for s in size:            num_features *= s         return num_features

 训练我们的网络:

model = Net() criterion = torch.nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01)  epochs = 10for epoch in range(epochs):    running_loss = 0.0    for i, data in enumerate(data_loader):        images, label = data         out = model(images)         loss = criterion(out, label)         optimizer.zero_grad()        loss.backward()        optimizer.step()         running_loss += loss.item()        if(i+1)%10 == 0:            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))            running_loss = 0.0 print('finished train')

 保存网络模型(这里不止是保存参数,还保存了网络结构)

#保存模型torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b权重值 # 读取模型model = torch.load('model_name.pth')

读到这里,这篇“pytorch怎么加载自己的图片数据集”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注编程网行业资讯频道。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch怎么加载自己的图片数据集

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

pytorch怎么加载自己的图片数据集

本文小编为大家详细介绍“pytorch怎么加载自己的图片数据集”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch怎么加载自己的图片数据集”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。1.ImageFo
2023-07-02

pytorch中怎么加载自己的数据集

在PyTorch中,可以通过创建一个自定义的数据集类来加载自己的数据集。首先,需要导入以下必要的库和模块:```pythonimport torchfrom torch.utils.data import Dataset, DataLoad
2023-10-09

怎么使用pytorch准备自己的图片数据

本篇内容主要讲解“怎么使用pytorch准备自己的图片数据”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“怎么使用pytorch准备自己的图片数据”吧!正文图片数据一般有两种情况:1、所有图片放在
2023-07-02

pytorch怎么制作自己的数据集

要制作自己的数据集,可以按照以下步骤操作:1. 准备数据:将数据整理成所需的格式。根据你的任务和数据类型,可能需要将数据转换为图像、文本、CSV等格式。2. 创建一个自定义数据集类:在PyTorch中,可以通过创建一个继承自torch.ut
2023-10-09

pytorch中怎么创建自己的数据集

在PyTorch中,可以通过继承torch.utils.data.Dataset类来创建自己的数据集。以下是一个简单的示例代码:import torchfrom torch.utils.data import Datasetclass
pytorch中怎么创建自己的数据集
2024-04-08

怎么用GAN训练自己数据生成新的图片

本文小编为大家详细介绍“怎么用GAN训练自己数据生成新的图片”,内容详细,步骤清晰,细节处理妥当,希望这篇“怎么用GAN训练自己数据生成新的图片”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。一、读取数据问题# M
2023-07-05

pytorch通过自己的数据集训练Unet网络架构

Unet是一个最近比较火的网络结构。它的理论已经有很多大佬在讨论了。本文主要从实际操作的层面,讲解如何使用pytorch实现unet图像分割
2022-12-08

图片自动按需加载怎么回事

图片按需加载图片按需加载仅在需要时才加载图像,以优化加载速度、减少带宽,并提升用户体验。通过使用占位符或低分辨率图像,它能延迟加载图像,直到用户滚动到视图中时再加载。此技术可提高页面性能、降低网络流量并改善用户体验。
图片自动按需加载怎么回事
2024-04-26

PyTorch中怎么使用DataLoader加载数据

在PyTorch中使用DataLoader加载数据主要有以下几个步骤:创建数据集对象:首先,需要创建一个数据集对象,该数据集对象必须继承自torch.utils.data.Dataset类,并实现__len__和__getitem__方法。
PyTorch中怎么使用DataLoader加载数据
2024-03-05

tensorflow怎么加载本地数据集

要加载本地数据集到TensorFlow中,可以使用tf.data.Dataset.from_tensor_slices()函数。首先,将本地数据集加载到numpy数组中,然后使用from_tensor_slices()函数将numpy数组转
tensorflow怎么加载本地数据集
2024-03-15

电脑图标怎么更改自己喜欢的图片

本文小编为大家详细介绍“电脑图标怎么更改自己喜欢的图片”,内容详细,步骤清晰,细节处理妥当,希望这篇“电脑图标怎么更改自己喜欢的图片”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。电脑图标更改自己喜欢的图片:1、首
2023-07-02

使用pytorch怎么将图片数据转换成tensor

这期内容当中小编将会给大家带来有关使用pytorch怎么将图片数据转换成tensor,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。一、数据转换把图片转成成torch的tensor数据,一般采用函数:tor
2023-06-06

js怎么实现图片的懒加载

这篇文章给大家分享的是有关js怎么实现图片的懒加载的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。图片的懒加载是前端优化必须要掌握的东西,图片实现懒加载可以节省带宽又可以减轻我们网页的负荷。接下来我来记录一下我所掌
2023-06-14

Keras中如何加载自定义的数据集

在Keras中加载自定义的数据集通常需要以下步骤:准备数据集:首先,将自定义的数据集准备好,包括数据文件、标签文件等。创建数据生成器:在Keras中通常使用ImageDataGenerator类来创建数据生成器,用于在训练模型时从数据集中生
Keras中如何加载自定义的数据集
2024-03-12

怎么在iOS中高效的加载图片

这篇文章主要介绍怎么在iOS中高效的加载图片,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!图片的渲染流程在iOS中使用 UIImage和UIImageView来记载图片,他俩遵守经典的MVC架构,UIImage相当于
2023-06-25

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录