我的编程空间,编程开发者的网络收藏夹
学习永远不晚

图文详解梯度下降算法的原理及Python实现

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

图文详解梯度下降算法的原理及Python实现

1.引例

给定如图所示的某个函数,如何通过计算机算法编程求f(x)min?

2.数值解法

传统方法是数值解法,如图所示

按照以下步骤迭代循环直至最优:

① 任意给定一个初值x0

② 随机生成增量方向,结合步长生成Δx;

③ 计算比较f(x0)与f(x0+Δx)的大小,若f(x0+Δx)<f(x0)则更新位置,否则重新生成Δx;

④ 重复②③直至收敛到最优f(x)min。

数值解法最大的优点是编程简明,但缺陷也很明显:

① 初值的设定对结果收敛快慢影响很大;

② 增量方向随机生成,效率较低;

③ 容易陷入局部最优解;

④ 无法处理“高原”类型函数。

所谓陷入局部最优解是指当迭代进入到某个极小值或其邻域时,由于步长选择不恰当,无论正方向还是负方向,学习效果都不如当前,导致无法向全局最优迭代。就本问题而言如图所示,当迭代陷入x=xj时,由于学习步长step的限制,无法使f(xj±Step)<f(xj),因此迭代就被锁死在了图中的红色区段。可以看出x=xj并非期望的全局最优。

若出现下图所示的“高原”函数,也可能使迭代得不到更新。

3.梯度下降算法

梯度下降算法可视为数值解法的一种改进,阐述如下:

记第k轮迭代后,自变量更新为x=xk,令目标函数f(x)在x=xk泰勒展开:

f(x)=f(xk​)+f′(xk​)(x−xk​)+o(x)

考察f(x)min ,则期望f(xk+1)<f(xk),从而:

f(xk+1​)−f(xk​)=f′(xk​)(xk+1​−xk​)<0

若f′(xk)>0则xk+1<xk ,即迭代方向为负;反之为正。不妨设xk+1−xk=−f′(xk),从而保证f(xk+1)−f(xk)<0。必须指出,泰勒公式成立的条件是x→x0,故|f′(xk)|不能太大,否则xk+1与xk距离太远产生余项误差。因此引入学习率γ∈(0,1)来减小偏移度,即xk+1-xk=−γf′(xk​)

在工程上,学习率γ \gammaγ要结合实际应用合理选择,γ \gammaγ过大会使迭代在极小值两侧振荡,算法无法收敛;γ \gammaγ过小会使学习效率下降,算法收敛慢。

对于向量 ,将上述迭代公式推广为

xk+1​=xk​−γ∇xk​​

其中

为多元函数的梯度,故此迭代算法也称为梯度下降算法

梯度下降算法通过函数梯度确定了每一次迭代的方向和步长,提高了算法效率。但从原理上可以知道,此算法并不能解决数值解法中初值设定、局部最优陷落和部分函数锁死的问题。

4.代码实战:Logistic回归

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from Logit import Logit

'''
* @breif: 从CSV中加载指定数据
* @param[in]: file -> 文件名
* @param[in]: colName -> 要加载的列名
* @param[in]: mode -> 加载模式, set: 列名与该列数据组成的字典, df: df类型
* @retval: mode模式下的返回值
'''
def loadCsvData(file, colName, mode='df'):
    assert mode in ('set', 'df')
    df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)
    if mode == 'df':
        return df
    if mode == 'set':
        res = {}
        for col in colName:
            res[col] = df[col].values
        return res

if __name__ == '__main__':
    # ============================
    # 读取CSV数据
    # ============================
    csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))
    dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')
    dataY = loadCsvData(csvPath, ["好瓜"], 'df')
    label = np.array([
        1 if i == "是" else 0
        for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))
    ])

    # ============================
    # 绘制样本点
    # ============================
    line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    plt.title('对数几率回归模拟\nLogistic Regression Simulation')
    plt.xlabel('density')
    plt.ylabel('sugarRate')
    plt.scatter(dataX['密度'][label==0],
                dataX['含糖率'][label==0],
                marker='^',
                color='k',
                s=100,
                label='坏瓜')
    plt.scatter(dataX['密度'][label==1],
                dataX['含糖率'][label==1],
                marker='^',
                color='r',
                s=100,
                label='好瓜')

    # ============================
    # 实例化对数几率回归模型
    # ============================
    logit = Logit(dataX, label)

    # 采用梯度下降法
    logit.logitRegression(logit.gradientDescent)
    line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]
    plt.plot(line_x, line_y, 'b-', label="梯度下降法")

    # 绘图
    plt.legend(loc='upper left')
    plt.show()

到此这篇关于图文详解梯度下降算法的原理及Python实现的文章就介绍到这了,更多相关Python梯度下降算法内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

