我知道有很多关于交叉熵的解释,但我仍然感到困惑。
它只是描述损失函数的一种方法吗?我们可以使用梯度下降算法利用损失函数来找到最小值吗?
我知道有很多关于交叉熵的解释,但我仍然感到困惑。
它只是描述损失函数的一种方法吗?我们可以使用梯度下降算法利用损失函数来找到最小值吗?
Pr(Class A) Pr(Class B) Pr(Class C)
0.0 1.0 0.0
Pr(Class A) Pr(Class B) Pr(Class C)
0.228 0.619 0.153
预测分布与实际分布的接近程度由交叉熵损失函数来确定。使用以下公式:
其中p(x)
是真实的概率分布(one-hot),q(x)
是预测的概率分布。求和范围是三个类别 A、B 和 C。在这种情况下,损失为0.479:
H = - (0.0*ln(0.228) + 1.0*ln(0.619) + 0.0*ln(0.153)) = 0.479
请注意,只要您始终使用相同的对数底,使用哪种对数底并不重要。恰巧,Python Numpy log()
函数计算自然对数(以e为底)。
以下是上述示例使用 Numpy 在 Python 中表达的方式:
import numpy as np
p = np.array([0, 1, 0]) # True probability (one-hot)
q = np.array([0.228, 0.619, 0.153]) # Predicted probability
cross_entropy_loss = -np.sum(p * np.log(q))
print(cross_entropy_loss)
# 0.47965000629754095
这就是您的预测与真实分布之间“错误”或“偏离”的方式。机器学习优化器将尝试最小化损失(即它将尝试将损失从0.479减少到0.0)。
从上面的例子中,我们可以看到损失为0.4797。因为我们使用的是自然对数(以e为底的对数),所以单位为nats,因此我们说损失为0.4797 nats。如果对数是以2为底的,则单位为比特。请参见this page以获取进一步的解释。
为了更好地理解这些损失值反映了什么,让我们看一些极端的例子。
同样,假设真实的(one-hot)分布是:
Pr(Class A) Pr(Class B) Pr(Class C)
0.0 1.0 0.0
现在假设你的机器学习算法表现得非常出色,以非常高的概率预测了B类:
Pr(Class A) Pr(Class B) Pr(Class C)
0.001 0.998 0.001
p = np.array([0, 1, 0])
q = np.array([0.001, 0.998, 0.001])
print(-np.sum(p * np.log(q)))
# 0.0020020026706730793
Pr(Class A) Pr(Class B) Pr(Class C)
0.001 0.001 0.998
p = np.array([0, 1, 0])
q = np.array([0.001, 0.001, 0.998])
print(-np.sum(p * np.log(q)))
# 6.907755278982137
Pr(Class A) Pr(Class B) Pr(Class C)
0.333 0.333 0.334
导致的损失为1.10。
p = np.array([0, 1, 0])
q = np.array([0.333, 0.333, 0.334])
print(-np.sum(p * np.log(q)))
# 1.0996127890016931
交叉熵是众多可能的损失函数之一(另一个流行的是SVM hinge loss)。这些损失函数通常被写成J(theta),可以在梯度下降中使用,这是一种迭代算法,用于将参数(或系数)移向最优值。在下面的方程中,您需要用H(p, q)替换J(theta)。但请注意,您需要先计算H(p,q)相对于参数的导数。
所以直接回答您最初的问题:
这只是一种描述损失函数的方法吗?
是的,交叉熵描述了两个概率分布之间的损失。它是许多可能损失函数之一。
那么我们可以使用例如梯度下降算法来找到最小值。
是的,交叉熵损失函数可以作为梯度下降的一部分使用。
进一步阅读:我其他答案中与TensorFlow相关的内容。
p(x)
将是每个类别的真实概率列表,即 [0.0, 1.0, 0.0]
。同样,q(x)
是每个类别的预测概率列表,即 [0.228, 0.619, 0.153]
。然后,H(p, q)
就是 - (0 * log(2.28) + 1.0 * log(0.619) + 0 * log(0.153))
,计算结果为0.479。请注意,通常使用Python的np.log()
函数,它实际上是自然对数;这并不重要。 - stackoverflowuser2010ln(0.619)
的结果相同。 - HAr