运行时错误:CUDA错误:在损失函数上触发了设备端断言。

3

/pytorch/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [18,0,0], thread: [54,0,0] 断言 input_val >= zero && input_val <= one 失败。

/pytorch/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [18,0,0], thread: [55,0,0] 断言 input_val >= zero && input_val <= one 失败。

/pytorch/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [18,0,0], thread: [56,0,0] 断言 input_val >= zero && input_val <= one 失败。

/pytorch/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [18,0,0], thread: [57,0,0] 断言 input_val >= zero && input_val <= one 失败。

/pytorch/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [18,0,0], thread: [58,0,0] 断言input_val >= zero && input_val <= one失败。

/pytorch/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [18,0,0], thread: [59,0,0] 断言input_val >= zero && input_val <= one失败。

追溯(最近的调用最先):
文件“run_toys.py”,第215行
loss = criterion(torch.reshape(out, [-1, dataset.out_dim]), torch.reshape(target, [-1, dataset.out_dim]))
文件“/usr/local/python3/lib/python3.6/site-packages/torch/nn/modules/module.py”,第727行
result = self.forward(*input, **kwargs)
文件“/usr/local/python3/lib/python3.6/site-packages/torch/nn/modules/loss.py”,第530行
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) 文件“/usr/local/python3/lib/python3.6/site-packages/torch/nn/functional.py”,第2526行
binary_cross_entropy
input, target, weight, reduction_enum)
运行时错误:CUDA错误:设备端触发断言

代码

criterion = nn.CrossEntropyLoss()
loss = criterion(torch.reshape(out, [-1, dataset.out_dim]), torch.reshape(target, [-1, dataset.out_dim]))
loss = torch.mean(loss)

目标和输出的形状相同 # torch.Size([640, 32])
模型在我的CPU上运行良好,但在GPU上运行有问题。

我在使用自定义损失函数时遇到了同样的问题。你解决了吗? - moefasa
1个回答

6
可能有两个原因导致错误:
  1. 日志显示input_val不在[0; 1]范围内。因此,您应确保模型输出在该范围内。您可以使用pytorch的torch.clamp()函数。在计算损失之前,请添加以下代码行:
    out = out.clamp(0, 1)
  1. 也许你确定模型输出在 [0; 1] 范围内。但是很常见的问题是输出包含一些 nan 值,这会触发断言错误。为了防止这种情况,在计算损失之前可以使用以下技巧:
    out[out!=out] = 0 # or 1 depending on your model's need

这里的诀窍在于使用 nan!=nan 属性,我们应该将它们更改为一些有效的数字。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接