我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

本文小编为大家详细介绍“pytorch中的transforms.ToTensor和transforms.Normalize怎么实现”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch中的transforms.ToTensor和transforms.Normalize怎么实现”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

transforms.ToTensor

最近看pytorch时,遇到了对图像数据的归一化,如下图所示:

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

该怎么理解这串代码呢?我们一句一句的来看,先看transforms.ToTensor(),我们可以先转到官方给的定义,如下图所示:

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

大概的意思就是说,transforms.ToTensor()可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ,具体做法其实就是将原始数据除以255。另外原始数据的shape是(H x W x C),通过transforms.ToTensor()后shape会变为(C x H x W)。这样说我觉得大家应该也是能理解的,这部分并不难,但想着还是用一些例子来加深大家的映像

先导入一些包

import cv2import numpy as npimport torchfrom torchvision import transforms

定义一个数组模型图片,注意数组数据类型需要时np.uint8【官方图示中给出】

data = np.array([                [[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1]],                [[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2]],                [[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3]],                [[4,4,4],[4,4,4],[4,4,4],[4,4,4],[4,4,4]],                [[5,5,5],[5,5,5],[5,5,5],[5,5,5],[5,5,5]]        ],dtype='uint8')

这是可以看看data的shape,注意现在为(W H C)。

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

使用transforms.ToTensor()将data进行转换

data = transforms.ToTensor()(data)

这时候我们来看看data中的数据及shape。

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

很明显,数据现在都映射到了[0, 1]之间,并且data的shape发生了变换。

**注意:不知道大家是如何理解三维数组的,这里提供我的一个方法。

原始的data的shape为(5,5,3),则其表示有5个(5 , 3)的二维数组,即我们把最外层的[]去掉就得到了5个五行三列的数据。同样的,变换后data的shape为(3,5,5),则其表示有3个(5 , 5)的二维数组,即我们把最外层的[]去掉就得到了3个五行五列的数据。

transforms.Normalize????

相信通过前面的叙述大家应该对transforms.ToTensor有了一定的了解,下面将来说说这个transforms.Normalize????????????同样的,我们先给出官方的定义,如下图所示:

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

可以看到这个函数的输出output[channel] = (input[channel] - mean[channel]) / std[channel]。这里[channel]的意思是指对特征图的每个通道都进行这样的操作。【mean为均值,std为标准差】接下来我们看第一张图片中的代码,即

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

这里的第一个参数(0.5,0.5,0.5)表示每个通道的均值都是0.5,第二个参数(0.5,0.5,0.5)表示每个通道的方差都为0.5。【因为图像一般是三个通道,所以这里的向量都是1x3的】有了这两个参数后,当我们传入一个图像时,就会按照上面的公式对图像进行变换。【注意:这里说图像其实也不够准确,因为这个函数传入的格式不能为PIL Image,我们应该先将其转换为Tensor格式

说了这么多,那么这个函数到底有什么用呢?我们通过前面的ToTensor已经将数据归一化到了0-1之间,现在又接上了一个Normalize函数有什么用呢?其实Normalize函数做的是将数据变换到了[-1,1]之间。之前的数据为0-1,当取0时,output =(0 - 0.5)/ 0.5 = -1;当取1时,output =(1 - 0.5)/ 0.5 = 1。这样就把数据统一到了[-1,1]之间了那么问题又来了,数据统一到[-1,1]有什么好处呢?数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度。【这句话是再网络上找到最多的解释,自己也不确定其正确性】

读到这里大家是不是以为就完了呢?这里还想和大家唠上一唠上面的两个参数(0.5,0.5,0.5)是怎么得来的呢?这是根据数据集中的数据计算出的均值和标准差,所以往往不同的数据集这两个值是不同的?这里再举一个例子帮助大家理解其计算过程。同样采用上文例子中提到的数据。

上文已经得到了经ToTensor转换后的数据,现需要求出该数据每个通道的mean和std。

# 需要对数据进行扩维,增加batch维度data = torch.unsqueeze(data,0)    #在pytorch中一般都是(batch,C,H,W)nb_samples = 0.#创建3维的空列表channel_mean = torch.zeros(3)channel_std = torch.zeros(3)N, C, H, W = data.shape[:4]data = data.view(N, C, -1)  #将数据的H,W合并#展平后,w,h属于第2维度,对他们求平均,sum(0)为将同一纬度的数据累加channel_mean += data.mean(2).sum(0)  #展平后,w,h属于第2维度,对他们求标准差,sum(0)为将同一纬度的数据累加channel_std += data.std(2).sum(0)#获取所有batch的数据,这里为1nb_samples += N#获取同一batch的均值和标准差channel_mean /= nb_sampleschannel_std /= nb_samplesprint(channel_mean, channel_std)   #结果为tensor([0.0118, 0.0118, 0.0118]) tensor([0.0057, 0.0057, 0.0057])

