我的编程空间,编程开发者的网络收藏夹
学习永远不晚

python中的Pytorch建模流程是什么

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

python中的Pytorch建模流程是什么

小编给大家分享一下python中的Pytorch建模流程是什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!

一般我们训练神经网络有以下步骤:

  • 导入库

  • 设置训练参数的初始值

  • 导入数据集并制作数据集

  • 定义神经网络架构

  • 定义训练流程

  • 训练模型

以下,我就将上述步骤使用代码进行注释讲解:

1 导入库

import torchfrom torch import nnfrom torch.nn import functional as Ffrom torch import optimfrom torch.utils.data import DataLoader, DataLoaderimport torchvisionimport torchvision.transforms as transforms

2 设置初始值

# 学习率lr = 0.15# 优化算法参数gamma = 0.8# 每次小批次训练个数bs = 128# 整体数据循环次数epochs = 10

3 导入并制作数据集

本次我们使用FashionMNIST图像数据集,每个图像是一个28*28的像素数组,共有10个衣物类别,比如连衣裙、运动鞋、包等。

注:初次运行下载需要等待较长时间。

# 导入数据集mnist = torchvision.datasets.FashionMNIST(    root = './Datastes'    , train = True    , download = True    , transform = transforms.ToTensor())    # 制作数据集batchdata = DataLoader(mnist                       , batch_size = bs                       , shuffle = True                       , drop_last = False)

我们可以对数据进行检查:

for x, y in batchdata:    print(x.shape)    print(y.shape)    break# torch.Size([128, 1, 28, 28])# torch.Size([128])

可以看到一个batch中有128个样本,每个样本的维度是1*28*28。

之后我们确定模型的输入维度与输出维度:

# 输入的维度input_ = mnist.data[0].numel()# 784# 输出的维度output_ = len(mnist.targets.unique())# 10

4 定义神经网络架构

先使用一个128个神经元的全连接层,然后用relu激活函数,再将其结果映射到标签的维度,并使用softmax进行激活。

# 定义神经网络架构class Model(nn.Module):    def __init__(self, in_features, out_features):        super().__init__()        self.linear1 = nn.Linear(in_features, 128, bias = True)        self.output = nn.Linear(128, out_features, bias = True)        def forward(self, x):        x = x.view(-1, 28*28)        sigma1 = torch.relu(self.linear1(x))        sigma2 = F.log_softmax(self.output(sigma1), dim = -1)        return sigma2

5 定义训练流程

在实际应用中,我们一般会将训练模型部分封装成一个函数,而这个函数可以继续细分为以下几步:

  • 定义损失函数与优化器

  • 完成向前传播

  • 计算损失

  • 反向传播

  • 梯度更新

  • 梯度清零

在此六步核心操作的基础上,我们通常还需要对模型的训练进度、损失值与准确度进行监视。

注释代码如下:

# 封装训练模型的函数def fit(net, batchdata, lr, gamma, epochs):# 参数:模型架构、数据、学习率、优化算法参数、遍历数据次数    # 5.1 定义损失函数    criterion = nn.NLLLoss()    # 5.1 定义优化算法    opt = optim.SGD(net.parameters(), lr = lr, momentum = gamma)        # 监视进度:循环之前,一个样本都没有看过    samples = 0    # 监视准确度:循环之前,预测正确的个数为0    corrects = 0        # 全数据训练几次    for epoch in range(epochs):        # 对每个batch进行训练        for batch_idx, (x, y) in enumerate(batchdata):            # 保险起见,将标签转为1维,与样本对齐            y = y.view(x.shape[0])                        # 5.2 正向传播            sigma = net.forward(x)            # 5.3 计算损失            loss = criterion(sigma, y)            # 5.4 反向传播            loss.backward()            # 5.5 更新梯度            opt.step()            # 5.6 梯度清零            opt.zero_grad()                        # 监视进度:每训练一个batch,模型见过的数据就会增加x.shape[0]            samples += x.shape[0]                        # 求解准确度:全部判断正确的样本量/已经看过的总样本量            # 得到预测标签            yhat = torch.max(sigma, -1)[1]            # 将正确的加起来            corrects += torch.sum(yhat == y)                        # 每200个batch和最后结束时,打印模型的进度            if (batch_idx + 1) % 200 == 0 or batch_idx == (len(batchdata) - 1):                # 监督模型进度                print("Epoch{}:[{}/{} {: .0f}%], Loss:{:.6f}, Accuracy:{:.6f}".format(                    epoch + 1                    , samples                    , epochs*len(batchdata.dataset)                    , 100*samples/(epochs*len(batchdata.dataset))                    , loss.data.item()                    , float(100.0*corrects/samples)))

6 训练模型

# 设置随机种子torch.manual_seed(51)# 实例化模型net = Model(input_, output_)# 训练模型fit(net, batchdata, lr, gamma, epochs)# Epoch2:[25600/600000  4%], Loss:0.524430, Accuracy:69.570312# Epoch2:[51200/600000  9%], Loss:0.363422, Accuracy:74.984375# ......# Epoch20:[600000/600000  100%], Loss:0.284664, Accuracy:85.771835

