Tensorflow均方误差损失函数

30

我在TensorFlow的各种回归模型帖子中看到了几个不同的均方误差损失函数:

loss = tf.reduce_sum(tf.pow(prediction - Y,2))/(n_instances)
loss = tf.reduce_mean(tf.squared_difference(prediction, Y))
loss = tf.nn.l2_loss(prediction - Y)

这些之间有什么区别?


4
另一种计算 MSE 的方法(与您的第一和第二种方法等效):tf.losses.mean_squared_error - Yibo Yang
2个回答

17

我认为第三个方程式(使用l2_loss)只返回输入的元素平方和的1/2,即预测值和实际值之差的各项平方和,即x=prediction-Y。在任何地方都没有除以样本数量。因此,如果您有非常大量的样本,计算可能会溢出(返回Inf)。

另外两个方程式形式上相同,计算了元素平方x张量的均值。然而,虽然文档没有明确说明,但很可能reduce_mean使用了一种适应于避免在样本数量非常大时溢出的算法。换句话说,它不会先求和再除以N,而是使用一种滚动平均值的方法,可以适应任意数量的样本而不一定会导致溢出。


很可能reduce_mean使用了一种适应于避免在样本数量非常大时溢出的算法。我不认为这是正确的。这里是相关代码,很明显你所声称的并没有发生。虽然,tf代码有点像一个兔子洞,而我也不是这方面的专家。但是,如果没有一些引用或证据,我认为这些说法有些可疑。 - Him

16

第一和第二个损失函数计算相同的内容,但是方式略有不同。第三个函数则完全计算不同的内容。通过执行以下代码可以看到:

import tensorflow as tf

shape_obj = (5, 5)
shape_obj = (100, 6, 12)
Y1 = tf.random_normal(shape=shape_obj)
Y2 = tf.random_normal(shape=shape_obj)

loss1 = tf.reduce_sum(tf.pow(Y1 - Y2, 2)) / (reduce(lambda x, y: x*y, shape_obj))
loss2 = tf.reduce_mean(tf.squared_difference(Y1, Y2))
loss3 = tf.nn.l2_loss(Y1 - Y2)

with tf.Session() as sess:
    print sess.run([loss1, loss2, loss3])
# when I run it I got: [2.0291963, 2.0291963, 7305.1069]

现在你可以通过注意到tf.pow(a - b, 2)tf.squared_difference(a - b, 2)相同来验证第一和第二次计算理论上计算的是相同的东西。此外,reduce_mean等于reduce_sum / number_of_element。问题在于计算机不能精确地计算所有事情。要了解数值不稳定性对计算的影响,请看这个:
import tensorflow as tf

shape_obj = (5000, 5000, 10)
Y1 = tf.zeros(shape=shape_obj)
Y2 = tf.ones(shape=shape_obj)

loss1 = tf.reduce_sum(tf.pow(Y1 - Y2, 2)) / (reduce(lambda x, y: x*y, shape_obj))
loss2 = tf.reduce_mean(tf.squared_difference(Y1, Y2))

with tf.Session() as sess:
    print sess.run([loss1, loss2])

很容易看出答案应该是1,但你会得到像这样的结果:[1.0, 0.26843545]
关于你最后的函数,文档说:
计算张量的L2范数的一半,不带平方根:输出= sum(t ** 2) / 2
因此,如果你想让它在理论上计算与第一个相同的东西,你需要适当地缩放它。
loss3 = tf.nn.l2_loss(Y1 - Y2) * 2 / (reduce(lambda x, y: x*y, shape_obj))

这是否意味着内置的tensorflow函数比显式计算平方并取平均值更差?这对我来说没有意义,内置实现不应该更加数值稳定吗?否则为什么要费力地制作一个函数呢? - Guilherme de Lazari
3
tf.nn.l2_loss的用例是什么? - mrgloom

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接