Python怎样实现LeNet网络模型的训练及预测
本篇文章给大家分享的是有关Python怎样实现LeNet网络模型的训练及预测,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。
1.LeNet模型训练脚本
整体的训练代码如下,下面我会为大家详细讲解这些代码的意思
import torchimport torchvisionfrom torchvision.transforms import transformsimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom pytorch.lenet.model import LeNetimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as plttransform = transforms.Compose( # 将数据集转换成tensor形式 [transforms.ToTensor(), # 进行标准化,0.5是均值,也是方差,对应三个维度都是0.5 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 下载完整的数据集时,download=True,第一个为保存的路径,下载完后download要改为False# 为训练集时,train=True,为测试集时,train=Falsetrain_set = torchvision.datasets.CIFAR10('./data', train=True, download=False, transform=transform)# 加载训练集,设置批次大小,是否打乱,number_works是线程数,window不设置为0会报错,linux可以设置非零train_loader = DataLoader(train_set, batch_size=36, shuffle=True, num_workers=0)test_set = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform)# 设置的批次大小一次性将所有测试集图片传进去test_loader = DataLoader(test_set, batch_size=10000, shuffle=False, num_workers=0)# 迭代测试集的图片数据和标签值test_img, test_label = next(iter(test_loader))# CIFAR10的十个类别名称classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# # ----------------------------显示图片-----------------------------------# def imshow(img, label):# fig = plt.figure()# for i in range(len(img)):# ax = fig.add_subplot(1, len(img), i+1)# nping = img[i].numpy().transpose([1, 2, 0])# npimg = (nping * 2 + 0.5)# plt.imshow(npimg)# title = '{}'.format(classes[label[i]])# ax.set_title(title)# plt.axis('off')# plt.show()# # # batch_image = test_img[: 5]# label_img = test_label[: 5]# imshow(batch_image, label_img)# # ----------------------------------------------------------------------net = LeNet()# 定义损失函数,nn.CrossEntropyLoss()自带softmax函数,所以模型的最后一层不需要softmax进行激活loss_function = nn.CrossEntropyLoss()# 定义优化器,优化网络模型所有参数optimizer = optim.Adam(net.parameters(), lr=0.001)# 迭代五次for epoch in range(5): # 初始损失设置为0 running_loss = 0 # 循环训练集,从1开始 for step, data in enumerate(train_loader, start=1): inputs, labels = data # 优化器的梯度清零,每次循环都需要清零,否则梯度会无限叠加,相当于增加批次大小 optimizer.zero_grad() # 将图片数据输入模型中 outputs = net(inputs) # 传入预测值和真实值,计算当前损失值 loss = loss_function(outputs, labels) # 损失反向传播 loss.backward() # 进行梯度更新 optimizer.step() # 计算该轮的总损失,因为loss是tensor类型,所以需要用item()取具体值 running_loss += loss.item() # 每500次进行日志的打印,对测试集进行预测 if step % 500 == 0: # torch.no_grad()就是上下文管理,测试时不需要梯度更新,不跟踪梯度 with torch.no_grad(): # 传入所有测试集图片进行预测 outputs = net(test_img) # torch.max()中dim=1是因为结果为(batch, 10)的形式,我们只需要取第二个维度的最大值 # max这个函数返回[最大值, 最大值索引],我们只需要取索引就行了,所以用[1] predict_y = torch.max(outputs, dim=1)[1] # (predict_y == test_label)相同返回True,不相等返回False,sum()对正确率进行叠加 # 因为计算的变量都是tensor,所以需要用item()拿到取值 accuracy = (predict_y == test_label).sum().item() / test_label.size(0) # running_loss/500是计算每一个step的loss,即每一步的损失 print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' % (epoch+1, step, running_loss/500, accuracy)) running_loss = 0.0print('Finished Training!')save_path = 'lenet.pth'# 保存模型,字典形式torch.save(net.state_dict(), save_path)
(1).下载CIFAR10数据集
首先要训练一个网络模型,我们需要足够多的图片做数据集,这里我们用的是torchvision.dataset为我们提供的CIFAR10数据集(更多的数据集可以去pytorch官网查看pytorch官网提供的数据集)
train_set = torchvision.datasets.CIFAR10('./data', train=True, download=False, transform=transform)test_set = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform)
这部分代码是下载CIFAR10,第一个参数是下载数据集后存放的路径,train=True和False对应下载的训练集和测试集,transform是对应的图像增强方式
(2).图像增强
transform = transforms.Compose( # 将数据集转换成tensor形式 [transforms.ToTensor(), # 进行标准化,0.5是均值,也是方差,对应三个维度都是0.5 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
这就是简单的图像图像增强,transforms.ToTensor()将数据集的所有图像转换成tensor, transforms.Normalize()是标准化处理,包含两个元组对应均值和标准差,每个元组包含三个元素对应图片的三个维度[channels, height, width],为什么是这样排序,别问,问就是pytorch要求的,顺序不能变,之后会看到transforms.Normalize([0.485, 0.406, 0.456], [0.229, 0.224, 0.225])这两组数据,这是官方给出的均值和标准差,之后标准化的时候会经常用到
(3).加载数据集
# 加载训练集,设置批次大小,是否打乱,number_works是线程数,window不设置为0会报错,linux可以设置非零train_loader = DataLoader(dataset=train_set, batch_size=36, shuffle=True, num_workers=0)test_loader = DataLoader(dataset=test_set, batch_size=36, shuffle=False, num_workers=0)
这里只简单的设置的四个参数也是比较重要的,第一个就是需要加载的训练集和测试集,shuffle=True表示将数据集打乱,batch_size表示一次性向设备放入36张图片,打包成一个batch,这时图片的shape就会从[3, 32, 32]----》[36, 3, 32, 32],传入网络模型的shape也必须是[None, channels, height, width],None代表一个batch多少张图片,否则就会报错,number_works是代表线程数,window系统必须设置为0,否则会报错,linux系统可以设置非0数
(4).显示部分图像
def imshow(img, label): fig = plt.figure() for i in range(len(img)): ax = fig.add_subplot(1, len(img), i+1) nping = img[i].numpy().transpose([1, 2, 0]) npimg = (nping * 2 + 0.5) plt.imshow(npimg) title = '{}'.format(classes[label[i]]) ax.set_title(title) plt.axis('off') plt.show()batch_image = test_img[: 5]label_img = test_label[: 5]imshow(batch_image, label_img)
这部分代码是显示测试集当中前五张图片,运行后会显示5张拼接的图片
由于这个数据集的图片都比较小都是32x32的尺寸,有些可能也看的不太清楚,图中显示的是真实标签,注:显示图片的代码可能会这个报警(Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).),警告解决的方法:将图片数组转成uint8类型即可,即 plt.imshow(npimg.astype(‘uint8'),但是那样显示出来的图片会变,所以暂时可以先不用管。
(5).初始化模型
数据图片处理完了,下面就是我们的正式训练过程
net = LeNet()# 定义损失函数,nn.CrossEntropyLoss()自带softmax函数,所以模型的最后一层不需要softmax进行激活loss_function = nn.CrossEntropyLoss()# 定义优化器,优化模型所有参数optimizer = optim.Adam(net.parameters(), lr=0.001)
首先初始化LeNet网络,定义交叉熵损失函数,以及Adam优化器,关于注释写的,我们可以ctrl+鼠标左键查看CrossEntropyLoss(),翻到CrossEntropyLoss类,可以看到注释写的这个标准包含LogSoftmax函数,所以搭建LetNet模型的最后一层没有使用softmax激活函数
(6).训练模型及保存模型参数
for epoch in range(5): # 初始损失设置为0 running_loss = 0 # 循环训练集,从1开始 for step, data in enumerate(train_loader, start=1): inputs, labels = data # 优化器的梯度清零,每次循环都需要清零,否则梯度会无限叠加,相当于增加批次大小 optimizer.zero_grad() # 将图片数据输入模型中得到输出 outputs = net(inputs) # 传入预测值和真实值,计算当前损失值 loss = loss_function(outputs, labels) # 损失反向传播 loss.backward() # 进行梯度更新(更新W,b) optimizer.step() # 计算该轮的总损失,因为loss是tensor类型,所以需要用item()取到值 running_loss += loss.item() # 每500次进行日志的打印,对测试集进行测试 if step % 500 == 0: # torch.no_grad()就是上下文管理,测试时不需要梯度更新,不跟踪梯度 with torch.no_grad(): # 传入所有测试集图片进行预测 outputs = net(test_img) # torch.max()中dim=1是因为结果为(batch, 10)的形式,我们只需要取第二个维度的最大值,第二个维度是包含十个类别每个类别的概率的向量 # max这个函数返回[最大值, 最大值索引],我们只需要取索引就行了,所以用[1] predict_y = torch.max(outputs, dim=1)[1] # (predict_y == test_label)相同返回True,不相等返回False,sum()对正确结果进行叠加,最后除测试集标签的总个数 # 因为计算的变量都是tensor,所以需要用item()拿到取值 accuracy = (predict_y == test_label).sum().item() / test_label.size(0) # running_loss/500是计算每一个step的loss,即每一步的损失 print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' % (epoch+1, step, running_loss/500, accuracy)) running_loss = 0.0 print('Finished Training!')save_path = 'lenet.pth'# 保存模型,字典形式torch.save(net.state_dict(), save_path)
这段代码注释写的很清楚,大家仔细看就能看懂,流程不复杂,多看几遍就能理解,最后再对训练好的模型进行保存就好了(* ̄︶ ̄)
2.预测脚本
上面已经训练好了模型,得到了lenet.pth参数文件,预测就很简单了,可以去网上随便找一张数据集包含的类别图片,将模型参数文件载入模型,通过对图像进行一点处理,喂入模型即可,下面奉上代码:
import torchimport numpy as npimport torchvision.transforms as transformsfrom PIL import Imagefrom pytorch.lenet.model import LeNetclasses = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')transforms = transforms.Compose( # 对数据图片调整大小 [transforms.Resize([32, 32]), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])net = LeNet()# 加载预训练模型net.load_state_dict(torch.load('lenet.pth'))# 网上随便找的猫的图片img_path = '../../Photo/cat2.jpg'img = Image.open(img_path)# 图片的处理img = transforms(img)# 增加一个维度,(channels, height, width)------->(batch, channels, height, width),pytorch要求必须输入这样的shapeimg = torch.unsqueeze(img, dim=0)with torch.no_grad(): output = net(img) # dim=1,只取[batch, 10]中10个类别的那个维度,取预测结果的最大值索引,并转换为numpy类型 prediction1 = torch.max(output, dim=1)[1].data.numpy() # 用softmax()预测出一个概率矩阵 prediction2 = torch.softmax(output, dim=1) # 得到概率最大的值得索引 prediction2 = np.argmax(prediction2)# 两种方式都可以得到最后的结果print(classes[int(prediction1)])print(classes[int(prediction2)])
反正我最后预测出来结果把猫识别成了狗,还有90.01%的概率,就离谱哈哈哈,但也说明了LeNet这个网络模型确实很浅,特征提取的不够深,才会出现这种。
以上就是Python怎样实现LeNet网络模型的训练及预测,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注编程网行业资讯频道。
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341