将上述得到的mean和std带入公式,计算输出。

for i in range(3):    data[i] = (data[i] - channel_mean[i]) / channel_std[i]print(data)

输出结果:

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

从结果可以看出,我们计算的mean和std并不是0.5,且最后的结果也没有在[-1,1]之间。

读到这里,这篇“pytorch中的transforms.ToTensor和transforms.Normalize怎么实现”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注编程网行业资讯频道。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

pytorch中的transforms.ToTensor和transforms.Normalize怎么实现

本文小编为大家详细介绍“pytorch中的transforms.ToTensor和transforms.Normalize怎么实现”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch中的transforms.ToTensor和tr
2023-06-30

pytorch中矩阵乘法和数组乘法怎么实现

本篇内容介绍了“pytorch中矩阵乘法和数组乘法怎么实现”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!一、torch.mul该乘法可简单理
2023-07-05

PyTorch中的卷积神经网络怎么实现

在PyTorch中,可以使用torch.nn模块中的Conv2d类来实现卷积神经网络。以下是一个简单的示例,展示如何在PyTorch中实现一个简单的卷积神经网络:import torchimport torch.nn as nnclas
PyTorch中的卷积神经网络怎么实现
2024-03-05

Pytorch中的model.train()和model.eval()怎么使用

本文小编为大家详细介绍“Pytorch中的model.train()和model.eval()怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的model.train()和model.eval()怎么使用”文章能帮助
2023-07-06

PyTorch模型转TensorRT是怎么实现的?

转换步骤概览准备好模型定义文件(.py文件)准备好训练完成的权重文件(.pth或.pth.tar)安装onnx和onnxruntime将训练好的模型转换为.onnx格式安装tensorRT环境参数ubuntu-18.04 PyTorch-1
2022-06-02

PyTorch中的train()、eval()和no_grad()怎么使用

本篇内容介绍了“PyTorch中的train()、eval()和no_grad()怎么使用”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!什么
2023-07-05

PyTorch中怎么实现自定义损失函数

要实现自定义损失函数,可以按照以下步骤在PyTorch中实现:创建一个继承自torch.nn.Module的类,该类用于定义自定义损失函数的计算逻辑。import torchimport torch.nn as nnclass Custo
PyTorch中怎么实现自定义损失函数
2024-03-05

PyTorch中的神经网络Mnist分类任务怎么实现

这篇“PyTorch中的神经网络Mnist分类任务怎么实现”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“PyTorch中的神
2023-07-05

Pytorch中实现CPU和GPU之间的切换的两种方法

本文主要介绍了Pytorch中实现CPU和GPU之间的切换的两种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-01-28

PyTorch中dropout设置训练和测试模式的实现示例

这篇文章主要介绍PyTorch中dropout设置训练和测试模式的实现示例,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!看代码吧~class Net(nn.Module):…model = Net()…model.t
2023-06-15

怎么使用PyTorch和LSTM实现单变量时间序列预测

这篇“怎么使用PyTorch和LSTM实现单变量时间序列预测”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“怎么使用PyTor
2023-07-05

oracle中decimal和number怎么实现

在Oracle中,DECIMAL和NUMBER都可以用来表示浮点数,但是在内部实现上有一些区别。DECIMAL是一种精确的数据类型,它在存储数据时不会引入任何舍入误差。 DECIMAL类型通常用于需要高度精度的金融数据或其他需要精确计算的
oracle中decimal和number怎么实现
2024-04-09

在Golang中怎么实现求和

今天小编给大家分享一下在Golang中怎么实现求和的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。首先,我们可以使用循环的方式
2023-07-05

spring中REST和RESTful怎么实现

今天小编给大家分享一下spring中REST和RESTful怎么实现的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。简介RES
2023-06-29

Node中的进程和线程怎么实现

这篇文章主要介绍了Node中的进程和线程怎么实现的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Node中的进程和线程怎么实现文章都会有所收获,下面我们一起来看看吧。一、进程和线程1.1、专业性文字定义进程(Pr
2023-07-04

Golang中的接口怎么定义和实现

在Golang中,接口定义的方式非常简单,只需要使用关键字type和interface即可。接口定义了一组方法的集合,任何类型只要实现了接口中的所有方法,就被认为是实现了该接口。接口的定义方式如下:type InterfaceName
Golang中的接口怎么定义和实现
2024-03-13

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录