理解交叉熵损失

4

我看到很多关于在背景下使用0或1作为真实值的CEL或二元交叉熵损失的解释,然后你会得到一个函数:

def CrossEntropy(yHat, y):
    if yHat == 1:
      return -log(y)
    else:
      return -log(1 - y)

然而,当yHat不是离散的0或1时,我对BCE的工作方式感到困惑。例如,如果我想查看MNIST数字的重构损失,其中我的基本事实为0 < yHat < 1,而我的预测也在同一范围内,那么这如何改变我的函数?
编辑:
抱歉,让我给出更多有关我的困惑的背景。在PyTorch VAE的教程中,他们使用BCE来计算重构损失,其中yhat(据我所知)不是离散的。请参见:

https://github.com/pytorch/examples/blob/master/vae/main.py

这个实现可以正常工作...但我不理解在这种情况下如何计算BCE损失。


对于使用图像的自编码器,您可以将像素值归一化为[0, 1]范围,然后使用BCE逐像素进行处理。 - Coolness
当然,但这是他们在这里做的吗? - Matt
在你发布的代码中,代码处理的不仅仅是0和1。第一个"if"语句处理了1的情况,但"else"语句处理的是所有其他值,而不仅仅是0。 - stackoverflowuser2010
1
请查看http://pytorch.org/docs/master/nn.html#torch.nn.functional.binary_cross_entropy_with_logits和http://pytorch.org/docs/master/nn.html#torch.nn.BCEWithLogitsLoss,它们从数据中获取sigmoid函数,因此将其归一化为`[0,1]`。 - Coolness
@stackoverflowuser2010 - 是的,但是如果它接受了除了0或1之外的任何内容,那么这段代码将无法正常工作(即无法计算正确的CE损失)。 - Matt
你读过这个吗?https://dev59.com/aFgQ5IYBdhLWcg3w7ofY - stackoverflowuser2010
3个回答

5
交叉熵衡量任意两个概率分布之间的距离。在所描述的VAE中,MNIST图像像素被解释为像素“开/关”的概率。在这种情况下,您的目标概率分布不仅是一个Dirac分布(0或1),而是可以具有不同的值。请参见维基百科上的交叉熵定义
以上作为参考,假设您的模型输出某个像素的重建值为0.7。这实际上表示您的模型估计p(pixel=1)= 0.7,并相应地估计p(pixel=0)= 0.3。
如果目标像素只是0或1,则此像素的交叉熵将为-log(0.3),如果真实像素为0,则为-log(0.7)(较小的值)如果真实像素为1。
如果真实像素为1,则完整公式为-(0*log(0.3)+1*log(0.7)),否则为-(1*log(0.3)+1*log(0.7))。
假设您的目标像素实际上是0.6!这基本上意味着该像素有0.6的概率为开,0.4的概率为关。
这只是将交叉熵计算更改为-(0.4*log(0.3) + 0.6*log(0.7))。
最后,您可以简单地对图像上的每个像素的交叉熵进行平均/求和。

0

交叉熵损失函数仅用于分类问题,即目标(yHat)是离散的情况。如果你面对的是回归问题,例如均方误差(MSE)损失函数更为合适。你可以在PyTorch库中找到各种损失函数及其实现这里

对于MNIST数据集来说,实际上是一个多类别分类问题(你要预测正确的数字,有10个可能的数字),所以二元交叉熵损失函数不适用,你应该使用一般的交叉熵损失函数。

无论如何,你调查的第一步应该是确定你的问题是“分类”还是“回归”。适用于一个问题的损失函数通常不适用于另一个问题。

编辑:你可以在TensorFlow网站的"MNIST for ML Beginners"教程中找到关于MNIST问题中交叉熵损失函数的更详细解释。


1
嗯,我主要是在VAE的上下文中感到困惑。例如,在pytorch官方示例中,BCE用于重构损失:https://github.com/pytorch/examples/blob/master/vae/main.py - Matt
对抗问题与传统的MNIST分类问题不同。在对抗问题中,第二个网络试图输出一个概率,即第一个网络生成的图像是真实数据的概率。这是一个二元问题:真实 vs. 假的。在传统的分类问题中,(单个)网络试图输出图像对应于任何数字的概率。这是一个多类问题:0 vs 1 vs 2 vs 等等。 - vbox
无论如何,对于真实与伪造的问题,BCE 是适当的选择。您编写的函数应该像平常一样工作,除了如果图像是假的,则 yHat 为 0,如果是真实的,则为 1;y 是它是真实的概率(由第二个网络生成)。 - vbox
我想你误解了,我指的是变分自编码器,而不是GAN。 - Matt
另外,我很想看到源代码,但是我在查找时遇到了一些麻烦。如果我没记错的话,nn模块调用F.模块,而BCE则调用其他东西(C文件?),我实际上在查找时遇到了一些问题... - Matt

0

通常情况下,您不应将非二进制类别编码为介于0和1之间的值。在MNIST的情况下,如果您将每个数字标记为0、0.1、0.2等,则意味着数字2的图像与数字0的图像相比,更类似于数字5的图像,这并不一定是正确的。

一个好的方法是使用“独热编码”来编码您的标签,作为一个由10个0组成的数组。然后,将对应于数字图像的索引设置为1。

如上所述,然后您将使用常规的交叉熵损失函数。您的模型应该输出每个样本的条件概率向量,对应于每个可能的类别。可能使用softmax函数。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接