我的编程空间,编程开发者的网络收藏夹
学习永远不晚

python神经网络Batch Normalization底层原理详解

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

python神经网络Batch Normalization底层原理详解

什么是Batch Normalization

Batch Normalization是神经网络中常用的层,解决了很多深度学习中遇到的问题,我们一起来学习一哈。

Batch Normalization是由google提出的一种训练优化方法。参考论文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift。

Batch Normalization的名称为批标准化,它的功能是使得输入的X数据符合同一分布,从而使得训练更加简单、快速。

一般来讲,Batch Normalization会放在卷积层后面,即卷积 + 标准化 + 激活函数。

其计算过程可以简单归纳为以下3点:

1、求数据均值。

2、求数据方差。

3、数据进行标准化。

Batch Normalization的计算公式

Batch Normalization的计算公式主要看如下这幅图:

这个公式一定要静下心来看,整个公式可以分为四行:

1、对输入进来的数据X进行均值求取。

2、利用输入进来的数据X减去第一步得到的均值,然后求平方和,获得输入X的方差。

3、利用输入X、第一步获得的均值和第二步获得的方差对数据进行归一化,即利用X减去均值,然后除上方差开根号。方差开根号前需要添加上一个极小值。

4、引入γ和β变量,对输入进来的数据进行缩放和平移。利用γ和β两个参数,让我们的网络可以学习恢复出原始网络所要学习的特征分布。

前三步是标准化工序,最后一步是反标准化工序。

Bn层的好处

1、加速网络的收敛速度。在神经网络中,存在内部协变量偏移的现象,如果每层的数据分布不同的话,会导致非常难收敛,如果把每层的数据都在转换在均值为零,方差为1的状态下,这样每层数据的分布都是一样的,训练会比较容易收敛。

2、防止梯度爆炸和梯度消失。对于梯度消失而言,以Sigmoid函数为例,它会使得输出在[0,1]之间,实际上当x到了一定的大小,sigmoid激活函数的梯度值就变得非常小,不易训练。归一化数据的话,就能让梯度维持在比较大的值和变化率;

对于梯度爆炸而言,在方向传播的过程中,每一层的梯度都是由上一层的梯度乘以本层的数据得到。如果归一化的话,数据均值都在0附近,很显然,每一层的梯度不会产生爆炸的情况。

3、防止过拟合。在网络的训练中,Bn使得一个minibatch中所有样本都被关联在了一起,因此网络不会从某一个训练样本中生成确定的结果,这样就会使得整个网络不会朝这一个方向使劲学习。一定程度上避免了过拟合。

为什么要引入γ和β变量

Bn层在进行前三步后,会引入γ和β变量,对输入进来的数据进行缩放和平移。

γ和β变量是网络参数,是可学习的。

引入γ和β变量进行缩放平移可以使得神经网络有自适应的能力,在标准化效果好时,尽量不抵消标准化的作用,而在标准化效果不好时,尽量去抵消一部分标准化的效果,相当于让神经网络学会要不要标准化,如何折中选择。

Bn层的代码实现

Pytorch代码看起来比较简单,而且和上面的公式非常符合,可以学习一下,参考自

https://www.jb51.net/article/247197.htm

def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
    if not is_training:
        x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + eps)
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * x_hat + beta
    return Y, moving_mean, moving_var
class BatchNorm2d(nn.Module):
    def __init__(self, num_features):
        super(BatchNorm2d, self).__init__()
        shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.register_buffer('moving_mean', torch.zeros(shape))
        self.register_buffer('moving_var', torch.ones(shape))
    def forward(self, x):
        if self.moving_mean.device != x.device:
            self.moving_mean = self.moving_mean.to(x.device)
            self.moving_var = self.moving_var.to(x.device)
        y, self.moving_mean, self.moving_var = batch_norm(self.training,
            x, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return y

以上就是python神经网络Batch Normalization底层原理详解的详细内容,更多关于Batch Normalization底层原理的资料请关注编程网其它相关文章!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

python神经网络Batch Normalization底层原理详解

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Python底层技术解析:如何实现神经网络

Python底层技术解析:如何实现神经网络,需要具体代码示例在现代人工智能领域中,神经网络是最为常用和重要的技术之一。它模拟人脑的工作原理,通过多层神经元的连接来实现复杂的任务。Python作为一门功能强大且易于使用的编程语言,为实现神经网
Python底层技术解析:如何实现神经网络
2023-11-08

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录