图文详解梯度下降算法的原理及Python实现

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

pytorch实现梯度下降和反向传播图文详细讲解

这篇文章主要介绍了pytorch实现梯度下降和反向传播,反向传播的目的是计算成本函数C对网络中任意w或b的偏导数。一旦我们有了这些偏导数,我们将通过一些常数α的乘积和该数量相对于成本函数的偏导数来更新网络中的权重和偏差
2023-05-17

使用Python实现小批量梯度下降算法的代码逻辑

让theta=模型参数和max_iters=时期数。对于itr=1,2,3,...,max_iters:对于mini_batch(X_mini,y_mini):批量X_mini的前向传递:1、对小批量进行预测2、使用参数的当前值计算预
使用Python实现小批量梯度下降算法的代码逻辑
2024-01-22

详解Bagging算法的原理及Python实现

目录一、什么是集成学习二、Bagging算法三、Bagging用于分类四、Bagging用于回归一、什么是集成学习 集成学习是一种技术框架,它本身不是一个单独的机器学习算法,而是通过构建并结合多个机器学习器来完成学习任务,一般结构是:先产生
2022-06-02

牛顿法、梯度下降法、最小二乘法的原理以及利用它们解决实际问题的python编程

  牛顿法、梯度下降法、最小二乘法的原理以及利用它们解决实际问题的python编程  一、牛顿法原理  1、产生背景    2、牛顿迭代公式  二、梯度下降法原理  根据计算梯度时所用数据量不同,可以分为三种基本方法:批量梯度下降法(Bat
2023-06-01

(手写)PCA原理及其Python实现图文详解

目录1、背景2、样本均值和样本方差矩阵3、PCA3.1 最大投影方差3.2 最小重构距离4、Python实现总结1、背景 为什么需要降维呢?因为数据个数 N 和每个数据的维度 p 不满足 N >> p,造成了模型结果的“过拟合”。有两种方法
2022-06-02

图文讲解选择排序算法的原理及在Python中的实现

基本思想:从未排序的序列中找到一个最小的元素,放到第一位,再从剩余未排序的序列中找到最小的元素,放到第二位,依此类推,直到所有元素都已排序完毕。假设序列元素总共n+1个,则我们需要找n轮,就可以使该序列排好序。在每轮中,我们可以这样做:用未
2022-06-04

SPFA算法的实现原理及其应用详解

SPFA算法,全称为Shortest Path Faster Algorithm,是求解单源最短路径问题的一种常用算法,本文就来聊聊它的实现原理与简单应用吧
2023-05-20

Python实现B树插入算法的原理图解

B树是高度平衡的二叉搜索树,进行插入操作,要先获取插入节点的位置,遵循节点比左子树大,比右子树小,在需要时拆分节点。一图看懂B树插入操作原理B树插入算法BreeInsertion(T, k)r  root[T]if n[r] = 2t
Python实现B树插入算法的原理图解
2024-01-23

详解MD5算法的原理以及C#和JS的实现

MD5 是哈希算法(散列算法)的一种应用。这篇文章主要和大家介绍一下MD5算法的原理以及C#和JS的实现,文中的示例代码讲解详细,需要的可以参考一下
2023-03-19

t-SNE算法的原理和Python代码实现详解

T分布随机邻域嵌入(t-SNE),是一种用于可视化的无监督机器学习算法,使用非线性降维技术,根据数据点与特征的相似性,试图最小化高维和低维空间中这些条件概率(或相似性)之间的差异,以在低维空间中完美表示数据点。因此,t-SNE擅长在二维或三
t-SNE算法的原理和Python代码实现详解
2024-01-23

详解DES&3DES算法的原理以及C#和JS的实现

DES 全称为 Data Encryption Standard,即数据加密标准,是一种使用密钥加密的块算法。3DES 算法通过对 DES 算法进行改进,增加 DES 的密钥长度来避免类似的攻击。本文就来聊聊它们的原理与实现吧
2023-03-19

Matlab中图像数字水印算法的原理与实现详解

数字水印技术作为信息隐藏技术的一个重要分支,是将信息(水印)隐藏于数字图像、视频、音频及文本文档等数字媒体中,从而实现隐秘传输、存储、标注、身份识别、版权保护和防篡改等目的。本文就来讲讲图像数字水印算法的原理与实现,感兴趣的可以了解一下
2023-05-15

基数排序算法的原理与实现详解(Java/Go/Python/JS/C)

基数排序(RadixSort)是一种非比较型整数排序算法,其原理是将整数按位数切割成不同的数字,然后按每个位数分别比较。本文将利用Java/Go/Python/JS/C不同语言实现基数排序算法,感兴趣的可以了解一下
2023-03-06

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录