Pytorch:如何在2D张量的每一行中找到第一个非零元素的索引?

12

我有一个二维张量,每一行都有一些非零元素,就像这样:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

我想要一个张量,其中包含每行第一个非零元素的索引:

indices = tensor([2],
                 [3])

我该如何在Pytorch中计算它?

4个回答

14

我简化了Iman的方法,使其能够实现以下功能:

idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)

1
这是一个如此简单但聪明的解决方案! - Tian
请注意,如果没有非零值,则默认为零作为第一个非零索引。 - sachinruk

7

我找到了一个巧妙的答案来回答我的问题:

  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

结果如下:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])

6
假设所有非零值都相等,argmax返回第一个索引。
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)

1
借鉴@Seppo的答案,我们可以通过从原始张量创建掩码并使用pytorch函数来消除“所有非零值相等”的假设。
# tmp = some tensor of whatever shape and values
indices = torch.argmax((tmp != 0).to(dtype=torch.int), dim=-1)

然而,如果张量的一行全是零,则返回的信息不是第一个非零元素的索引。我想这个问题的性质使得这种情况不会发生。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接