我有一个二维张量,每一行都有一些非零元素,就像这样:
import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
我想要一个张量,其中包含每行第一个非零元素的索引:
indices = tensor([2],
[3])
我该如何在Pytorch中计算它?
我有一个二维张量,每一行都有一些非零元素,就像这样:
import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
我想要一个张量,其中包含每行第一个非零元素的索引:
indices = tensor([2],
[3])
我该如何在Pytorch中计算它?
我简化了Iman的方法,使其能够实现以下功能:
idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
我找到了一个巧妙的答案来回答我的问题:
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
idx = reversed(torch.Tensor(range(1,8)))
print(idx)
tmp2= torch.einsum("ab,b->ab", (tmp, idx))
print(tmp2)
indices = torch.argmax(tmp2, 1, keepdim=True)
print(indeces)
结果如下:
tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
[0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
[3]])
argmax
返回第一个索引。tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
# tmp = some tensor of whatever shape and values
indices = torch.argmax((tmp != 0).to(dtype=torch.int), dim=-1)
然而,如果张量的一行全是零,则返回的信息不是第一个非零元素的索引。我想这个问题的性质使得这种情况不会发生。