多类别分割的广义Dice Loss:Keras实现

10
我刚刚在Keras中实现了广义Dice损失(Dice损失的多类版本),如ref所述:
(我的目标定义为:(batch_size,image_dim1,image_dim2,image_dim3,nb_of_classes))
def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)

    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)

    gen_dice_coef = numerator/denominator

    return 1-2*gen_dice_coef

但是一定有什么问题。我正在处理必须分割为4类(1个背景类和3个物体类,我有一个不平衡的数据集)的3D图像。第一个奇怪的事情:虽然我的训练损失和准确度在训练过程中得到了改善(并且收敛非常快),但我的验证损失/准确度在整个时期内都保持不变(见image)。其次,在对测试数据进行预测时,只有背景类被预测:我得到一个恒定的体积。

我使用完全相同的数据和脚本,但使用分类交叉熵损失得到合理的结果(物体类被分割)。这意味着我的实现有问题。有什么想法吗?

此外,我认为为keras社区提供通用的dice损失实现将是有用的,因为它似乎在大多数最新的语义分割任务中都被使用(至少在医学图像社区中是如此)。

附言:我觉得权重的定义很奇怪;我得到的值大约在10^-10左右。还有其他人尝试过实现这个吗?我也测试了没有权重的函数,但仍然出现相同的问题。


嗨@Manu,你解决了这个问题吗? - clifgray
1个回答

7
我认为问题在于你的权重。想象一下你正在尝试解决一个多类别分割问题,但是在每个图像中只有很少的类别出现。这种情况的一个玩具例子(也是导致我遇到这个问题的例子)是以以下方式从mnist创建分割数据集。
x = 28x28图像和y = 28x28x11,其中如果像素低于归一化灰度值0.4,则将每个像素分类为背景,否则将其分类为原始类别的数字。因此,如果您看到数字1的图片,您将会有许多像素被分类为1和背景。
现在,在这个数据集中,您将只会有两个类出现在图像中。这意味着,按照你的骰子损失函数,9个权重将变成1./(0. + eps) = large,因此对于每个图像,我们都会强烈惩罚所有9个不存在的类别。网络在这种情况下要找到的明显的强局部最小值是把所有东西预测为背景类别。
我们确实希望惩罚任何不正确预测的不在图像中的类别,但是不要那么强烈。所以我们只需要修改权重。这就是我做的方法。
def gen_dice(y_true, y_pred, eps=1e-6):
    """both tensors are [b, h, w, classes] and y_pred is in logit form"""

    # [b, h, w, classes]
    pred_tensor = tf.nn.softmax(y_pred)
    y_true_shape = tf.shape(y_true)

    # [b, h*w, classes]
    y_true = tf.reshape(y_true, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
    y_pred = tf.reshape(pred_tensor, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])

    # [b, classes]
    # count how many of each class are present in 
    # each image, if there are zero, then assign
    # them a fixed weight of eps
    counts = tf.reduce_sum(y_true, axis=1)
    weights = 1. / (counts ** 2)
    weights = tf.where(tf.math.is_finite(weights), weights, eps)

    multed = tf.reduce_sum(y_true * y_pred, axis=1)
    summed = tf.reduce_sum(y_true + y_pred, axis=1)

    # [b]
    numerators = tf.reduce_sum(weights*multed, axis=-1)
    denom = tf.reduce_sum(weights*summed, axis=-1)
    dices = 1. - 2. * numerators / denom
    dices = tf.where(tf.math.is_finite(dices), dices, tf.zeros_like(dices))
    return tf.reduce_mean(dices)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接