用另一个多维张量索引多维 torch 张量

5
我有一个在PyTorch中的张量x,假设它的形状为(5,3,2,6),还有另一个形状为(5,3,2,1)的张量idx,其中包含了第一个张量中每个元素的索引。我想要用第二个张量的索引对第一个张量进行切片。我尝试了x = x[idx],但是我得到了一个奇怪的维度,而我真正想要的是形状为(5,3,2)或(5,3,2,1)。
我将尝试给出一个更简单的例子:假设
x=torch.Tensor([[10,20,30],
                 [8,4,43]])
idx = torch.Tensor([[0],
                    [2]])

我想要类似的东西

y = x[idx]

这样' y '将输出[[10],[43]]或类似的结果。

索引表示所需元素在最后一个维度中的位置。对于上面的示例,其中x.shape = (2,3),最后一个维度是列,则“idx”中的索引是列。我希望对于超过2个维度的情况也是如此。


如何解释指数idx = [[0],[2]]以从x中获取值[[10],[43]]?不清楚这些索引表示什么,它们是行/列还是平坦数组索引? - Ehsan
这意味着在最后一个维度中的位置,对于该示例来说是列。 - Jessica Borja
3个回答

2

从评论中我了解到,你需要 idx 成为最后一个维度上的索引,idx 中的每个索引与 x 中的相应索引对应(除了最后一个维度)。在这种情况下(这是numpy版本,你可以将其转换为torch):

ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]

输出:

[[10]
 [43]]

1
你可以使用 rangesqueeze 来获取正确的 idx 维度,如下所示:
x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])

# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
        [43.]])

0
这是在PyTorch中使用“gather”运行的代码。 “idx”需要以“torch.int64”格式,以下一行将确保它(请注意“tensor”中小写的“t”)。
idx = torch.tensor([[0],
                    [2]])
torch.gather(x, 1, idx) # 1 is the axis to index here
tensor([[10.],
        [43.]])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接