我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Python实现LeNet网络模型的训练及预测

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Python实现LeNet网络模型的训练及预测

1.LeNet模型训练脚本

整体的训练代码如下,下面我会为大家详细讲解这些代码的意思


import torch
import torchvision
from torchvision.transforms import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
from pytorch.lenet.model import LeNet
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose(
    # 将数据集转换成tensor形式
    [transforms.ToTensor(),
     # 进行标准化,0.5是均值,也是方差,对应三个维度都是0.5
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

# 下载完整的数据集时,download=True,第一个为保存的路径,下载完后download要改为False
# 为训练集时,train=True,为测试集时,train=False
train_set = torchvision.datasets.CIFAR10('./data', train=True,
                                         download=False, transform=transform)

# 加载训练集,设置批次大小,是否打乱,number_works是线程数,window不设置为0会报错,linux可以设置非零
train_loader = DataLoader(train_set, batch_size=36,
                          shuffle=True, num_workers=0)

test_set = torchvision.datasets.CIFAR10('./data', train=False,
                                        download=False, transform=transform)
# 设置的批次大小一次性将所有测试集图片传进去
test_loader = DataLoader(test_set, batch_size=10000,
                         shuffle=False, num_workers=0)

# 迭代测试集的图片数据和标签值
test_img, test_label = next(iter(test_loader))

# CIFAR10的十个类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# # ----------------------------显示图片-----------------------------------
# def imshow(img, label):
#     fig = plt.figure()
#     for i in range(len(img)):
#         ax = fig.add_subplot(1, len(img), i+1)
#         nping = img[i].numpy().transpose([1, 2, 0])
#         npimg = (nping * 2 + 0.5)
#         plt.imshow(npimg)
#         title = '{}'.format(classes[label[i]])
#         ax.set_title(title)
#         plt.axis('off')
#     plt.show()
# 
# 
# batch_image = test_img[: 5]
# label_img = test_label[: 5]
# imshow(batch_image, label_img)
# # ----------------------------------------------------------------------

net = LeNet()
# 定义损失函数,nn.CrossEntropyLoss()自带softmax函数,所以模型的最后一层不需要softmax进行激活
loss_function = nn.CrossEntropyLoss()
# 定义优化器,优化网络模型所有参数
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 迭代五次
for epoch in range(5):
    # 初始损失设置为0
    running_loss = 0
    # 循环训练集,从1开始
    for step, data in enumerate(train_loader, start=1):
        inputs, labels = data
        # 优化器的梯度清零,每次循环都需要清零,否则梯度会无限叠加,相当于增加批次大小
        optimizer.zero_grad()
        # 将图片数据输入模型中
        outputs = net(inputs)
        # 传入预测值和真实值,计算当前损失值
        loss = loss_function(outputs, labels)
        # 损失反向传播
        loss.backward()
        # 进行梯度更新
        optimizer.step()
        # 计算该轮的总损失,因为loss是tensor类型,所以需要用item()取具体值
        running_loss += loss.item()
        # 每500次进行日志的打印,对测试集进行预测
        if step % 500 == 0:
            # torch.no_grad()就是上下文管理,测试时不需要梯度更新,不跟踪梯度
            with torch.no_grad():
                # 传入所有测试集图片进行预测
                outputs = net(test_img)
                # torch.max()中dim=1是因为结果为(batch, 10)的形式,我们只需要取第二个维度的最大值
                # max这个函数返回[最大值, 最大值索引],我们只需要取索引就行了,所以用[1]
                predict_y = torch.max(outputs, dim=1)[1]
                # (predict_y == test_label)相同返回True,不相等返回False,sum()对正确率进行叠加
                # 因为计算的变量都是tensor,所以需要用item()拿到取值
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                # running_loss/500是计算每一个step的loss,即每一步的损失
                print('[%d, %5d] train_loss: %.3f   test_accuracy: %.3f' %
                      (epoch+1, step, running_loss/500, accuracy))
                running_loss = 0.0

print('Finished Training!')

save_path = 'lenet.pth'
# 保存模型,字典形式
torch.save(net.state_dict(), save_path)

(1).下载CIFAR10数据集

首先要训练一个网络模型,我们需要足够多的图片做数据集,这里我们用的是torchvision.dataset为我们提供的CIFAR10数据集(更多的数据集可以去pytorch官网查看pytorch官网提供的数据集)


train_set = torchvision.datasets.CIFAR10('./data', train=True,
                                         download=False, transform=transform)
test_set = torchvision.datasets.CIFAR10('./data', train=False,
                                        download=False, transform=transform)

这部分代码是下载CIFAR10,第一个参数是下载数据集后存放的路径,train=True和False对应下载的训练集和测试集,transform是对应的图像增强方式

(2).图像增强


transform = transforms.Compose(
    # 将数据集转换成tensor形式
    [transforms.ToTensor(),
     # 进行标准化,0.5是均值,也是方差,对应三个维度都是0.5
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

这就是简单的图像图像增强,transforms.ToTensor()将数据集的所有图像转换成tensor, transforms.Normalize()是标准化处理,包含两个元组对应均值和标准差,每个元组包含三个元素对应图片的三个维度[channels, height, width],为什么是这样排序,别问,问就是pytorch要求的,顺序不能变,之后会看到transforms.Normalize([0.485, 0.406, 0.456], [0.229, 0.224, 0.225])这两组数据,这是官方给出的均值和标准差,之后标准化的时候会经常用到

(3).加载数据集


# 加载训练集,设置批次大小,是否打乱,number_works是线程数,window不设置为0会报错,linux可以设置非零
train_loader = DataLoader(dataset=train_set, batch_size=36,
                          shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_set, batch_size=36,
                         shuffle=False, num_workers=0)

这里只简单的设置的四个参数也是比较重要的,第一个就是需要加载的训练集和测试集,shuffle=True表示将数据集打乱,batch_size表示一次性向设备放入36张图片,打包成一个batch,这时图片的shape就会从[3, 32, 32]----》[36, 3, 32, 32],传入网络模型的shape也必须是[None, channels, height, width],None代表一个batch多少张图片,否则就会报错,number_works是代表线程数,window系统必须设置为0,否则会报错,linux系统可以设置非0数

(4).显示部分图像


def imshow(img, label):
    fig = plt.figure()
    for i in range(len(img)):
        ax = fig.add_subplot(1, len(img), i+1)
        nping = img[i].numpy().transpose([1, 2, 0])
        npimg = (nping * 2 + 0.5)
        plt.imshow(npimg)
        title = '{}'.format(classes[label[i]])
        ax.set_title(title)
        plt.axis('off')
    plt.show()


batch_image = test_img[: 5]
label_img = test_label[: 5]
imshow(batch_image, label_img)

这部分代码是显示测试集当中前五张图片,运行后会显示5张拼接的图片

由于这个数据集的图片都比较小都是32x32的尺寸,有些可能也看的不太清楚,图中显示的是真实标签,注:显示图片的代码可能会这个报警(Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).),警告解决的方法:将图片数组转成uint8类型即可,即 plt.imshow(npimg.astype(‘uint8'),但是那样显示出来的图片会变,所以暂时可以先不用管。

(5).初始化模型

数据图片处理完了,下面就是我们的正式训练过程


net = LeNet()
# 定义损失函数,nn.CrossEntropyLoss()自带softmax函数,所以模型的最后一层不需要softmax进行激活
loss_function = nn.CrossEntropyLoss()
# 定义优化器,优化模型所有参数
optimizer = optim.Adam(net.parameters(), lr=0.001)

首先初始化LeNet网络,定义交叉熵损失函数,以及Adam优化器,关于注释写的,我们可以ctrl+鼠标左键查看CrossEntropyLoss(),翻到CrossEntropyLoss类,可以看到注释写的这个标准包含LogSoftmax函数,所以搭建LetNet模型的最后一层没有使用softmax激活函数

(6).训练模型及保存模型参数


for epoch in range(5):
    # 初始损失设置为0
    running_loss = 0
    # 循环训练集,从1开始
    for step, data in enumerate(train_loader, start=1):
        inputs, labels = data
        # 优化器的梯度清零,每次循环都需要清零,否则梯度会无限叠加,相当于增加批次大小
        optimizer.zero_grad()
        # 将图片数据输入模型中得到输出
        outputs = net(inputs)
        # 传入预测值和真实值,计算当前损失值
        loss = loss_function(outputs, labels)
        # 损失反向传播
        loss.backward()
        # 进行梯度更新(更新W,b)
        optimizer.step()
        # 计算该轮的总损失,因为loss是tensor类型,所以需要用item()取到值
        running_loss += loss.item()
        # 每500次进行日志的打印,对测试集进行测试
        if step % 500 == 0:
            # torch.no_grad()就是上下文管理,测试时不需要梯度更新,不跟踪梯度
            with torch.no_grad():
                # 传入所有测试集图片进行预测
                outputs = net(test_img)
                # torch.max()中dim=1是因为结果为(batch, 10)的形式,我们只需要取第二个维度的最大值,第二个维度是包含十个类别每个类别的概率的向量
                # max这个函数返回[最大值, 最大值索引],我们只需要取索引就行了,所以用[1]
                predict_y = torch.max(outputs, dim=1)[1]
                # (predict_y == test_label)相同返回True,不相等返回False,sum()对正确结果进行叠加,最后除测试集标签的总个数
                # 因为计算的变量都是tensor,所以需要用item()拿到取值
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                # running_loss/500是计算每一个step的loss,即每一步的损失
                print('[%d, %5d] train_loss: %.3f   test_accuracy: %.3f' %
                      (epoch+1, step, running_loss/500, accuracy))
                running_loss = 0.0
                
print('Finished Training!')

save_path = 'lenet.pth'
# 保存模型,字典形式
torch.save(net.state_dict(), save_path)

这段代码注释写的很清楚,大家仔细看就能看懂,流程不复杂,多看几遍就能理解,最后再对训练好的模型进行保存就好了(* ̄︶ ̄)

2.预测脚本

上面已经训练好了模型,得到了lenet.pth参数文件,预测就很简单了,可以去网上随便找一张数据集包含的类别图片,将模型参数文件载入模型,通过对图像进行一点处理,喂入模型即可,下面奉上代码:


import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from pytorch.lenet.model import LeNet

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

transforms = transforms.Compose(
    # 对数据图片调整大小
    [transforms.Resize([32, 32]),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

net = LeNet()
# 加载预训练模型
net.load_state_dict(torch.load('lenet.pth'))
# 网上随便找的猫的图片
img_path = '../../Photo/cat2.jpg'
img = Image.open(img_path)
# 图片的处理
img = transforms(img)
# 增加一个维度,(channels, height, width)------->(batch, channels, height, width),pytorch要求必须输入这样的shape
img = torch.unsqueeze(img, dim=0)

with torch.no_grad():
    output = net(img)
    # dim=1,只取[batch, 10]中10个类别的那个维度,取预测结果的最大值索引,并转换为numpy类型
    prediction1 = torch.max(output, dim=1)[1].data.numpy()
    # 用softmax()预测出一个概率矩阵
    prediction2 = torch.softmax(output, dim=1)
    # 得到概率最大的值得索引
    prediction2 = np.argmax(prediction2)
# 两种方式都可以得到最后的结果
print(classes[int(prediction1)])
print(classes[int(prediction2)])

反正我最后预测出来结果把猫识别成了狗,还有90.01%的概率,就离谱哈哈哈,但也说明了LeNet这个网络模型确实很浅,特征提取的不够深,才会出现这种。

到此这篇关于Python 实现LeNet网络模型的训练及预测的文章就介绍到这了,更多相关LeNet网络模型训练及预测内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Python实现LeNet网络模型的训练及预测

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Python怎样实现LeNet网络模型的训练及预测

本篇文章给大家分享的是有关Python怎样实现LeNet网络模型的训练及预测,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。1.LeNet模型训练脚本整体的训练代码如下,下面我会
2023-06-21

利用Pytorch实现ResNet网络构建及模型训练

这篇文章主要为大家介绍了利用Pytorch实现ResNet网络构建及模型训练详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2023-05-17

Python底层技术揭秘:如何实现模型训练和预测

Python底层技术揭秘:如何实现模型训练和预测,需要具体代码示例作为一门易学易用的编程语言,Python在机器学习领域中被广泛使用。Python提供了大量的开源机器学习库和工具,比如Scikit-Learn、TensorFlow等。这些开
Python底层技术揭秘:如何实现模型训练和预测
2023-11-08

详解利用Pytorch实现ResNet网络之评估训练模型

这篇文章主要为大家介绍了利用Pytorch实现ResNet网络之评估训练模型详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2023-05-16

Python中LazyPredict库的实施以及训练所有分类或回归模型

Python中LazyPredict库的实施以及训练所有分类或回归模型,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。自动化机器学习(自动ML)是指自动化数据科学
2023-06-15

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录