我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch可视化之hook钩子怎么使用

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch可视化之hook钩子怎么使用

这篇文章主要介绍了pytorch可视化之hook钩子怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch可视化之hook钩子怎么使用文章都会有所收获,下面我们一起来看看吧。

一、hook

在PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能够获取到特征图,这个接口的名称非常形象,叫做hook。
可以想象这样的场景,数据通过网络向前传播,网络某一层我们预先设置了一个钩子,数据传播过后钩子上会留下数据在这一层的样子,读取钩子的信息就是这一层的特征图。
具体实现如下:

1.1 什么是hook,什么情况下使用?

首先,明确一下,为什么需要用hook,假设有这么一个函数

pytorch可视化之hook钩子怎么使用

需要通过梯度下降法求最小值,其实现方法如下:

import torchx = torch.tensor(3.0, requires_grad=True)y = (x-2)z = ((y-x) ** 2)z.backward()print("x.grad:",x.requires_grad,x.grad)print("y.grad:",y.requires_grad,y.grad)print("z.grad:",z.requires_grad,z.grad)

结果如下:

x.grad: True tensor(0.)
y.grad: True None
z.grad: True None

注意:在使用训练PyTorch训练模型时,只有叶节点(即直接指定数值的变量,而不是由其他变量计算得到的,比如网络输入)的梯度会保留,其余中间节点梯度在反向传播完成后就会自动释放以节省显存。 因此y.requires_grad的返回值为True,y.grad却为None。

可以看到上面的requires_grad方法都显示True,但是grad没有返回值。当然pytorch也提供某种方法保留非叶子节点的梯度信息。
使用 retain_grad() 方法可以保留非叶子节点的梯度,使用 retain_grad 保留的grad会占用显存,具体操作如下:

x = torch.tensor(3.0, requires_grad=True)y = (x-2)z = ((y-x) ** 2)y.retain_grad()z.retain_grad()z.backward()print("x.grad:",x.requires_grad,x.grad)print("y.grad:",y.requires_grad,y.grad)print("z.grad:",z.requires_grad,z.grad)

out:

x.grad: True tensor(0.)y.grad: True tensor(-4.)z.grad: True tensor(1.)

** 重申一次** 使用retain_grad方法会占用显存,如果不想要占用显存,就使用到了hook方法。

对于中间节点的变量a,可以使用a.register_hook(hook_fn)对其grad进行操作。 而hook_fn是一个自定义的函数,其声明为hook_fn(grad) -> Tensor or None

1.2 hook在变量中的使用

1 hook的打印功能

# 自定义hook方法,其传入参数为grad,打印出使用钩子的节点梯度def hook_fn(grad):    print(grad)x = torch.tensor(3.0, requires_grad=True)y = (x-2)z = ((y-x) ** 2)y.register_hook(hook_fn)z.register_hook(hook_fn)print("backward前")z.backward()print("backward后\n")print("x.grad:",x.requires_grad,x.grad)print("y.grad:",y.requires_grad,y.grad)print("z.grad:",z.requires_grad,z.grad)

out:

backward前tensor(1.)tensor(-4.)backward后x.grad: True tensor(0.)y.grad: True Nonez.grad: True None

可以看到绑定hook后,backward打印的时候打印了y和z的梯度,调用grad的时候没有保留grad值,已经释放掉内存。注意,打印出来的结果是反向传播,所以先打印z的梯度,再打印y的梯度。

2 使用hook改变grad的功能

对标记的节点,梯度加2

def hook_fn(grad):    grad += 2    print(grad)    return gradx = torch.tensor(3.0, requires_grad=True)y = (x-2)z = ((y-x) ** 2)y.register_hook(hook_fn)z.register_hook(hook_fn)print("backward前")z.backward()print("backward后\n")print("x.grad:",x.requires_grad,x.grad)print("y.grad:",x.requires_grad,y.grad)print("z.grad:",x.requires_grad,z.grad)

out:

backward前tensor(3.)tensor(-10.)backward后x.grad: True tensor(2.)y.grad: True Nonez.grad: True None

可以看到梯度教上面的已经发生的改变。

1.3 hook在模型中的使用:

PyTorch中使用register_forward_hook和register_backward_hook获取Module输入和输出的feature_map和grad。使用结构如下: hook_fn(module, input, output) -> Tensor or None
模型中使用hook一点要带有这三个参数module, grad_input, grad_output

1 register_forward_hook的使用