现在我们已经用Pytorch训练了最基础的神经网络,并且可以查看其训练成果。大家可以将代码复制进行运行!

虽然没有用到复杂的模型,但是我们在每次建模时的基本思想都是一致的

以上是“python中的Pytorch建模流程是什么”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注编程网行业资讯频道!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

python中的Pytorch建模流程是什么

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

python中的Pytorch建模流程是什么

小编给大家分享一下python中的Pytorch建模流程是什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!一般我们训练神经网络有以下步骤:导入库设置训练参数的初
2023-06-29

python中Pexpect的工作流程是什么

这期内容当中小编将会给大家带来有关python中Pexpect的工作流程是什么,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。python可以做什么Python是一种编程语言,内置了许多有效的工具,Pyth
2023-06-14

PyTorch中创建张量的方法是什么

在PyTorch中创建张量有多种方法,最常用的方法包括:使用torch.tensor()函数:通过传入一个列表或数组来创建张量。import torchtensor = torch.tensor([1, 2, 3, 4, 5])使用torc
PyTorch中创建张量的方法是什么
2024-03-05

vitejs预构建的流程是什么

本文小编为大家详细介绍“vitejs预构建的流程是什么”,内容详细,步骤清晰,细节处理妥当,希望这篇“vitejs预构建的流程是什么”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。为啥要预构建简单来讲就是为了提高本
2023-07-02

Python编程中的反模式是什么

Python编程中的反模式是什么,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。这篇文章收集了我在Python新手开发者写的代码中所见到的不规范但偶尔又很微妙的问题。本文的目的是
2023-06-17

pytorch创建tensor的方法是什么

在PyTorch中,可以通过以下几种方法来创建tensor:使用torch.Tensor()函数创建一个空的tensor:tensor = torch.Tensor()使用torch.tensor()函数根据给定的数据创建一个tensor
pytorch创建tensor的方法是什么
2024-04-08

PyTorch模型剪枝的概念是什么

PyTorch模型剪枝是指通过消除神经网络中不必要的参数或神经元,从而减少模型的大小和计算量的过程。剪枝技术可以帮助优化模型,提高推理速度,降低模型的内存占用和功耗,并且可以通过减少模型参数来提高模型的泛化能力。在PyTorch中,可以使用
PyTorch模型剪枝的概念是什么
2024-03-05

PyTorch的模型部署方式是什么

PyTorch模型的部署方式通常有以下几种:部署到本地计算机:可以在本地计算机上使用PyTorch的预训练模型或自己训练的模型进行推理或应用。部署到服务器:将PyTorch模型部署到服务器上,可以通过REST API或其他方式提供服务给客户
PyTorch的模型部署方式是什么
2024-03-14

PyTorch与PyTorch Geometric的安装过程是什么

这篇文章主要讲解了“PyTorch与PyTorch Geometric的安装过程是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“PyTorch与PyTorch Geometric的安装过
2023-07-05

springboot创建项目的流程是什么

Spring Boot创建项目的流程如下:在官方网站下载并安装Spring Boot CLI(Command Line Interface)工具。打开命令行窗口,进入到你想要创建项目的目录中。使用Spring Boot CLI的命令来创建一
springboot创建项目的流程是什么
2024-03-07

C#中流模型的作用是什么

C#中流模型的作用是什么,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。访问的两种模型:在程序中访问进而操作XML文件一般有两种模型,分别是使用DOM(文档对象模型)和流模型,使
2023-06-17

asp网站建设的流程是什么

ASP网站建设的流程一般包括以下几个步骤:1.需求分析:与客户沟通,了解客户的需求和要求,确定网站的基本功能、页面布局、交互方式等。2.网站设计:根据需求分析的结果,进行网站的设计,包括页面设计、用户体验设计、交互设计等。3.数据库设计:根
2023-06-05

建企业网站的流程是什么

1.确定网站目的和目标:明确企业网站的目的和目标,如宣传企业形象、展示产品或服务、提供在线销售等。2.制定网站策略:根据目的和目标制定网站策略,包括网站结构、内容、设计风格等。3.确定网站内容:根据网站策略确定网站内容,包括文字、图片、视频
2023-06-14

python中try语句的工作流程是什么

python中try语句的工作流程是什么?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。python的五大特点是什么python的五大特点:1.简单易学,开发程序时,专注的是解决
2023-06-14

PyTorch中的张量是什么

在PyTorch中,张量是一种类似于多维数组的数据结构,可以存储和处理多维数据。张量在PyTorch中是用来表示神经网络的输入、输出和参数的主要数据类型。张量可以是任意维度的,可以是标量(0维张量)、向量(1维张量)、矩阵(2维张量)等等。
PyTorch中的张量是什么
2024-03-05

pycharm搭建项目流程是什么

搭建项目流程如下:1. 下载和安装PyCharm:首先,你需要从PyCharm官方网站下载适合你操作系统的PyCharm版本,并按照提示进行安装。2. 创建新项目:打开PyCharm,点击"Create New Project"按钮或选择"
2023-09-23

springCloud项目搭建流程是什么

本篇内容主要讲解“springCloud项目搭建流程是什么”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“springCloud项目搭建流程是什么”吧!实现跨服务的远程调用(RestTemplat
2023-06-30

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录