我有一个PyTorch
LSTM模型,我的forward
函数如下:
def forward(self, x, hidden):
print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
lstm_out, hidden = self.lstm(x, hidden)
return lstm_out, hidden
所有的
print
语句都显示torch.float64
,我相信这是double类型。那么为什么会出现这个问题呢?我已经在所有相关的位置进行了
double
类型转换。