我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Python深度学习pytorch实现图像分类数据集

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Python深度学习pytorch实现图像分类数据集

目前广泛使用的图像分类数据集之一是MNIST数据集。如今,MNIST数据集更像是一个健全的检查,而不是一个基准。

为了提高难度,我们将在接下来的章节中讨论在2017年发布的性质相似但相对复杂的Fashion-MNIST数据集。


import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()

读取数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。


# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvisino.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

Fashion-MNIST由10个类别的图像组成,每个类别由训练集中的6000张图像和测试集中的1000张图像组成。

测试数据集(test dataset)不会用于训练,只用于评估模型性能。训练集和测试集分别包含60000和10000张图像。


len(mnist_train), len(mnist_test)

(60000, 10000)

每个输入图像的高度和宽度均为28像素。数据集由灰度图像组成,其通道数为1。

为了简洁起见,本篇中,我们将高度h像素,宽度w像素图像的形状即为 h×w或 (h,w)。


mnist_train[0][0].shape

torch.size([1, 28, 28])

Fashion-MNIST中包含10个类别分别是

t-shirt(T恤)、trouser(裤⼦)、pullover(套衫)、dress(连⾐裙)、coat(外套)、

sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)

以下函数用于在数字标签索引及其文本名称之间进行转换。


def get_fashion_mnist_labels(labels):
	"""返回Fashion-MNIST数据集的本文标签。"""
	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
	return [text_labels[int(i)] for i in labels]

我们现在可以创建一个函数来可视化这些样本。


def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
	"""Plot a list of images."""
	figsize = (num_cols * scale, num_rows * scale)
	_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
	axes = axes.flatten()
	for i, (ax, img) in enumerate(zip(axes, imgs)):
		if torch.is_tensor(img):
			# 图片张量
			ax.imshow(img.numpy())
		else:
			# PIL图片
			ax.imshow(img)
		ax.axes.get_xaxis().set_visible(False)
		ax.axes.get_yaxis().set_visible(False)
		if titles:
			ax.set_title(titles[i])
	return axes

以下是训练数据集中前几个样本的图像及其相应的标签(文本形式)。


X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

在这里插入图片描述

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建一个。回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。我们在训练数据迭代其中还随机打乱了所有样本


batch_size = 256
def get_dataloader_workers():
	"""使用4个进程来读取数据。"""
	return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())

整合所有组件

现在我们定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。它返回训练集和验证集的数据迭代器。此外,它还接受一个可选参数,用来将图像大小调整为另一种形状。


def load_data_fashion_mnist(batch_size, resize=None):
	"""下载Fashion-MNIST数据集,然后将其加载到内存中。"""
	trans = [transforms.ToTensor()]	
	if resize:
		trans.insert(0, transforms.Resize(resize))	
	trans = transforms.Compose(trans)
	mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transforms=trans, download=True)
	mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transforms=trans, download=True)
	return(data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
		   data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。


train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
	print(X.shape, X.dtype, y.shape, y.dtype)
	break

torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

以上就是Python深度学习pytorch实现图像分类数据集的详细内容,更多关于pytorch图像分类数据集的资料请关注编程网其它相关文章!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Python深度学习pytorch实现图像分类数据集

下载Word文档到电脑,方便收藏和打印~

下载Word文档

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录