我的编程空间,编程开发者的网络收藏夹
学习永远不晚

在pytorch中复制模型时出现问题如何解决

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

在pytorch中复制模型时出现问题如何解决

在pytorch中复制模型时出现问题如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。

直接使用

model2=model1

会出现当更新model2时,model1的权重也会更新,这和自己的初始目的不同。

经评论指出可以使用:

model2=copy.deepcopy(model1)

来实现深拷贝,手上没有pytorch环境,具体还没测试过,谁测试过可以和我说下有没有用。

原方法:

所有要使用模型复制可以使用如下方法。

torch.save(model, "net_params.pkl")model5=Cnn(3,10)model5=torch.load('net_params.pkl')

这样编写不会影响原始模型的权重

补充:pytorch模型训练流程中遇到的一些坑(持续更新)

要训练一个模型,主要分成几个部分,如下。

数据预处理

入门的话肯定是拿 MNIST 手写数据集先练习。

pytorch 中有帮助我们制作数据生成器的模块,其中有 Dataset、TensorDataset、DataLoader 等类可以来创建数据入口。

之前在 tensorflow 中可以用 dataset.from_generator() 的形式,pytorch 中也类似,目前我了解到的有两种方法可以实现。

第一种就继承 pytorch 定义的 dataset,改写其中的方法即可。如下,就获得了一个 DataLoader 生成器。

class MyDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __getitem__(self, index): return self.data[index], self.labels[index] def __len__(self): return len(self.labels) train_dataset = MyDataset(train_data, train_label)train_loader = DataLoader(dataset = train_dataset, batch_size = 1, shuffle = True)

第二种就是转换,先把我们准备好的数据转化成 pytorch 的变量(或者是 Tensor),然后传入 TensorDataset,再构造 DataLoader。

X = torch.from_numpy(train_data).float()Y = torch.from_numpy(train_label).float()train_dataset = TensorDataset(X, Y) train_loader = DataLoader(dataset = train_dataset, batch_size = 1, shuffle = True) #num_workers = 2)

模型定义

