运行时错误:期望的标量类型为Double,但参数#2的标量类型为Float。

3

我有一个PyTorchLSTM模型,我的forward函数如下:

    def forward(self, x, hidden):
        print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
        lstm_out, hidden = self.lstm(x, hidden)
        return lstm_out, hidden

所有的print语句都显示torch.float64,我相信这是double类型。那么为什么会出现这个问题呢?
我已经在所有相关的位置进行了double类型转换。
1个回答

6

确保您的数据和模型都使用 double 类型。

对于模型:

net = net.double()

关于数据:

net(x.double())

这个问题已经在PyTorch论坛上被讨论过了


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接