在PyTorch中使用nn.embedding时遇到问题,期望的标量类型为Long,但实际得到的是torch.cuda.FloatTensor类型(如何修复)?

4

所以我有一个RNN编码器,它是更大的语言模型的一部分。其中过程为:编码 -> rnn -> 解码。

作为我的rnn类的__init__的一部分,我有以下内容:

self.encode_this = nn.Embedding(self.vocab_size, self.embedded_vocab_dim)

现在我正在尝试实现一个前向类,它接收批次数据,执行编码然后解码。

def f_calc(self, batch):
    #Here, batch.shape[0] is the size of batch while batch.shape[1] is the sequence length

    hidden_states = (torch.zeros(self.num_layers, batch.shape[0], self.hidden_vocab_dim).to(device))
    embedded_states = (torch.zeros(batch.shape[0],batch.shape[1], self.embedded_vocab_dim).to(device))

    o1, h = self.encode_this(embedded_states)

然而,我的问题始终与编码器有关,这给了我以下错误:
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1465         # remove once script supports set_grad_enabled
   1466         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1467     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1468 
   1469 

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

任何人有任何想法如何解决吗?我完全不熟悉pytorch,所以如果这是一个愚蠢的问题,请原谅我。我知道涉及到某种类型转换,但我不确定如何去做...非常感谢!
1个回答

4

嵌入层期望输入整数。

import torch as t

emb = t.nn.Embedding(embedding_dim=3, num_embeddings=26)

emb(t.LongTensor([0,1,2]))

在您的代码中添加 long()

enter image description here

embedded_states = (torch.zeros(batch.shape[0],batch.shape[1], self.embedded_vocab_dim).to(device)).long()

当我这样做并通过我的rnn时,self.rnn(embedded_states),会添加一个全新的维度...所以张量变成了4d而不是3d,您知道为什么会发生这种情况吗? - skidjoe
因为你进行嵌入:你用它们的嵌入替换最后一个维度中的整数!换句话说,你用向量替换数字(标量),并添加一个额外的维度。 - Alexey Golyshev
嗯,你有没有想过如何保持它的3D?我一直在尝试挤压方法,但似乎不起作用。我的RNN输入是3D而不是4D...我应该传入3D嵌入和隐藏状态,对吧?或者我完全做错了吗? - skidjoe
不了解您的想法很难说。您可以在这里看到我的代码。我正在进行字符嵌入。在输入时,我有[句子、单词、字符]。我用它的嵌入表示每个字符(现在我有[句子、单词、字符、嵌入]),并使用LSTM获取单词的嵌入。然后将LSTM应用于单词以解决分类问题。 - Alexey Golyshev
我的输入形式为NL,其中N是批量大小,L是序列长度,我的输出应该是NL*V,其中V是词汇表大小(每个单词的嵌入)...现在我的嵌入状态的初始化是一个三维零向量(批量大小,序列长度,嵌入维度),但是当我运行nn.embedding时,我得到了一个额外的维度。 - skidjoe
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接