class Net(nn.Module):  def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6 ,16, 3)  self.fc1 = nn.Linear(400, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10)  def forward(self, x): relu = F.relu(self.conv1(x)) x = F.max_pool2d(relu, (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x)  return x  def num_flat_features(self, x): size = x.size()[1:] #除了batch_size之外的维度 num_features = 1 for s in size: num_features *= s return num_features

训练模型那么肯定要先定义一个网络结构,如上定义一个前向传播网络。里面包含了卷积层、全连接层、最大池化层和 relu 非线性激活层(名字我自己取的)以及一个 view 展开,把一个多维的特征图平展成一维的。

其中nn.Conv2d(in_channels, out_channels, kernel_size),第一个参数是输入的深度,第二是输出的深度,第三是卷积核的尺寸。

F.max_pool2d(input, (pool_size, pool_size)),第二个参数是池话

nn.Linear(in_features, out_features)

x.view是平展的操作,不过实际上相当于 numpy 的 reshape,需要计算转换后的尺寸。

损失函数定义

import torch.optim as optim criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

模型定义完之后,意味着给出输入,就可以得到输出的结果。那么就来比较 outputs 和 targets 之间的区别,那么就需要用到损失函数来描述。

训练网络

for epoch in range(2): # loop over the dataset multiple times  running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data  # zero the parameter gradients optimizer.zero_grad()  # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()  # print statistics running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches  print('[%d, %5d] loss: %.3f' %   (epoch + 1, i + 1, running_loss / 2000))  running_loss = 0.0 print('Finished Training')

以上的代码是官方教程中给出来的,我们要做的就是学习他的思路。

首先是 epoch 的数量为 2,每个 epoch 都会历遍一次整个训练集。在每个 epoch 内累积统计 running_loss,每 2000 个 batch 数据计算一次损失的平均值,然后 print 再重新将 running_loss 置为 0。

然后分 mini-batch 进行训练,在每个计算每个 mini-batch 的损失之前,都会将优化器 optimizer 中的梯度清空,防止不同 mini-batch 的梯度被累加到一起。更新分成两步:第一步计算损失函数,然后把总的损失分配到各个层中,即 loss.backward(),然后就使用优化器更新权重,即 optimizer.step()。

保存模型

PATH = '...'torch.save(net.state_dict(), PATH)

爬坑总结

总的来说流程就是上面那几步,但自己做的时候就遇到了挺多问题,最主要是对于其中张量传播过程中的要求不清楚,导致出了不少错误。

首先是输入的数据,pytorch 默认图片的 batch 数据的结构是(BATCH_SIZE, CHANNELS, IMG_H, IMG_W),所以要在生成数据时做一些调整,满足这种 BCHW 的规则。

会经常出现一些某个矩阵或者张量要求的数据,例如 “RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 ‘mat2'” 等错误信息。

可以使用 x.double(),y.float(),z.long() 等方式转换成他要求的格式。

RuntimeError: multi-target not supported。这个错误出现在损失函数那个地方,对于分类问题肯定是优先考虑交叉熵。

criterion = nn.CrossEntropyLoss()loss = criterion(outputs, labels.long())#报错的地方

当我batch-size=1时这个地方不会报错,但是当batch-size>1时就会报错。

查了别人的代码,大家基本都是和官方教程里面写的一样,使用官方的 mnist 数据接口,代码如下。一开始我是不愿意的,因为那样子意味着可能数据格式被封装起来看不见,但是自己折腾成本比较高,所以还是试了,真香!

train_dataset = datasets.MNIST(root='./data/',    train=True,    transform=transforms.ToTensor(),    download=True)train_loader = DataLoader(dataset = train_dataset,  batch_size = 4,  shuffle = True)

打印了一下从生成器中获得数据,看一下 size,发现果然和我自己写的不同。当 batch_size=4 时,数据 data.size() 都是4*1*28*28,这个是相同的;但是 labels.size() 是不同的,我写的是 one_hot 向量所以是 4*10,但它的是 4。

直接打印 labels 看看,果然,是单个指,例如 tensor([3, 2, 6, 2]) 这样。

不过模型的 outputs 依然是 4*10,看来是 nn.CrossEntropyLoss() 这个函数自己会做计算,所以他才会报错说 multi-target not supported,因为 lables.size() 不对,原本只有一个数字,但现在是10个数字,相当于被分配了10个属性,自然就报错啦。

所以稍微修改了自己写的生成器之后,就没问题了。

不过,如果想要更自由的调用数据,还是需要对对象进行一些方法的重载,使用 pytoch 定义的 DataLoader,用 enumerate,就会把所有的数据历遍一次,如果使用 iter() 得到一个可迭代对象之后 next(),并不可以像 tensorflow 那样子生成训练数据。

例如说,如果使用如上的形式,DataLoader 得到的是一个生成器,python 中的生成器对象主要有 __next__ 和 __iter__ 等魔术方法决定。

__iter__ 方法使得实例可以如下调用,可以得到一个可迭代对象,iterable,但是如果不加也没关系,因为更重要的是 __next__ 类方法。

如下自己写了 __next__ 方法之后就可以看到,原本会出现越界的现象不见了,可以循环的历遍数据,当然也可以想被注释的那部分一样,抛出 StopIteration 来终止。

a = A()a_iter = iter(a)class A(): def __init__(self): self.list = [1,2,3] self.index = 0 #def __getitem__(self, index): # return self.list[i] #def __iter__(self): # return self def __next__(self): #for i in range(): if self.index >= len(self.list): #raise StopIteration  self.index = self.index%len(self.list) result = self.list[self.index] self.index += 1 return result b = A() for i in range(20): print(next(b))

关于在pytorch中复制模型时出现问题如何解决问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注编程网行业资讯频道了解更多相关知识。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

在pytorch中复制模型时出现问题如何解决

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

在pytorch中复制模型时出现问题如何解决

在pytorch中复制模型时出现问题如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。直接使用model2=model1会出现当更新model2时,model1的权重也
2023-06-06

pytorch网络模型构建场景的问题如何解决

今天小编给大家分享一下pytorch网络模型构建场景的问题如何解决的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。网络模型构建
2023-07-05

vuejs在解析时出现闪烁问题如何解决

这篇文章主要介绍“vuejs在解析时出现闪烁问题如何解决”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“vuejs在解析时出现闪烁问题如何解决”文章能帮助大家解决问题。原因: 在使用vuejs、ang
2023-07-04

如何在pytorch中解决state_dict()的拷贝问题

如何在pytorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。model.state_dict()是浅拷贝,返回的参
2023-06-06

在Tomcat中访问localhost时出现404如何解决

今天就跟大家聊聊有关在Tomcat中访问localhost时出现404如何解决,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。错误的路径配置如下:正确的配置应该是:但是当我这样操作之后
2023-06-14

在Android中使用AutoWrapTextView时出现中英文排版问题如何解决

这篇文章将为大家详细讲解有关在Android中使用AutoWrapTextView时出现中英文排版问题如何解决,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。实现首先创建一个继承自View的A
2023-05-31

如何解决InternetExplorer9安装时出现的问题

要解决Internet Explorer 9安装时出现的问题,可以尝试以下方法:1. 检查系统要求:确保你的计算机符合Internet Explorer 9的系统要求。例如,你的操作系统是否为Windows 7或更高版本。2. 关闭防火墙和
2023-09-07

如何解决go Fscanf在读取文件时出现的问题

这篇文章将为大家详细讲解有关如何解决go Fscanf在读取文件时出现的问题,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。先要明白Fscanf的工作原理Fscanf在遇到\n才结束遇到\r时就会把\r替换
2023-06-14

如何解决CentOS MAKE中出现的问题

如何解决CentOS MAKE中出现的问题,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。在CentOS MAKE的使用中会出现很多的问题,这次,我就碰到了CentOS MA
2023-06-16

labview执行请求时出现问题如何解决

当LabVIEW执行请求时出现问题,可以尝试以下解决方法:1. 检查错误信息:LabVIEW会提供详细的错误信息,可以通过查看错误信息来了解问题的具体原因。根据错误信息进行排查和修复。2. 调试程序:使用LabVIEW的调试工具,例如断点、
2023-09-15

连接服务器时出现问题如何解决

连接服务器时出现问题可能是由于网络连接或者服务器上的相关设置不正确引起的。可以通过以下步骤来解决:1、重新连接网络确保网络连接正常,若重新连接网络无效,可能是 IP 地址问题,这时应该查看服务器 IP 地址是否正确,或重新配置 IP 地址。
2023-03-11

如何解决canvas在移动端绘制模糊的问题

小编给大家分享一下如何解决canvas在移动端绘制模糊的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!由于一些移动端的兼容性原因,我们某个项目需要前端将pdf
2023-06-09

如何解决编写代码时出现的Go问题

这篇文章主要介绍“如何解决编写代码时出现的Go问题”,在日常操作中,相信很多人在如何解决编写代码时出现的Go问题问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”如何解决编写代码时出现的Go问题”的疑惑有所帮助!
2023-06-15

如何解决PIP安装python包出现超时问题

这篇文章给大家分享的是有关如何解决PIP安装python包出现超时问题的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。我们在使用pip默认源进行模块安装时,经常会超时问题导致不能下载。如图所示: 下面是解决方法—
2023-06-14

win10运行此工具时出现问题如何解决

首先,您可以尝试以下几种方法来解决Windows 10运行此工具时出现的问题:1. 更新Windows 10:确保您的操作系统已经安装了最新的更新。打开Windows设置,点击“更新和安全”,然后点击“检查更新”。2. 重新安装工具:如果问
2023-10-09

怎么解决复制网页上面的一些文字时出现了无法复制问题

这篇文章主要介绍了怎么解决复制网页上面的一些文字时出现了无法复制问题,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。  1、首先打开浏览器,然后点击浏览器上方的“工具--Int
2023-06-13

如何解决LNMP安装composer install时出现Warning: putenv()问题

小编给大家分享一下如何解决LNMP安装composer install时出现Warning: putenv()问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!L
2023-06-14

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录