Pytorch期望类型为Long,但实际传入的类型是int。

3
我遇到了一个错误。
 Expected object of scalar type Long but got scalar type Int for argument #3 'index'

这是来自这一行。

targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)

我不确定该怎么做,因为我尝试在几个地方将其转换为长整型。我试图放置一个

.long

在最后设置dtype为torch.long,但仍未生效。

与这个问题非常相似,但他没有做任何事情来得到答案: "Expected Long but got Int" while running PyTorch script

我已经改了很多代码,这是我的最后一次尝试,但现在出现了同样的问题。

    def forward(self, inputs, targets):
            """
            Args:
                inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
                targets: ground truth labels with shape (num_classes)
            """
            log_probs = self.logsoftmax(inputs)
            targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
            if self.use_gpu: targets = targets.to(torch.device('cuda'))
            targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
            loss = (- targets * log_probs).mean(0).sum()
            return loss

你所参考的问题已经大约5个月了。他们不太可能只是坐在那里等待这么长时间。在问题下面留言,看看他们是否找到了解决方案。 - ggdx
在Python 3中不再支持Long类型,因此请尝试使用int(var_name)。 - Chetan Vashisth
2个回答

3

你的索引参数的数据类型(即targets.unsqueeze(1).data.cpu())需要是torch.int64

(错误信息有点令人困惑:torch.long并不存在。但在PyTorch内部,“Long”表示int64)。


我该怎么做?我知道在Java中是这样的int(x);但我尝试了类似的Python代码,但它没有起作用。有什么建议吗?我还尝试了添加dtype=int64,但也不起作用。 - MNM
targets = torch.int64(torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1))目标 = torch.int64(torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)) - MNM
targets = torch.zeros(log_probs.size()).scatter_(1, torch.int64(targets.unsqueeze(1).data.cpu()), 1)目标= torch.zeros(log_probs.size()).scatter_(1, torch.int64(targets.unsqueeze(1).data.cpu()), 1) - MNM
targets = torch.int64(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 将目标值用one-hot编码表示,即从大小为log_probs的张量中选取与目标值相同位置的下标,将其对应的元素赋值为1。最后得到的结果保存在targets中。 - MNM
targets = torch.zeros(log_probs.size(), dtype=torch.long).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)目标= torch.zeros(log_probs.size(), dtype=torch.long).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) - MNM
显示剩余2条评论

0
targets = torch.zeros(log_probs.size()).scatter_(1, (targets.unsqueeze(1).data.cpu()).long(), 1)

1
嗨!感谢分享答案! 将来,为答案添加解释肯定会有所帮助! :) - TheOneWhoPrograms

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接