Jax - 调试NaN值

6
大家晚上好,
我过去的6个小时一直在尝试调试Jax中看似随机出现的NaN值。我已经缩小问题范围,发现这些NaN最初来自损失函数或其梯度。
一个可以重现错误的最小笔记本在这里提供:https://colab.research.google.com/drive/1uXa-igMm9QBOOl8ZNdK1OkwxRFlLqvZD?usp=sharing 这对于Jax也可能是一个有趣的用例。我使用Jax解决方向估计任务,只有有限的陀螺仪/加速度测量数据可用。在这种情况下,高效实现四元数运算很重要。
训练循环一开始还不错,但最终会发散。
Step 0| Loss: 4.550444602966309 | Time: 13.910547971725464s
Step 1| Loss: 4.110116481781006 | Time: 5.478027105331421s
Step 2| Loss: 3.7159230709075928 | Time: 5.476970911026001s
Step 3| Loss: 3.491917371749878 | Time: 5.474078416824341s
Step 4| Loss: 3.232130765914917 | Time: 5.433410406112671s
Step 5| Loss: 3.095140218734741 | Time: 5.433837413787842s
Step 6| Loss: 2.9580295085906982 | Time: 5.429029941558838s
Step 7| Loss: nan | Time: 5.427825689315796s
Step 8| Loss: nan | Time: 5.463077545166016s
Step 9| Loss: nan | Time: 5.479652643203735s

这可以通过分歧的梯度进行追踪,就像下面的片段所示。
(loss, _), grads = loss_fn(params, X[0], y[0], rnn.reset_carry(bs=2))

grads["params"]["Dense_0"]["bias"] # shape=(bs, out_features)
DeviceArray([[-0.38666773,         nan, -1.0433975 ,         nan],
             [ 0.623061  , -0.20950513,  0.8459796 , -0.42356613]],            dtype=float32)

我的问题是:如何调试这个问题?

启用NaN调试

启用NaN调试并没有真正帮助解决问题,反而会导致大量隐藏的堆栈跟踪。

from jax.config import config
config.update("jax_debug_nans", True)

非常感谢任何帮助!谢谢 :)


你尝试过通过 config.update('jax_disable_jit', True) 禁用JIT编译器,这样你就可以使用任何IDE调试代码而不会留下隐藏的痕迹。 - null
很高兴知道JIT编译实际上会导致隐藏的痕迹,我一直在想这个问题。明天会试一下。 - Simon B
1个回答

3
几种方法(在主要文档中得到很好的记录)可能有效:
1. 作为临时解决方案,切换到 `float64` 可以解决问题。更多信息请参见此处jax.config.update("jax_enable_x64", True)。 2. 梯度裁剪是你需要的全部内容 (文档)。 3. 有时候可以实现自己的反向传播,当你组合两个饱和函数进入一个不会饱和的函数,或者在奇点上强制值时,这可以帮助。 4. 通过检查计算图来诊断后向传递。通常寻找除法,用 div 标记:
from jax import make_jaxpr

# If grad_fn(x) gives you trouble, you can inspect the computation as follows:
grad_fn = jit(value_and_grad(my_forward_prop, argnums=0))
make_jaxpr(grad_fn)(x)

请注意,社区非常活跃,并且已经添加并正在添加一些支持来诊断 NaNs

希望这可以帮到您!
Andres

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接