我的编程空间,编程开发者的网络收藏夹
学习永远不晚

如何在Pytorch中实现一个模型迁移功能

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

如何在Pytorch中实现一个模型迁移功能

这篇文章给大家介绍如何在Pytorch中实现一个模型迁移功能,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。

1. 利用resnet18做迁移学习

import torchfrom torchvision import models if __name__ == "__main__":  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  device = 'cpu'  print("-----device:{}".format(device))  print("-----Pytorch version:{}".format(torch.__version__))   input_tensor = torch.zeros(1, 3, 100, 100)  print('input_tensor:', input_tensor.shape)  pretrained_file = "model/resnet18-5c106cde.pth"  model = models.resnet18()  model.load_state_dict(torch.load(pretrained_file))  model.eval()  out = model(input_tensor)  print("out:", out.shape, out[0, 0:10])

结果输出:

input_tensor: torch.Size([1, 3, 100, 100])
out: torch.Size([1, 1000]) tensor([ 0.4010, 0.8436, 0.3072, 0.0627, 0.4446, 0.8470, 0.1882, 0.7012,0.2988, -0.7574], grad_fn=<SliceBackward>)

如果,我们修改了resnet18的网络结构,如何将原来预训练模型参数(resnet18-5c106cde.pth)迁移到新的resnet18网络中呢?

比如,这里将官方的resnet18的self.layer4 = self._make_layer(block, 512, layers[3], stride=2)改为:self.layer44 = self._make_layer(block, 512, layers[3], stride=2)

class ResNet(nn.Module):   def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):    super(ResNet, self).__init__()    self.inplanes = 64    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,                bias=False)    self.bn1 = nn.BatchNorm2d(64)    self.relu = nn.ReLU(inplace=True)    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)    self.layer1 = self._make_layer(block, 64, layers[0])    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)    self.layer44 = self._make_layer(block, 512, layers[3], stride=2)    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))    self.fc = nn.Linear(512 * block.expansion, num_classes)     for m in self.modules():      if isinstance(m, nn.Conv2d):        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')      elif isinstance(m, nn.BatchNorm2d):        nn.init.constant_(m.weight, 1)        nn.init.constant_(m.bias, 0)     # Zero-initialize the last BN in each residual branch,    # so that the residual branch starts with zeros, and each residual block behaves like an identity.    # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677    if zero_init_residual:      for m in self.modules():        if isinstance(m, Bottleneck):          nn.init.constant_(m.bn3.weight, 0)        elif isinstance(m, BasicBlock):          nn.init.constant_(m.bn2.weight, 0)   def _make_layer(self, block, planes, blocks, stride=1):    downsample = None    if stride != 1 or self.inplanes != planes * block.expansion:      downsample = nn.Sequential(        conv1x1(self.inplanes, planes * block.expansion, stride),        nn.BatchNorm2d(planes * block.expansion),      )     layers = []    layers.append(block(self.inplanes, planes, stride, downsample))    self.inplanes = planes * block.expansion    for _ in range(1, blocks):      layers.append(block(self.inplanes, planes))     return nn.Sequential(*layers)   def forward(self, x):    x = self.conv1(x)    x = self.bn1(x)    x = self.relu(x)    x = self.maxpool(x)     x = self.layer1(x)    x = self.layer2(x)    x = self.layer3(x)    x = self.layer44(x)     x = self.avgpool(x)    x = x.view(x.size(0), -1)    x = self.fc(x)     return x

这时,直接加载模型:

  model = models.resnet18()  model.load_state_dict(torch.load(pretrained_file))

这时,肯定会报错,类似:Missing key(s) in state_dict或者Unexpected key(s) in state_dict的错误:

RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "layer44.0.conv1.weight", "layer44.0.bn1.weight", "layer44.0.bn1.bias", "layer44.0.bn1.running_mean", "layer44.0.bn1.running_var", "layer44.0.conv2.weight", "layer44.0.bn2.weight", "layer44.0.bn2.bias", "layer44.0.bn2.running_mean", "layer44.0.bn2.running_var", "layer44.0.downsample.0.weight", "layer44.0.downsample.1.weight", "layer44.0.downsample.1.bias", "layer44.0.downsample.1.running_mean", "layer44.0.downsample.1.running_var", "layer44.1.conv1.weight", "layer44.1.bn1.weight", "layer44.1.bn1.bias", "layer44.1.bn1.running_mean", "layer44.1.bn1.running_var", "layer44.1.conv2.weight", "layer44.1.bn2.weight", "layer44.1.bn2.bias", "layer44.1.bn2.running_mean", "layer44.1.bn2.running_var".
Unexpected key(s) in state_dict: "layer4.0.conv1.weight", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.conv2.weight", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.1.conv1.weight", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.conv2.weight", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias".

Process finished with

RuntimeError: Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: "layer4.0.conv1.weight", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.conv2.weight", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.1.conv1.weight", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.conv2.weight", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias".

我们希望将原来预训练模型参数(resnet18-5c106cde.pth)迁移到新的resnet18网络,当然只能迁移二者相同的模型参数,不同的参数还是随机初始化的.

 def transfer_model(pretrained_file, model):  '''  只导入pretrained_file部分模型参数  tensor([-0.7119, 0.0688, -1.7247, -1.7182, -1.2161, -0.7323, -2.1065, -0.5433,-1.5893, -0.5562]  update:    D.update([E, ]**F) -> None. Update D from dict/iterable E and F.    If E is present and has a .keys() method, then does: for k in E: D[k] = E[k]    If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v    In either case, this is followed by: for k in F: D[k] = F[k]  :param pretrained_file:  :param model:  :return:  '''  pretrained_dict = torch.load(pretrained_file) # get pretrained dict  model_dict = model.state_dict() # get model dict  # 在合并前(update),需要去除pretrained_dict一些不需要的参数  pretrained_dict = transfer_state_dict(pretrained_dict, model_dict)  model_dict.update(pretrained_dict) # 更新(合并)模型的参数  model.load_state_dict(model_dict)  return model def transfer_state_dict(pretrained_dict, model_dict):  '''  根据model_dict,去除pretrained_dict一些不需要的参数,以便迁移到新的网络  url: https://blog.csdn.net/qq_34914551/article/details/87871134  :param pretrained_dict:  :param model_dict:  :return:  '''  # state_dict2 = {k: v for k, v in save_model.items() if k in model_dict.keys()}  state_dict = {}  for k, v in pretrained_dict.items():    if k in model_dict.keys():      # state_dict.setdefault(k, v)      state_dict[k] = v    else:      print("Missing key(s) in state_dict :{}".format(k))  return state_dict if __name__ == "__main__":   input_tensor = torch.zeros(1, 3, 100, 100)  print('input_tensor:', input_tensor.shape)  pretrained_file = "model/resnet18-5c106cde.pth"  # model = resnet18()  # model.load_state_dict(torch.load(pretrained_file))  # model.eval()  # out = model(input_tensor)  # print("out:", out.shape, out[0, 0:10])   model1 = resnet18()  model1 = transfer_model(pretrained_file, model1)  out1 = model1(input_tensor)  print("out1:", out1.shape, out1[0, 0:10])

2. 修改网络名称并迁移学习

上面的例子,只是将官方的resnet18的self.layer4 = self._make_layer(block, 512, layers[3], stride=2)改为了:self.layer44 = self._make_layer(block, 512, layers[3], stride=2),我们仅仅是修改了一个网络名称而已,就导致 model.load_state_dict(torch.load(pretrained_file))出错,

那么,我们如何将预训练模型"model/resnet18-5c106cde.pth"转换成符合新的网络的模型参数呢?

方法很简单,只需要将resnet18-5c106cde.pth的模型参数中所有前缀为layer4的名称,改为layer44即可

本人已经定义好了方法:

modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)
def string_rename(old_string, new_string, start, end):  new_string = old_string[:start] + new_string + old_string[end:]  return new_string def modify_model(pretrained_file, model, old_prefix, new_prefix):  '''  :param pretrained_file:  :param model:  :param old_prefix:  :param new_prefix:  :return:  '''  pretrained_dict = torch.load(pretrained_file)  model_dict = model.state_dict()  state_dict = modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)  model.load_state_dict(state_dict)  return model  def modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix):  '''  修改model dict  :param pretrained_dict:  :param model_dict:  :param old_prefix:  :param new_prefix:  :return:  '''  state_dict = {}  for k, v in pretrained_dict.items():    if k in model_dict.keys():      # state_dict.setdefault(k, v)      state_dict[k] = v    else:      for o, n in zip(old_prefix, new_prefix):        prefix = k[:len(o)]        if prefix == o:          kk = string_rename(old_string=k, new_string=n, start=0, end=len(o))          print("rename layer modules:{}-->{}".format(k, kk))          state_dict[kk] = v  return state_dict
if __name__ == "__main__":  input_tensor = torch.zeros(1, 3, 100, 100)  print('input_tensor:', input_tensor.shape)  pretrained_file = "model/resnet18-5c106cde.pth"  # model = models.resnet18()  # model.load_state_dict(torch.load(pretrained_file))  # model.eval()  # out = model(input_tensor)  # print("out:", out.shape, out[0, 0:10])  #  # model1 = resnet18()  # model1 = transfer_model(pretrained_file, model1)  # out1 = model1(input_tensor)  # print("out1:", out1.shape, out1[0, 0:10])  #  new_file = "new_model.pth"  model = resnet18()  new_model = modify_model(pretrained_file, model, old_prefix=["layer4"], new_prefix=["layer44"])  torch.save(new_model.state_dict(), new_file)   model2 = resnet18()  model2.load_state_dict(torch.load(new_file))  model2.eval()  out2 = model2(input_tensor)  print("out2:", out2.shape, out2[0, 0:10])

这时,输出,跟之前一模一样了。

out: torch.Size([1, 1000]) tensor([ 0.4010, 0.8436, 0.3072, 0.0627, 0.4446, 0.8470, 0.1882, 0.7012,0.2988, -0.7574], grad_fn=<SliceBackward>)

3.去除原模型的某些模块

下面是在不修改原模型代码的情况下,通过"resnet18.named_children()"和"resnet18.children()"的方法去除子模块"fc"和"avgpool"

import torchimport torchvision.models as modelsfrom collections import OrderedDict if __name__=="__main__":  resnet18 = models.resnet18(False)  print("resnet18",resnet18)   # use named_children()  resnet18_v1 = OrderedDict(resnet18.named_children())  # remove avgpool,fc  resnet18_v1.pop("avgpool")  resnet18_v1.pop("fc")  resnet18_v1 = torch.nn.Sequential(resnet18_v1)  print("resnet18_v1",resnet18_v1)  # use children  resnet18_v2 = torch.nn.Sequential(*list(resnet18.children())[:-2])  print(resnet18_v2,resnet18_v2)

补充:pytorch导入(部分)模型参数

背景介绍:

我的想法是把一个预训练的网络的参数导入到我的模型中,但是预训练模型的参数只是我模型参数的一小部分,怎样导进去不出差错了,请来听我说说。

解法

首先把你需要添加参数的那一小部分模型提取出来,并新建一个类进行重新定义,如图向Alexnet中添加前三层的参数,重新定义前三层。

如何在Pytorch中实现一个模型迁移功能

接下来就是导入参数

checkpoint = torch.load(config.pretrained_model)    # change name and load parameters    model_dict = model.net1.state_dict()    checkpoint = {k.replace('features.features', 'featureExtract1'): v for k, v in checkpoint.items()}    checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict.keys()}     model_dict.update(checkpoint)    model.net1.load_state_dict(model_dict)

关于如何在Pytorch中实现一个模型迁移功能就分享到这里了,希望以上内容可以对大家有一定的帮助,可以学到更多知识。如果觉得文章不错,可以把它分享出去让更多的人看到。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

如何在Pytorch中实现一个模型迁移功能

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

如何在Pytorch中实现一个模型迁移功能

这篇文章给大家介绍如何在Pytorch中实现一个模型迁移功能,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。1. 利用resnet18做迁移学习import torchfrom torchvision import mo
2023-06-06

