我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch实现逻辑回归

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch实现逻辑回归

本文实例为大家分享了pytorch实现逻辑回归的具体代码,供大家参考,具体内容如下

一、pytorch实现逻辑回归

逻辑回归是非常经典的分类算法,是用于分类任务,如垃圾分类任务,情感分类任务等都可以使用逻辑回归。

接下来使用逻辑回归模型完成一个二分类任务:


# 使用逻辑回归完成一个二分类任务
# 数据准备
import torch
import matplotlib.pyplot as plt

x1 = torch.randn(365)+1.5   # randn():输出一个形状为size的标准正态分布Tensor
x2 = torch.randn(365)-1.5
#print(x1.shape)  # torch.Size([365])
#print(x2.shape)  # torch.Size([365])
data = zip(x1.data.numpy(),x2.data.numpy())  # 创建一个聚合了来自每个可迭代对象中的元素的迭代器。 x = [1,2,3]

pos = []
neg = []
def classification(data):
    for i in data:
        if (i[0] > 1.5+0.1*torch.rand(1).item()*(-1)**torch.randint(1,10,(1,1)).item()):
            pos.append(i)
        else:
            neg.append(i)

classification(data)
# 将正、负两类数据可视化
pos_x = [i[0] for i in pos]
pos_y = [i[1] for i in pos]
neg_x = [i[0] for i in neg]
neg_y = [i[1] for i in neg]
plt.scatter(pos_x,pos_y,c = 'r',marker = "*")
plt.scatter(neg_x,neg_y,c = 'b',marker = "^")
plt.show()

# 构造正、负两类数据可视化结果如上图所示

# 构建模型
import torch.nn as nn
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2,1)
        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        return self.sigmoid(self.linear(x))

model = LogisticRegression()
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(),0.01)
epoch = 5000
features = [[i[0],i[1]] for i in pos]
features.extend([[i[0],i[1]] for i in neg])   #extend 接受一个参数,这个参数总是一个 list,并且把这个 list 中的每个元素添加到原 list 中
features = torch.Tensor(features)   # torch.Tensor 生成单精度浮点类型的张量

label = [1 for i in range(len(pos))]
label.extend(0 for i in range(len(neg)))
label = torch.Tensor(label)
print(label.shape)

for i in range(500000):
    out = model(features)
    #print(out.shape)
    loss = criterion(out.squeeze(1),label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 分类任务准确率
    acc = (out.ge(0.5).float().squeeze(1)==label).sum().float()/features.size()[0]
    if (i % 10000 ==0):
        plt.scatter(pos_x, pos_y, c='r', marker="*")
        plt.scatter(neg_x, neg_y, c='b', marker="^")
        weight = model.linear.weight[0]
        #print(weight.shape)
        wo = weight[0]
        w1 = weight[1]
        b = model.linear.bias.data[0]
        # 绘制分界线
        test_x = torch.linspace(-10,10,500)   # 500个点
        test_y = (-wo*test_x - b) / w1
        plt.plot(test_x.data.numpy(),test_y.data.numpy(),c="pink")
        plt.title("acc:{:.4f},loss:{:.4f}".format(acc,loss))
        plt.ylim(-5,3)
        plt.xlim(-3,5)
        plt.show()

附上分类结果:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持编程网。

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch实现逻辑回归

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

python实现逻辑回归的方法示例

本文实现的原理很简单,优化方法是用的梯度下降。后面有测试结果。 先来看看实现的示例代码:# coding=utf-8 from math import expimport matplotlib.pyplot as plt import nu
2022-06-04

怎么在R语言中实现逻辑回归

怎么在R语言中实现逻辑回归?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。什么是R语言R语言是用于统计分析、绘图的语言和操作环境,属于GNU系统的一个自由、免费、源代码开放的
2023-06-14

python怎么实现梯度下降求解逻辑回归

今天小编给大家分享一下python怎么实现梯度下降求解逻辑回归的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。线性回归1.线性
2023-07-06

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录