TensorFlow梯度 - 获取所有NaN值

8
我正在使用带有anaconda的Python 3和eager eval的tensorflow 1.12。
我使用它来为连体网络创建三元组损失函数,并需要计算不同数据样本之间的距离。
我创建了一个函数来计算距离,但无论我做什么,当我尝试计算相对于网络输出的梯度时,它都会给出所有nan梯度。
以下是代码:
def matrix_row_wise_norm(matrix):
    import tensorflow as tf
    tensor = tf.expand_dims(matrix, -1)

    tensor = tf.transpose(tensor, [0, 2, 1]) - tf.transpose(tensor, [2, 0, 1])
    norm = tf.norm(tensor, axis=2)
    return norm

我使用的是损失函数。
def loss(y_true, p_pred):
    with tf.GradientTape() as t:
    t.watch(y_pred)
        distance_matrix = matrix_row_wise_norm(y_pred)
        grad = t.gradient(distance_matrix, y_pred)

梯度全是 nan。我检查了 y_pred 的值都是合理的。我尝试创建关于自身的 y_pred * 2 的梯度并获得了合法的梯度值。

我错过了什么?创建距离矩阵时的索引有问题吗?


编辑:

y_predloss的数据类型均为tf.float32

编辑:在tf中发现了一个已公开的错误报告,这可能是问题所在吗?


编辑:

当我将norm轴更改为0或1时,我得到了合法的值,并且没有出现nan。使用axis=2进行norm操作时,我得到的是矩阵中行对之间的成对距离,我怀疑这可能与一行到自身的距离为0有关,因此我将值裁剪为最小值为1e-7,但没有成功。

谢谢


我曾经遇到过同样的问题,请检查 y_predlossdtype - Ankish Bansal
@AnkishBansal - 感谢您的回复,两者都是tf.float32。 - thebeancounter
你的矩阵中每个轴代表什么?我唯一能猜测的是norm(tensor, axis=2) 或者它上面的转置和减法操作没有梯度。我之前在自定义损失函数时遇到过这个问题,好像与重塑有关?非可导操作似乎会影响梯度计算。 - Engineero
@Engineero - 我在这里做的是取一个矩阵,每一行都是一个向量,我试图创建所有向量之间的成对距离,并通过复制向量、转置、相减和使用范数来实现。这怎么可能没有梯度呢? - thebeancounter
1个回答

4

似乎tf.norm存在数值不稳定性问题,如此处所述。

他们建议使用更加数值稳定的l2范数,因此我尝试了一下,但是也出现了nan值,这要感谢0梯度。因此,我将它们与梯度剪切一起使用,目前为止效果良好,损失函数能够正常工作并收敛。

def last_attempt(y_true, y_pred):
    import tensorflow as tf
    import numpy as np

    loss = tf.zeros(1)

    for i in range(y_pred.shape[0]):
        dist = tf.gather(y_pred, [i], axis=0)
        y = y_true.numpy().squeeze()
        norm = tf.map_fn(tf.nn.l2_loss, dist-y_pred)

        d = norm.numpy()
        d[np.where(y != y[i])] = 0.0
        max_pos = tf.gather(norm, np.argmax(d))

        d = norm.numpy()
        d[np.where(y == y[i])] = np.inf
        min_neg = tf.gather(norm, np.argmin(d))

        loss += tf.clip_by_value(max_pos - min_neg + tf.constant(1, dtype=tf.float32),
                                 1e-8, 1e1)

    return loss

针对该函数存在很大的优化空间,这里提供了另一个SO问题的参考答案-正在研究中。


1
我们能否在范数内部添加一个epsilon,以创建一个类似于此答案中的安全范数? - Lerner Zhang

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接