我的编程空间,编程开发者的网络收藏夹
学习永远不晚

PyTorch加载模型model.load_state_dict()问题及解决

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

PyTorch加载模型model.load_state_dict()问题及解决

PyTorch加载模型model.load_state_dict()问题

希望将训练好的模型加载到新的网络上。

如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题。

Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。

表明了加载过程中,期望获得的key值为feature...,而不是module.features....。

这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。

You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

解决上面的问题有三个办法: 

1. 对load的模型创建新的字典

去掉不需要的key值"module".

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
    new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

2. 直接用空白''代替'module.'

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
 
# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

3. 最简单的方法

加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。

如果有多个GPU,将模型并行化,用DataParallel来操作。

这个过程会将key值加一个"module. ***"。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

4. 总结

从出错显示的问题就可以看出,key值不匹配,因此可以选择多种方法,将模型参数加载进去。

这个方法通常会在load_state_dict过程中遇到。将训练好的一个网络参数,移植到另外一个网络上面,继续训练。

或者将训练好的网络checkpoint加载进模型,再次进行训练。可以打印出model state_dict来看出两者的差别。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

features.0.0.weight   
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():
    print(k) 
print("*****************************************")
 

输出结果为:

module.features.0.0.weight",

"module.features.0.1.weight",

"module.features.0.1.bias

可以看出不匹配,模型的参数中,key值不同,多了module。

PS: 追加

在移植参数的过程中,对于出现 .total_ops和.total_params结尾的参数,可参考以下代码:

from collections import OrderedDict
checkpoint = torch.load(
    pretrained_model_file_path,
    map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if not k.endswith('total_ops') and not k.endswith('total_params'):
        name = k[7:]
        new_state_dict[name] = v

最后

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

PyTorch加载模型model.load_state_dict()问题及解决

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

PyTorch加载模型model.load_state_dict()问题及解决

这篇文章主要介绍了PyTorch加载模型model.load_state_dict()问题及解决,具有很好的参考价值,希望对大家有所帮助。
2023-02-03

pytorch模型保存与加载问题怎么解决

这篇“pytorch模型保存与加载问题怎么解决”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“pytorch模型保存与加载问题
2023-07-04

模型的保存加载、模型微调、GPU使用及Pytorch常见报错

序列化与反序列化 序列化就是说内存中的某一个对象保存到硬盘当中,以二进制序列的形式存储下来,这就是一个序列化的过程。 而反序列化,就是将硬盘中存储的二进制的数,反序列化到内存当中,得到一个相应的对象,这样就可以再次使用这个模型了。 序列化和
2023-08-30

在pytorch中复制模型时出现问题如何解决

在pytorch中复制模型时出现问题如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。直接使用model2=model1会出现当更新model2时,model1的权重也
2023-06-06

pytorch网络模型构建场景的问题如何解决

今天小编给大家分享一下pytorch网络模型构建场景的问题如何解决的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。网络模型构建
2023-07-05

pytorch模型的保存加载与续训练详解

这篇文章主要为大家介绍了pytorch模型的保存加载与续训练详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2022-11-13

ThinPHP无法加载模块问题的解决方案

ThinPHP无法加载模块可能是因模块未安装、路径不正确、权限错误或PHP配置等原因造成的。解决方法包括检查模块安装、验证模块路径、调整文件权限、检查PHP配置、启用调试、清除缓存、重启服务器、查看日志文件、安装调试包,以及寻求社区支持。
ThinPHP无法加载模块问题的解决方案
2024-04-02

Linux下如何解决IPV6模块加载失败问题

这篇文章主要为大家展示了“Linux下如何解决IPV6模块加载失败问题”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Linux下如何解决IPV6模块加载失败问题”这篇文章吧。同事一个SUSE L
2023-06-27

Hibernate Lazy加载问题怎么解决

这篇文章主要讲解了“Hibernate Lazy加载问题怎么解决”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Hibernate Lazy加载问题怎么解决”吧!Hbm文件2023-06-03

Pytorch之ToPILImage()不输出图片问题及解决

使用Pytorch的ToPILImage()函数时可能遇到的图片输出问题,通常与数据类型错误、张量值超出范围或通道顺序错误有关。解决方案包括将张量转换为uint8类型、归一化张量值范围并调整通道顺序。此外,确保张量形状正确,已安装Pillow库,并考虑使用其他方法显示图像,例如使用NumPy数组和matplotlib或OpenCV。
Pytorch之ToPILImage()不输出图片问题及解决
2024-04-02

Android webView加载数据时内存溢出问题及解决

这篇文章主要介绍了Android webView加载数据时内存溢出问题及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2022-12-08

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录