我的编程空间,编程开发者的网络收藏夹
学习永远不晚

如何深入理解Pytorch微调torchvision模型

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

如何深入理解Pytorch微调torchvision模型

如何深入理解Pytorch微调torchvision模型,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。

一、简介

在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。
本节将执行两种类型的迁移学习:

  • 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型。

  • 特征提取:从预训练模型开始,仅更新从中导出预测的最终图层权重。它被称为特征提取,因为我们使用预训练的CNN作为固定 的特征提取器,并且仅改变输出层。

通常这两种迁移学习方法都会遵循一下步骤:

  • 初始化预训练模型

  • 重组最后一层,使其具有与新数据集类别数相同的输出数

  • 为优化算法定义想要的训练期间更新的参数

  • 运行训练步骤

二、导入相关包

from __future__ import print_functionfrom __future__ import divisionimport torchimport torch.nn as nnimport torch.optim as optimimport numpy as npimport torchvision from torchvision import datasets,models,transformsimport matplotlib.pyplot as pltimport timeimport osimport copyprint("Pytorch version:",torch.__version__)print("torchvision version:",torchvision.__version__)

运行结果

如何深入理解Pytorch微调torchvision模型

三、数据输入

数据集——>我在这里

链接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取码:1234

#%%输入data_dir="D:\Python\Pytorch\data\hymenoptera_data"# 从[resnet,alexnet,vgg,squeezenet,desenet,inception]model_name='squeezenet'# 数据集中类别数量num_classes=2# 训练的批量大小batch_size=8# 训练epoch数num_epochs=15# 用于特征提取的标志。为FALSE,微调整个模型,为TRUE只更新图层参数feature_extract=True

四、辅助函数

1、模型训练和验证

  • train_model函数处理给定模型的训练和验证。作为输入,它需要PyTorch模型、数据加载器字典、损失函数、优化器、用于训练和验 证epoch数,以及当模型是初始模型时的布尔标志。

  • is_inception标志用于容纳 Inception v3 模型,因为该体系结构使用辅助输出, 并且整体模型损失涉及辅助输出和最终输出,如此处所述。 这个函数训练指定数量的epoch,并且在每个epoch之后运行完整的验证步骤。它还跟踪最佳性能的模型(从验证准确率方面),并在训练 结束时返回性能最好的模型。在每个epoch之后,打印训练和验证正确率。

#%%模型训练和验证device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):    since=time.time()    val_acc_history=[]    best_model_wts=copy.deepcopy(model.state_dict())    best_acc=0.0    for epoch in range(num_epochs):        print('Epoch{}/{}'.format(epoch, num_epochs-1))        print('-'*10)        # 每个epoch都有一个训练和验证阶段        for phase in['train','val']:            if phase=='train':                model.train()            else:                model.eval()                            running_loss=0.0            running_corrects=0            # 迭代数据            for inputs,labels in dataloaders[phase]:                inputs=inputs.to(device)                labels=labels.to(device)                # 梯度置零                optimizer.zero_grad()                # 向前传播                with torch.set_grad_enabled(phase=='train'):                    # 获取模型输出并计算损失,开始的特殊情况在训练中他有一个辅助输出                    # 在训练模式下,通过将最终输出和辅助输出相加来计算损耗,在测试中值考虑最终输出                    if is_inception and phase=='train':                        outputs,aux_outputs=model(inputs)                        loss1=criterion(outputs,labels)                        loss2=criterion(aux_outputs,labels)                        loss=loss1+0.4*loss2                    else:                        outputs=model(inputs)                        loss=criterion(outputs,labels)                                            _,preds=torch.max(outputs,1)                                        if phase=='train':                        loss.backward()                        optimizer.step()                                        # 添加                running_loss+=loss.item()*inputs.size(0)                running_corrects+=torch.sum(preds==labels.data)                            epoch_loss=running_loss/len(dataloaders[phase].dataset)            epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)                        print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))                        if phase=='train' and epoch_acc>best_acc:                best_acc=epoch_acc                best_model_wts=copy.deepcopy(model.state_dict())            if phase=='val':                val_acc_history.append(epoch_acc)                    print()    time_elapsed=time.time()-since    print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))    print('best val acc:{:.4f}'.format(best_acc))        model.load_state_dict(best_model_wts)    return model,val_acc_history

2、设置模型参数的'.requires_grad属性'

当我们进行特征提取时,此辅助函数将模型中参数的 .requires_grad 属性设置为False。
默认情况下,当我们加载一个预训练模型时,所有参数都是 .requires_grad = True,如果我们从头开始训练或微调,这种设置就没问题。
但是,如果我们要运行特征提取并且只想为新初始化的层计算梯度,那么我们希望所有其他参数不需要梯度变化。

#%%设置模型参数的.require——grad属性def set_parameter_requires_grad(model,feature_extracting):    if feature_extracting:        for param in model.parameters():            param.require_grad=False

关于如何深入理解Pytorch微调torchvision模型问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注编程网行业资讯频道了解更多相关知识。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

如何深入理解Pytorch微调torchvision模型

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

如何深入理解Pytorch微调torchvision模型

如何深入理解Pytorch微调torchvision模型,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。一、简介在本小节,深入探讨如何对torchvision进行微调和特征提
2023-06-25

如何在PyTorch中进行模型的微调

在PyTorch中进行模型微调的步骤如下:加载预训练模型:首先,你需要加载一个预训练的模型。PyTorch提供了许多常见的预训练模型,如ResNet、VGG等。你可以使用torchvision.models中的模型来加载预训练模型。impo
如何在PyTorch中进行模型的微调
2024-03-14

如何深入理解 JavaScript 原型链?(JavaScript原型链怎样深入理解)

一、引言在JavaScript中,原型链是一个非常重要的概念,它对于理解对象的继承和属性访问机制起着关键作用。然而,对于许多初学者来说,原型链可能是一个比较抽象和难以理解的概念。本文将深入探讨JavaScript原型链的原理和工作方式,
如何深入理解 JavaScript 原型链?(JavaScript原型链怎样深入理解)
Java2024-12-21

如何深入理解Java设计模式的中介者模式

这期内容当中小编将会给大家带来有关如何深入理解Java设计模式的中介者模式,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。一、什么是中介者模式用一个中介对象来封装一系列的对象交互。中介者使各对象不需要显式地
2023-06-25

如何深入理解Java设计模式的迭代器模式

如何深入理解Java设计模式的迭代器模式,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。一、什么是迭代器模式迭代器模式是针对集合对象而生的,对于集合对象而言,肯定
2023-06-25

如何深入理解关系型数据库的三大范式

该文章,GitHub已收录,欢迎老板们前来Star!GitHub地址: https://github.com/Ziphtracks/JavaLearningmanual数据库范式一、什么是数据库范式 设计关系数据库时,遵从不同的规范要求,设计出合理的关系型数
如何深入理解关系型数据库的三大范式
2018-01-24

如何深入理解 Java 泛型中 extends 的继承关系?(如何理解Java泛型extends的继承关系)

在Java编程中,泛型是一种强大的特性,它允许我们编写更灵活、类型安全的代码。其中,extends关键字在泛型中的使用对于理解泛型的继承关系至关重要。一、extends关键字的基本概念
如何深入理解 Java 泛型中 extends 的继承关系?(如何理解Java泛型extends的继承关系)
Java2024-12-14

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录