如何在Keras中实现稀疏均方误差损失函数

5
我想修改以下keras均方误差(MSE)损失函数,使得只计算稀疏的损失。
我的输出y是一个3通道图像,其中第3个通道仅在需要计算损失的像素处为非零值。您有任何想法如何修改上述内容以计算稀疏损失吗?
代码如下: def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true) * y_true[:, :, :, 2], axis=-1)
2个回答

8

这不是你要找的确切损失函数,但我希望它可以给你写函数提供一些提示(还可以在这里查看Github讨论):

def masked_mse(mask_value):
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        masked_mse = (K.sum(masked_squared_error, axis=-1) /
                      K.sum(mask_true, axis=-1))
        return masked_mse
    f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
    return f

该函数计算预测输出的所有值(除了真实输出中对应值等于掩码值(例如-1)的元素)的MSE损失。
两个注意点:
- 计算均值时,分母必须是非遮蔽值的数量,而不是数组的维度,这就是为什么我不使用K.mean(masked_squared_error, axis=1)而是手动平均的原因。 - 掩蔽值必须是有效的数字(即np.nannp.inf不能胜任该工作),这意味着您需要调整数据以使其不包含mask_value
在本例中,目标输出始终为[1, 1, 1, 1],但某些预测值逐渐被屏蔽。
y_pred = K.constant([[ 1, 1, 1, 1], 
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
                     [ 1, 1, 1, 1],
                     [-1, 1, 1, 1],
                     [-1,-1, 1, 1],
                     [-1,-1,-1, 1],
                     [-1,-1,-1,-1]])

true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))

for i in range(true.shape[0]):
    print(true[i], pred[i], loss[i], sep='\t')

预期输出为:
[ 1.  1.  1.  1.]  [ 1.  1.  1.  1.]  0.0
[ 1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.0
[-1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.33333
[-1. -1.  1.  1.]  [ 1.  1.  1.  3.]  2.0
[-1. -1. -1.  1.]  [ 1.  1.  1.  3.]  4.0
[-1. -1. -1. -1.]  [ 1.  1.  1.  3.]  nan

1
为了防止出现nan,请按照此处的说明操作。以下假设您希望掩码值(背景)等于零:
 # Copied almost character-by-character (only change is default mask_value=0)
 # from https://github.com/keras-team/keras/issues/7065#issuecomment-394401137
 def masked_mse(mask_value=0):
    """
    Made default mask_value=0; not sure this is necessary/helpful
    """
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        # in case mask_true is 0 everywhere, the error would be nan, therefore divide by at least 1
        # this doesn't change anything as where sum(mask_true)==0, sum(masked_squared_error)==0 as well
        masked_mse = K.sum(masked_squared_error, axis=-1) / K.maximum(K.sum(mask_true, axis=-1), 1)
        return masked_mse
    f.__name__ = str('Masked MSE (mask_value={})'.format(mask_value))
    return f

@baldassarreFe 我注意到你的答案也在上面的位置;基于这个观察,我建议对你的答案进行编辑。 - brethvoice

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接