在tensorflow张量中替换nan值

22
我正在使用TensorFlow工作,开发一种卷积神经网络。然而我遇到了一个问题:我通过tfrecords读取的输入图像包含一定数量的NaN值。这是因为该图像代表一个深度图,其中存在一些无限值,在对其进行编码并解码后,这些无限值就变成了NaN值。
在我的情况下,无法在编码图像之前替换原始图像中的无限值。那么,在将图像张量提供给神经网络之前,是否有任何方法可以替换图像张量中的NaN值?

我尝试过 input_clean = tf.map_fn(lambda x: x if x == x else 0.0, input),但它并没有删除 NaNs... 还有这个:cleaned = tf.map_fn(lambda x: 0.0 if math.isnan(x) else 2*x, input) - 引发了 TypeError: a float is required... - Tomasz Gandor
5个回答

37

tf.where和tf.is_nan的组合可以实现:

import tensorflow as tf
with tf.Session():
    has_nans = tf.constant([float('NaN'), 1.])
    print(tf.where(tf.is_nan(has_nans), tf.zeros_like(has_nans), has_nans).eval())

打印输出(使用 TensorFlow 0.12.1):

[ 0.  1.]

1
有没有比tf.where()更快的选项?我们在代码库中使用了很多次tf.where,它正在迅速成为瓶颈。 - optimist
1
我不知道有没有针对这个的融合操作;理论上你可以编写一个,或者使用XLA来为您融合。你知道NaN是从哪里来的吗?理想情况下,在很多地方甚至根本不需要这样做。add_check_numerics_opstfdbg可能有助于跟踪它们。 - Allen Lavoie

10
如果有人正在寻找Tensorflow 2.0的解决方案,那么适用的代码是 Allen Lavoie 改编的:
import tensorflow as tf
with tf.compat.v1.Session():
    has_nans = tf.constant([float('NaN'), 1.])
    print(tf.where(tf.math.is_nan(has_nans), tf.zeros_like(has_nans), has_nans).eval())

8
一个更简单的方法,兼容TF2.0,就是使用tf.clip_by_value函数,它类似于np.clip并且可以移除NaN值(参见这里):
no_nans = tf.clip_by_value(has_nans, -1e12, 1e12)

一些注意事项: 1)这也会移除"无穷大"的值 2)根据您的应用程序,您可能需要将剪辑值设置为较高的数值,以避免丢失信息。


1
它对NaN值做什么?将它们映射为零吗? - Thomas Ahle
2
这在tensorflow 2.8中不起作用。 - Taw

8

通过值剪裁产生了NaN和无穷大,并且对于一个变量来说,这是过头了。我使用这个方法将单个值转换为0,如果它是NaN:

value_not_nan = tf.dtypes.cast(tf.math.logical_not(tf.math.is_nan(value)), dtype=tf.float32)
tf.math.multiply_no_nan(value, value_not_nan)

3

在TensorFlow 2.0中,您可以使用tf.math.is_nantf.tensor_scatter_nd_update实现:

tensor_with_nan = tf.convert_to_tensor([[np.nan,1.],[0.,np.nan]])
new_value = 9.

indices = tf.where(tf.math.is_nan(tensor_with_nan))
tensor_without_nan = tf.tensor_scatter_nd_update(
    tensor_with_nan,
    indices,
    tf.ones((tf.shape(indices)[0]))*new_value
)

在我看来,这是最干净的方式。让我的一天变得美好! - P. Navarro

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接