重新加载Pytorch模型时出现CUDA内存不足错误

4

这里有一个关于pytorch的常见错误,但我在特定情况下遇到了它:当重新加载模型时,即使我还没有将模型放在GPU上,我也会收到一个CUDA: Out of Memory错误。

model = model.load_state_dict(torch.load(model_file_path))
optimizer = optimizer.load_state_dict(torch.load(optimizer_file_path))
# Error happens here ^, before I send the model to the device.
model = model.to(device_id)
1个回答

9
问题在于我试图在一个新的 GPU (cuda:2) 上加载模型, 但是最初却是从另一个 GPU (cuda:0) 中保存了模型和优化器。即使我没有明确告诉它重新加载到之前的 GPU,其默认行为也是重新加载到原始 GPU(这恰好被占用了)。
在每个 torch.load 调用中添加 map_location=device_id 可以解决问题。
model.to(device_id)
model = model.load_state_dict(torch.load(model_file_path, map_location=device_id))
optimizer = optimizer.load_state_dict(torch.load(optimizer_file_path, map_location=device_id))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接