我的编程空间,编程开发者的网络收藏夹
学习永远不晚

python简单批量梯度下降代码怎么写

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

python简单批量梯度下降代码怎么写

python简单批量梯度下降代码怎么写,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。

简单批量梯度下降代码

其中涉及到公式

alpha表示超参数,由外部设定。过大则会出现震荡现象,过小则会出现学习速度变慢情况,因此alpha应该不断的调整改进。

python简单批量梯度下降代码怎么写

python简单批量梯度下降代码怎么写

注意1/m前正负号的改变

python简单批量梯度下降代码怎么写

Xj的意义为j个维度的样本。
下面为代码部分

import numpy as np#该处数据和linear_model中数据相同x = np.array([4,8,5,10,12])y = np.array([20,50,30,70,60])#一元线性回归 即 h_theta(x)=  y= theta0 +theta1*x#初始化系数,最开始要先初始化theta0 和theta1theta0,theta1 = 0,0#最开始梯度下降法中也有alpha 为超参数,提前初始化为0.01alpha = 0.01#样本的个数 ,在梯度下降公式中有xm = len(x)#设置停止条件,即梯度下降到满足实验要求时即可停止。# 方案1:设置迭代次数,如迭代5000次后停止。#(此处为2)方案2:设置epsilon,计算mse(均方误差,线性回归指标之一)的误差,如果mse的误差《= epsilon,即停止#在更改epsilon的次数后,越小,迭代次数会越多,结果更加准确。epsilon = 0.00000001#设置误差error0,error1 = 0,0#计算迭代次数cnt = 0def h_theta_x(x):    return theta0+theta1*x#接下来开始各种迭代#"""用while 迭代"""while True:    cnt+=1    diff=[0,0]    #该处为梯度,设置了两个梯度后再进行迭代,梯度每次都会清零后再进行迭代    for i in range(m):        diff[0]+=(y[i]-h_theta_x(x[i]))*1        diff[1]+=(y[i]-h_theta_x(x[i]))*x[i]    theta0 = theta0 + alpha * diff[0] / m    theta1 = theta1 + alpha * diff[1] / m    #输出theta值    # ”%s“表示输出的是输出字符串。格式化    print("theta0:%s,theta1:%s"%(theta0,theta1))    #计算mse    for i in range(m):        error1 +=(y[i]-h_theta_x(x[i]))**2    error1/=m    if(abs(error1-error0)<=epsilon):        break    else:        error0 = error1print("迭代次数:%s"%cnt)#线性回归结果:5.714285714285713     1.4285714285714448      87.14285714285714#批量梯度下降结果:theta0:1.4236238440026219,theta1:5.71483960227916   迭代次数:3988#在更改epsilon的次数后,越小,迭代次数会越多,结果更加准确。
在线性模型的代码(代码可参见另一条文章)中,得到运算结果a,b的值,与梯度下降后得到的结果theta0和theta1相近。增加实验次数(如修改epsilon的次数)可以得到更为相近的结果。

运行完毕后发现其实该处理方式并不理想
因为梯度下降开始后,theta数量会增加,即变量也会增加。每次增加都需要重新编写其中的循环和函数。
因此可以将他们编写成向量的形式

import numpy as np#X_b = np.array([[1,4],[1,8],[1,5],[1,10],[1,12]])#y = np.array([20,50,30,70,60])#改写成向量形式#运用random随机生成100个样本np.random.seed(1)X = 2 * np.random.rand(100, 1)y = 4 + 3 * X + np.random.rand(100, 1)X_b = np.c_[np.ones((100, 1)), X]#print(X_b)#此处的learning_rate 就是alphalearning_rate = 0.01#设置最大迭代次数,避免学习时间过长n_iterations = 10000#样本格数m = 100#初始化thata, w0...wn,初始化两个2*1 的随机数theta = np.random.randn(2, 1)#不会设置阈值,直接设置超参数,迭代次数,迭代次数到了,我们就认为收敛了。先看结果,如果结果不好就去调参for _ in range(n_iterations):    #接着求梯度gradient,这儿的梯度是n个梯度。即x* (h_theta - y)    #会得到一次迭代的n个theta值    gradients = 1/m * X_b.T.dot(X_b.dot(theta)-y)    #应用公式调整theta的值,theta_t + 1 = theta_t - grad * learning_rate , 是一个向量    theta = theta - learning_rate * gradientsprint(theta)

看完上述内容,你们掌握python简单批量梯度下降代码怎么写的方法了吗?如果还想学到更多技能或想了解更多相关内容,欢迎关注编程网行业资讯频道,感谢各位的阅读!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

python简单批量梯度下降代码怎么写

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

python简单批量梯度下降代码怎么写

python简单批量梯度下降代码怎么写,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。简单批量梯度下降代码其中涉及到公式alpha表示超参数,由外部设定。过大则会出现震荡现象,过
2023-06-26

python简单批量梯度下降代码

简单批量梯度下降代码 其中涉及到公式 alpha表示超参数,由外部设定。过大则会出现震荡现象,过小则会出现学习速度变慢情况,因此alpha应该不断的调整改进。注意1/m前正负号的改变Xj的意义为j个维度的样本。下面为代码部分 import
2022-06-04

使用Python实现小批量梯度下降算法的代码逻辑

让theta=模型参数和max_iters=时期数。对于itr=1,2,3,...,max_iters:对于mini_batch(X_mini,y_mini):批量X_mini的前向传递:1、对小批量进行预测2、使用参数的当前值计算预
使用Python实现小批量梯度下降算法的代码逻辑
2024-01-22

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录