Keras + Tensorflow:调试NaNs

7

这里有一个关于如何在tensorflow图中找到第一个NaN出现的好问题:

在反向传递中调试nans

答案非常有帮助,这是其中的代码:

train_op = ...
check_op = tf.add_check_numerics_ops()

sess = tf.Session()
sess.run([train_op, check_op])  # Runs training and checks for NaNs

显然,同时运行训练和数值检查会在第一次遇到NaN时立即导致错误报告。

我该如何将其集成到Keras中呢? 在文档中,我找不到任何看起来像这样的东西。

我也检查了代码。 更新步骤在此处执行: https://github.com/fchollet/keras/blob/master/keras/engine/training.py

有一个名为_make_train_function的函数,它创建了一个操作来计算损失和应用更新。稍后调用它来训练网络。

我可以像这样更改代码(始终假设我们正在运行tf后端):

check_op = tf.add_check_numerics_ops()

self.train_function = K.function(inputs, 
    [self.total_loss] + self.metrics_tensors + [check_op],
    updates=updates, name='train_function', **self._function_kwargs)

我目前正在尝试正确设置这个,不确定上面的代码是否有效。也许有更简单的方法?


1
你找到答案了吗?这个方法有效吗?我和你去年遇到了同样的问题。 - 0vbb
1个回答

1

我遇到了完全相同的问题,并找到了替代check_add_numerics_ops()函数的方法。我使用TensorFlow Debugger来遍历我的模型,按照https://www.tensorflow.org/guide/debugger中的示例,找出我的代码在哪里产生了nan。这段代码应该可以用来替换Keras正在使用的TensorFlow会话,使用tfdbg

from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接