我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch autograd bac

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch autograd bac

retain_graph参数的作用

官方定义:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

具体看一个例子理解:

假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。然后我们对两个output执行backward。

 1 import torch
 2 x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
 3 y = x ** 2
 4 z = y * 4
 5 print(x)
 6 print(y)
 7 print(z)
 8 loss1 = z.mean()
 9 loss2 = z.sum()
10 print(loss1,loss2)
11 loss1.backward()    # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
12 print(loss1,loss2)
13 loss2.backward()    # 这时会引发错误

程序正常执行到第12行,所有的变量正常保存。但是在第13行报错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

分析:计算节点数值保存了,但是计算图x-y-z结构被释放了,而计算loss2的backward仍然试图利用x-y-z的结构,因此会报错。

因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。正确的代码应当把第11行以及之后改成

1 # 假如你需要执行两次backward,先执行第一个的backward,再执行第二个backward
2 loss1.backward(retain_graph=True)# 这里参数表明保留backward后的中间参数。
3 loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
4  #如果是在训练网络optimizer.step() # 更新参数

create_graph参数比较简单,参考官方定义:
  • create_graph (booloptional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

附参考学习的链接如下,并对作者表示感谢:retain_graph参数的作用.

 

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch autograd bac

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

pytorch autograd bac

retain_graph参数的作用官方定义:retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note t
2023-01-31

PyTorch的张量tensor和自动求导autograd详解

PyTorch中的张量是多维数据结构,具有数据类型、形状和设备属性,可用于表示神经网络的数据。自动求导通过记录计算图自动计算梯度,加速优化过程。它包含前向传递、计算梯度和更新参数三个步骤。尽管PyTorch提供自动梯度计算、GPU加速和易用性等优点,但存在内存消耗、调试困难和性能开销等限制。
PyTorch的张量tensor和自动求导autograd详解
2024-04-02

PyTorch与PyTorch Geometric的安装过程

这篇文章主要介绍了PyTorch与PyTorch Geometric的安装,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
2023-05-14

PyTorch基础

Infi-chu:http://www.cnblogs.com/Infi-chu/torch.FloatTensor:用于生成数据类型为浮点型的Tensor,参数可以是一个列表,也可以是一个维度。import torcha = torch.
2023-01-30

PyTorch入门指南:在PyCharm中轻松安装PyTorch

PyTorch是当前深度学习领域中备受瞩目的框架之一,它的易用性和灵活性受到很多开发者的喜爱。对于很多新手来说,安装PyTorch可能是一个挑战,尤其是在选择合适的开发环境时。本文将介绍如何使用PyCharm这一流行的集成开发环境安装PyT
PyTorch入门指南:在PyCharm中轻松安装PyTorch
2024-02-27

什么是 PyTorch?

PyTorch是一款基于Python的深度学习框架,具有动态计算图和命令式编程接口,使模型开发灵活且快速。它支持硬件加速,拥有丰富的生态系统和广泛的应用,包括计算机视觉、自然语言处理和强化学习。与其他框架相比,PyTorch提供了更高的灵活性、可修改性和性能。其缺点包括内存占用和文档较少。总的来说,PyTorch是机器学习研究和开发的强大选择。
什么是 PyTorch?
2024-04-02

PyTorch与PyTorch Geometric的安装过程是什么

这篇文章主要讲解了“PyTorch与PyTorch Geometric的安装过程是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“PyTorch与PyTorch Geometric的安装过
2023-07-05

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录