如何在PyTorch中创建一个神经网络模型

在PyTorch中创建神经网络模型通常需要定义一个继承自torch.nn.Module类的自定义类。下面是一个简单的示例:import torchimport torch.nn as nnclass SimpleNN(nn.Module
如何在PyTorch中创建一个神经网络模型
2024-03-05

如何在Python中使用Tqdm模块实现一个进度条功能

本文章向大家介绍如何在Python中使用Tqdm模块实现一个进度条功能,主要包括如何在Python中使用Tqdm模块实现一个进度条功能的使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。Pytho
2023-06-06

如何在Android中实现一个计时器功能

本篇文章为大家展示了如何在Android中实现一个计时器功能,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。Android是什么Android是一种基于Linux内核的自由及开放源代码的操作系统,主要
2023-06-14

如何在Android应用中实现一个侧滑功能

本篇文章给大家分享的是有关如何在Android应用中实现一个侧滑功能,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。实现说明:通过自定义布局实现:SlidingLayout继承于
2023-05-31

怎么在SQL Server中实现一个模糊查询功能

怎么在SQL Server中实现一个模糊查询功能?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。1.用_通配符查询"_"号表示任意单个字符,该字符号只能匹配一个字
2023-06-14

如何在Android开发中中实现一个App更新功能

如何在Android开发中中实现一个App更新功能?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。开发环境:AndroidStudio2.1.2+gradle-2.10部分代
2023-05-31

如何在Android中利用MediaRecorder实现一个录像功能

今天就跟大家聊聊有关如何在Android中利用MediaRecorder实现一个录像功能,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。在AndroidManifest.xml加入以下
2023-05-31

如何在Android应用中实现一个返回键功能

今天就跟大家聊聊有关如何在Android应用中实现一个返回键功能,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。记录用户点击的操作历史,使用栈数据结构,频繁的操作栈顶(添加,获取,删除
2023-05-31

如何在Golang中使用WebSocket实现一个通信功能

本篇文章给大家分享的是有关如何在Golang中使用WebSocket实现一个通信功能,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。什么是golanggolang 是Google
2023-06-06

如何在Android中实现一个在图片中添加文字功能

这篇文章给大家介绍如何在Android中实现一个在图片中添加文字功能,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。Android自定义实现图片加文字功能分四步来写: 1,组合控件的xml; 2,自定义组合控件的属性;
2023-05-31

如何在Android应用中实现一个图片添加功能

如何在Android应用中实现一个图片添加功能?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。1、首先这是用GridView实现的2023-05-31

如何在Android中利用Dialog实现一个对话框功能

今天就跟大家聊聊有关如何在Android中利用Dialog实现一个对话框功能,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。一、普通对话框AlertDialog.Builder bui
2023-05-31

如何在Android应用中实现一个记住密码功能

本篇文章给大家分享的是有关如何在Android应用中实现一个记住密码功能,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、打开之前完成的Case_login进行修改再编辑二、将
2023-05-31

如何在vue中使用video实现一个播放器功能

这期内容当中小编将会给大家带来有关如何在vue中使用video实现一个播放器功能,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。当现有video播放器不能满足需求时,需要自己对video进行封装。video
2023-06-06

如何在java项目中实现一个递归调用功能

本篇文章为大家展示了如何在java项目中实现一个递归调用功能,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。Java的特点有哪些Java的特点有哪些1.Java语言作为静态面向对象编程语言的代表,实现
2023-06-06

如何在Android应用中实现一个手势密码功能

如何在Android应用中实现一个手势密码功能?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。实现思路:1. 正上方的提示区域,用一个类(LockIndicato
2023-05-31

怎么在python中使用translate模块实现一个翻译功能

怎么在python中使用translate模块实现一个翻译功能?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。Python的优点有哪些1、简单易用,与C/C++、
2023-06-14

如何在Android中实现一个滑块拼图验证码功能

本篇文章给大家分享的是有关如何在Android中实现一个滑块拼图验证码功能,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、实现步骤:1、定义自定义属性; 2、确认目标位置,这
2023-06-06

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录