PyTorch错误:调用`cublasCreate(handle)`时出现CUDA错误:CUBLAS_STATUS_INTERNAL_ERROR。

7

我有一个非常简单的例子。

import torch

if __name__ == "__main__":
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    m = torch.nn.Linear(20, 30).to(DEVICE)
    input = torch.randn(128, 20).to(DEVICE)
    output = m(input)
    print('output', output.size())
    exit()

然后我得到:

Traceback (most recent call last):
  File "test.py", line 9, in <module>
    output = m(input)
  File "/home/shamoon/.local/share/virtualenvs/speech-reconstruction-7HMT9fTW/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/shamoon/.local/share/virtualenvs/speech-reconstruction-7HMT9fTW/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/shamoon/.local/share/virtualenvs/speech-reconstruction-7HMT9fTW/lib/python3.8/site-packages/torch/nn/functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasCreate(handle)`

我正在使用PyTorch 1.7.1。任何帮助将不胜感激。

谢谢。

编辑。python -m torch.utils.collect_env的更新如下:

Collecting environment information...
PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 11.1.0
CMake version: version 3.18.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: TITAN RTX
GPU 3: TITAN RTX
GPU 4: TITAN RTX
GPU 5: TITAN RTX
GPU 6: TITAN RTX
GPU 7: TITAN RTX

Nvidia driver version: 460.39
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.1
[pip3] torch==1.8.0
[pip3] torchaudio==0.8.0
[pip3] torchsummary==1.5.1
[conda] Could not collect

您介意分享一下 python -m torch.utils.collect_env 的输出吗? - Berriel
添加到原帖。 - Shamoon
1个回答

2
根据您的日志记录,已安装PyTorch 1.8而不是1.7.1。否则,请使用正确的Python二进制文件重新发送您的日志。
我遇到了完全相同的问题,使用1.8版本。降级到1.7.1版本解决了这个问题(如huggingface transformers github issue中所述)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接