我的编程空间,编程开发者的网络收藏夹
学习永远不晚

tensorflow2 自定义损失函数使用的隐藏坑

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

tensorflow2 自定义损失函数使用的隐藏坑

Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制。当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型。

这里以tensorflow2官网中的例子来说明:


import numpy as np
import tensorflow as tf
from tensorflow import keras
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
class CustomModel(keras.Model):
    tf.random.set_seed(100)
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    


# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss=tf.losses.MSE, metrics=["mae"])

# Just use `fit` as usual

model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 1ms/step - loss: 0.2783 - mae: 0.4257

 
<tensorflow.python.keras.callbacks.History at 0x7ff7edf6dfd0>

这里的loss是tensorflow库中实现了的损失函数,如果想自定义损失函数,然后将损失函数传入model.compile中,能正常按我们预想的work吗?

答案竟然是否定的,而且没有错误提示,只是loss计算不会符合我们的预期。


def custom_mse(y_true, y_pred):
    return tf.reduce_mean((y_true - y_pred)**2, axis=-1)
a_true = tf.constant([1., 1.5, 1.2])
a_pred = tf.constant([1., 2, 1.5])
custom_mse(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>
tf.losses.MSE(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>

以上结果证实了我们自定义loss的正确性,下面我们直接将自定义的loss置入compile中的loss参数中,看看会发生什么。


my_model = CustomModel(inputs, outputs)
my_model.compile(optimizer="adam", loss=custom_mse, metrics=["mae"])
my_model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 820us/step - loss: 0.1628 - mae: 0.3257

<tensorflow.python.keras.callbacks.History at 0x7ff7edeb7810>

我们看到,这里的loss与我们与标准的tf.losses.MSE明显不同。这说明我们自定义的loss以这种方式直接传递进model.compile中,是完全错误的操作。

正确运用自定义loss的姿势是什么呢?下面揭晓。


loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")

class MyCustomModel(keras.Model):
    tf.random.set_seed(100)
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = custom_mse(y, y_pred)
            # loss += self.losses

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Compute our own metrics
        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}
    
    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [loss_tracker, mae_metric]
    
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
my_model_beta = MyCustomModel(inputs, outputs)
my_model_beta.compile(optimizer="adam")

# Just use `fit` as usual

my_model_beta.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 960us/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7eda3d810>

终于,通过跳过在 compile() 中传递损失函数,而在 train_step 中手动完成所有计算内容,我们获得了与之前默认tf.losses.MSE完全一致的输出,这才是我们想要的结果。

总结一下,当我们在模型中想用自定义的损失函数,不能直接传入fit函数,而是需要在train_step中手动传入,完成计算过程。

到此这篇关于tensorflow2 自定义损失函数使用的隐藏坑的文章就介绍到这了,更多相关tensorflow2 自定义损失函数内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

tensorflow2 自定义损失函数使用的隐藏坑

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

tensorflow2中怎么自定义损失函数

tensorflow2中怎么自定义损失函数,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更
2023-06-20

使用tensorflow2自定义损失函数需要注意什么

小编给大家分享一下使用tensorflow2自定义损失函数需要注意什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!Keras的核心原则是逐步揭示复杂性,可以在保
2023-06-28

如何在Keras中使用自定义的损失函数

要在Keras中使用自定义的损失函数,首先需要定义一个Python函数来表示损失函数,然后将其传递给Keras模型的compile()方法中。下面是一个简单的例子,展示了如何在Keras中使用自定义的损失函数:import keras.
如何在Keras中使用自定义的损失函数
2024-03-12

MindSpore自定义模型损失函数的示例分析

这篇文章将为大家详细讲解有关MindSpore自定义模型损失函数的示例分析,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。一、技术背景损失函数是机器学习中直接决定训练结果好坏的一个模块,该函数用于定义计算出
2023-06-20

SparkSQL的自定义函数UDF使用

SparkSql可以通过UDF来对DataFrame的Column进行自定义操作。在特定场景下定义UDF可能需要用到SparkContext以外的资源或数据。比如从List或Map中取值,或是通过连接池从外部的数据源中读取数据,然后再参与Column的运算
2023-02-01

如何使用JavaScript定义自己的ajax函数

这篇文章将为大家详细讲解有关如何使用JavaScript定义自己的ajax函数,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。由于用原生js的方式发起的网络请求,都是以查询字符串的形式,提交给服务器的,用户
2023-06-21

使用自定义函数验证 MySQL 中的日期

让我们创建一个自定义函数来验证 MySQL 中的日期 -mysql> set global log_bin_trust_function_creators=1;Query OK, 0 rows affected (0.03 sec)my
2023-10-22

Linux下怎么编写和使用自定义的Shell函数和函数库

本篇内容主要讲解“Linux下怎么编写和使用自定义的Shell函数和函数库”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Linux下怎么编写和使用自定义的Shell函数和函数库”吧!在 Linu
2023-06-16

python自定义函数中的return和print使用及说明

这篇文章主要介绍了python自定义函数中的return和print使用及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2023-01-04

使用golang自定义函数实现实现特定领域的语言

使用 go 语言自定义函数可扩展特定领域的语言,其步骤为:创建自定义函数,定义特定领域问题的可重用代码块。在程序中使用自定义函数,简化和优化代码。针对不同场景的需求,还可以创建更复杂的实战案例,例如处理 json 数据的自定义函数。使用 G
使用golang自定义函数实现实现特定领域的语言
2024-04-27

帝国CMS使用用户自定义函数取发表的新闻数

添加用户自定义函数 1.函数内容如下: 复制代码代码如下:
2022-06-12

使用 PHP 内置函数和自定义函数去重数组的性能对比

array_unique() 是去重数组性能最好的内置函数。哈希表法自定义函数性能最优,哈希值用作键,值为空。循环法实现简单但效率低,建议使用内置或自定义函数进行去重。array_unique() 耗时 0.02 秒、array_rever
使用 PHP 内置函数和自定义函数去重数组的性能对比
2024-04-26

如何使用golang中的http.NewRequest函数创建自定义的HTTP请求

如何使用golang中的http.NewRequest函数创建自定义的HTTP请求在golang中,我们可以使用http.NewRequest函数创建自定义的HTTP请求。这个函数可以让我们更灵活地控制请求的各个方面,包括请求的方法、URL
如何使用golang中的http.NewRequest函数创建自定义的HTTP请求
2023-11-18

使用PHP自定义函数扩展数组交集和并集的功能

使用 php 自定义函数可扩展数组交集和并集功能,自定义交集函数允许按键或值查找交集,而自定义并集函数按键或值查找并集。这使您能够基于特定需求灵活操作数组。使用 PHP 自定义函数扩展数组交集和并集在 PHP 中,交集和并集是两个经常使用
使用PHP自定义函数扩展数组交集和并集的功能
2024-05-01

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录