我的编程空间,编程开发者的网络收藏夹
学习永远不晚

【Pytorch】model.train() 和 model.eval() 原理与用法

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

【Pytorch】model.train() 和 model.eval() 原理与用法

文章目录


一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train()model.eval()

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。


二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。


三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。


参考链接

  1. PyTorch中train()方法的作用是什么
  2. 【pytorch】model.train()和model.evel()的用法
  3. pytorch中net.eval() 和net.train()的使用
  4. Pytorch学习笔记11----model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析、输出整个tensor的方法
  5. 好文:Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

来源地址:https://blog.csdn.net/weixin_44211968/article/details/123774649

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

【Pytorch】model.train() 和 model.eval() 原理与用法

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Pytorch中的model.train() 和 model.eval() 原理与用法解析

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval(),这篇文章主要介绍了Pytorch中的model.train() 和 model.eval() 原理与用法,需要的朋友可以参考下
2023-05-15

详解model.train()和model.eval()两种模式的原理与用法

这篇文章主要介绍了详解model.train()和model.eval()两种模式的原理与用法,相信很多没有经验的人对此束手无策,那么看完这篇文章一定会对你有所帮助
2023-03-23

Pytorch中的model.train()和model.eval()怎么使用

本文小编为大家详细介绍“Pytorch中的model.train()和model.eval()怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的model.train()和model.eval()怎么使用”文章能帮助
2023-07-06

MySQL 8.0用户和角色管理原理与用法详解

本文实例讲述了MySQL 8.0用户和角色管理。分享给大家供大家参考,具体如下: MySQL8.0新加了很多功能,其中在用户管理中增加了角色的管理, 默认的密码加密方式也做了调整,由之前的sha1改为了sha2,同时加上5.7的禁用用户和用
2022-05-31

mysql数据类型和字段属性原理与用法详解

本文实例讲述了mysql数据类型和字段属性。分享给大家供大家参考,具体如下: 本文内容:数据类型数值类型整数型浮点型定点型日期时间类型字符串类型补充:显示宽度与zerofll记录长度字段属性空\不为空值:NULL、NOT NULL主键:pr
2022-05-22

Java SQL注入的原理和用法

本篇内容介绍了“Java SQL注入的原理和用法”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!目录一,SQL注入–1,需求–2,测试–3,总
2023-06-20

mysql视图原理与用法实例详解

本文实例讲述了mysql视图原理与用法。分享给大家供大家参考,具体如下: 本文内容:什么是视图创建视图查看视图视图的修改视图的删除视图的数据操作首发日期:2018-04-13什么是视图:视图是一种基于查询结果的虚拟表,数据来源的表称为基本表
2022-05-29

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录