二元交叉熵计算中的正样本权重

3

当我们处理不平衡的训练数据(负样本更多,正样本更少)时,通常会使用pos_weight参数。 pos_weight期望的是,当positive sample标签错误时,模型的损失值将比negative sample高。 当我使用binary_cross_entropy_with_logits函数时,我发现:

bce = torch.nn.functional.binary_cross_entropy_with_logits

pos_weight = torch.FloatTensor([5])
preds_pos_wrong =  torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)

preds_neg_wrong =  torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)

然而:

>>> loss_pos_wrong
tensor(2.0359)

>>> loss_neg_wrong
tensor(2.0359)

错误的正样本和负样本导致的损失相同,那么在不平衡数据的损失计算中,pos_weight 如何起作用?
1个回答

10
TLDR; 两个损失是相同的,因为你计算的是相同的量:两个输入是相同的,两个批次元素和标签只是交换了位置。

为什么你获得相同的损失值?

我认为你在使用F.binary_cross_entropy_with_logits你可以在nn.BCEWithLogitsLoss中查找更详细的文档页面)方面有些混淆。在你的情况下,你的输入形状(也就是模型的输出)是一维的,这意味着你只有一个逻辑值 x,而不是两个)。

在你的例子中:

preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])

这意味着您的批量大小为 2,并且由于默认情况下该函数在批量元素上平均损失,因此您会得到相同的结果 BCE(preds_pos_wrong, label_pos)BCE(preds_neg_wrong, label_neg)。您批次的两个元素只是被交换了。
您可以通过使用 reduction='none' 选项来轻松验证不对批次元素进行平均损失。
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
       pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])

>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
       pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])

研究F.binary_cross_entropy_with_logits

话虽如此,二元交叉熵的公式如下:

bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]

y(分别是sigmoid(x))代表与该逻辑回归相关联的正类别,1-y(分别是1-sigmoid(x))则是负类别。

关于pos_weight的加权方案,文档可以更加精确(不要与加权不同的weight混淆,后者是对不同逻辑回归结果的加权)。如你所说,pos_weight的想法是为了加权正项,而不是整个项。

bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]

w_p是正类项的权重,用于补偿正负样本不平衡。在实践中,应该设定为w_p = #negative/#positive

因此:

>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])

使用内置的损失函数,

>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])

与手动计算相比:

>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接