import torch.nn as nndef hook_forward_fn(model,put,out):    print("model:",model)    print("input:",put)    print("output:",out)    # 定义一个modelclass Net(nn.Module):    def __init__(self):        super(Net,self).__init__()        self.conv = nn.Conv2d(3, 1, 1)        self.bn = nn.BatchNorm2d(1)        #self.conv.register_forward_hook(hook_forward_fn)        #self.bn.register_forward_hook(hook_forward_fn)    def forward(self, x):        x = self.conv(x)        x = self.bn(x)        return torch.relu(x)    net = Net()# 对模型中的具体某一层使用hooknet.conv.register_forward_hook(hook_forward_fn)net.bn.register_forward_hook(hook_forward_fn)x = torch.rand(1, 3, 2, 2, requires_grad=True)y = net(x).mean()

注意:该方法不需要使用。backword就能输出结果,是记录前向传播的钩子。
结果如下:

model: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))input: (tensor([[[[0.4570, 0.6791],          [0.0197, 0.5040]],         [[0.8883, 0.1808],          [0.6289, 0.9386]],         [[0.8772, 0.5290],          [0.0014, 0.3728]]]], requires_grad=True),)output: tensor([[[[-0.4909, -0.1122],          [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>)model: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)input: (tensor([[[[-0.4909, -0.1122],          [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>),)output: tensor([[[[-0.2060,  1.6790],          [-0.8987, -0.5743]]]], grad_fn=<NativeBatchNormBackward0>)

2 register_backward_hook的使用

使用上面相同的Net模型

def hook_backward_fn(module, grad_input, grad_output):    print(f"module: {module}")    print(f"grad_output: {grad_output}")    print(f"grad_input: {grad_input}")    print("*"*20)    net = Net()net.conv.register_backward_hook(hook_backward_fn)net.bn.register_backward_hook(hook_backward_fn)x = x = torch.rand(1, 3, 2, 2, requires_grad=True)y = net(x).mean()y.backward()

out:

module: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)grad_output: (tensor([[[[0.2500, 0.2500],          [0.0000, 0.0000]]]]),)grad_input: (tensor([[[[ 0.6586, -0.3360],          [-0.3009, -0.0218]]]]), tensor([0.4575]), tensor([0.5000]))********************module: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))grad_output: (tensor([[[[ 0.6586, -0.3360],          [-0.3009, -0.0218]]]]),)grad_input: (tensor([[[[-0.2974,  0.1517],          [ 0.1359,  0.0098]],         [[ 0.0270, -0.0138],          [-0.0123, -0.0009]],         [[ 0.2918, -0.1489],          [-0.1333, -0.0096]]]]), tensor([[[[0.4331]],         [[0.1386]],         [[0.4292]]]]), tensor([-1.4156e-07]))********************

其结果是逆向输出各节点层的梯度信息。

3 hook中使用展示卷积层

随便画一张图,图片张这个样子:

pytorch可视化之hook钩子怎么使用

使用读取图片发现是个4通道的图像,我们转成单通道并可视化:

import matplotlib.pyplot as pltimport matplotlib.image as mpingimg=mping.imread("./test1.png")print(img.shape)img = torch.tensor(img[:,:,0]).view(1,1,228,226)plt.imshow(img[0][0])

pytorch可视化之hook钩子怎么使用

接下来创建一个只有卷积层的模型

class Net(nn.Module):    def __init__(self):        super(Net,self).__init__()        self.conv = nn.Sequential(nn.Conv2d(1,1,7),                                  nn.ReLU()                                 )    def forward(self, x):        x=self.conv(x)        return x

使用我们的钩子hook对卷积层的输出进行可视化

def hook_forward_fn(model,put,out):    print("inputshape:",put[0].shape) # 打印出输入图片的维度    print("outputshape:",out[0][0].shape) # 经过卷积之后的维度    # 可视化,因为卷积之后带有grad梯度信息,所以需要使用detach().numpy()方法,否则会报错    plt.imshow(out[0][0].detach().numpy())

具体完整实现以及可视化代码如下:

import matplotlib.pyplot as pltimport matplotlib.image as mpingimport numpy as npimg=mping.imread("./test1.png")img = torch.tensor(img[:,:,0]).view(1,1,228,226)def hook_forward_fn(model,put,out):    print("inputshape:",put[0].shape)    print("outputshape:",out[0][0].shape)    plt.imshow(out[0][0].detach().numpy())      class Net(nn.Module):    def __init__(self):        super(Net,self).__init__()        self.conv = nn.Sequential(nn.Conv2d(1,1,7),                                  nn.ReLU()                                 )    def forward(self, x):        x=self.conv(x)        return x    model = Net()model.conv.register_forward_hook(hook_forward_fn)y=model(img)

pytorch可视化之hook钩子怎么使用

关于“pytorch可视化之hook钩子怎么使用”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“pytorch可视化之hook钩子怎么使用”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注编程网行业资讯频道。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch可视化之hook钩子怎么使用

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

pytorch可视化之hook钩子怎么使用

这篇文章主要介绍了pytorch可视化之hook钩子怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch可视化之hook钩子怎么使用文章都会有所收获,下面我们一起来看看吧。一、hook在PyTo
2023-07-05

pytorch中可视化之hook钩子

本文主要介绍了pytorch中可视化之hook钩子,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-03-23

Pytorch可视化之Visdom怎么用

这篇文章主要为大家展示了“Pytorch可视化之Visdom怎么用”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Pytorch可视化之Visdom怎么用”这篇文章吧。一、Visdom简介Visd
2023-06-20

PyTorch可视化工具TensorBoard和Visdom怎么用

今天小编给大家分享一下PyTorch可视化工具TensorBoard和Visdom怎么用的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了
2023-06-26

Python数据可视化之Seaborn怎么使用

这篇文章主要介绍了Python数据可视化之Seaborn怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Python数据可视化之Seaborn怎么使用文章都会有所收获,下面我们一起来看看吧。1. 安装 s
2023-06-30

Python可视化tkinter怎么使用

这篇文章主要讲解了“Python可视化tkinter怎么使用”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Python可视化tkinter怎么使用”吧!1、基本用法# coding:utf-
2023-07-02

Vue怎么使用echarts可视化图表

这篇文章主要介绍“Vue怎么使用echarts可视化图表”,在日常操作中,相信很多人在Vue怎么使用echarts可视化图表问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Vue怎么使用echarts可视化图表
2023-07-04

winform数据可视化控件怎么使用

WinForms 数据可视化控件是用于在 Windows 窗体应用程序中显示和分析数据的工具。以下是使用 WinForms 数据可视化控件的一般步骤:1. 打开 Visual Studio 创建一个新的 Windows 窗体应用程序项目。2
2023-09-16

怎么使用python的可视化工具Pandas_Alive

这篇文章主要介绍“怎么使用python的可视化工具Pandas_Alive”,在日常操作中,相信很多人在怎么使用python的可视化工具Pandas_Alive问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”怎
2023-06-25

MongoDB可视化工具mongodb compass怎么使用

这篇文章主要介绍了MongoDB可视化工具mongodb compass怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇MongoDB可视化工具mongodb compass怎么使用文章都会有所收获,下面
2023-07-02

使用SpringBoot怎么实现可视化监控

本篇文章给大家分享的是有关使用SpringBoot怎么实现可视化监控,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。具体如下:1、Spring Boot 应用暴露监控指标【版本
2023-06-15

git可视化提交工具Sourcetree怎么使用

这篇文章主要讲解了“git可视化提交工具Sourcetree怎么使用”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“git可视化提交工具Sourcetree怎么使用”吧!Sourcetree基
2023-06-29

Redis可视化工具怎么安装及使用

要安装和使用Redis可视化工具,可以按照以下步骤进行操作:1. 选择合适的Redis可视化工具:有很多可选的Redis可视化工具,比如Redis Desktop Manager、RedisInsight等。可以根据自己的需求和喜好选择一个
2023-08-15

vue可视化表单设计器怎么使用

Vue可视化表单设计器是一个基于Vue.js的表单设计器,用于快速生成表单,可以大大提高开发效率。使用步骤如下:1. 安装Vue可视化表单设计器在命令行中输入以下命令安装:```npm install vue-form-making --s
2023-06-12

Python中怎么使用使用Plotly实现数据可视化

这期内容当中小编将会给大家带来有关Python中怎么使用使用Plotly实现数据可视化,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。 Plotly 是一个数据绘图库,具有整洁的接口,它旨在允许你构建自己的
2023-06-16

Python数据可视化之怎么用Matplotlib绘制常用图形

这篇文章主要介绍Python数据可视化之怎么用Matplotlib绘制常用图形,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!一、散点图散点图用两组数据构成多个坐标点,考察坐标点的分布,判断两变量之间是否存在某种关联或
2023-06-15

R语言怎么使用gganimate创建可视化动图

这篇“R语言怎么使用gganimate创建可视化动图”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“R语言怎么使用gganim
2023-06-30

Python怎么使用树状图实现可视化聚类

今天小编给大家分享一下Python怎么使用树状图实现可视化聚类的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。树状图树状图是显
2023-07-05

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录