PyTorch张量中的唯一值

19

我正在尝试在PyTorch张量中查找不同的值。 是否有Tensorflow的unique op的高效类比?


1
仅供参考 - 在pytorch的github上有一个功能请求,链接在这里:https://github.com/pytorch/pytorch/issues/2031 - cleros
正如其他答案所指出的那样,unique op 存在于 torch>=0.4。 - arosa
3个回答

20

0.4.0版本中有一个torch.unique()方法。

torch <= 0.3.1版本中,你可以尝试:

import torch
import numpy as np

x = torch.rand((3,3)) * 10
np.unique(x.round().numpy())

1
不错,但是目前 (0.4.0) 版本的 unique 只支持 CPU.. :-) - aerin

7
您可以将其转换为numpy数组,并利用numpy内置的“unique”函数:
您可以将其转换为numpy数组,并利用numpy内置的“unique”函数:
def unique(tensor1d):
    t, idx = np.unique(tensor1d.numpy(), return_inverse=True)
    return torch.from_numpy(t), torch.from_numpy(idx)  

例子:

t, idx = unique(torch.LongTensor([1, 1, 2, 4, 4, 4, 7, 8, 8]))  
# t --> [1, 2, 4, 7, 8]
# idx --> [0, 0, 1, 2, 2, 2, 3, 4, 4]

我认为这个方法可以行得通,但我宁愿避免使用numpy操作,因为它可能会花费太多时间。无论如何,我认为这是目前唯一的解决方案,谢谢。 - arosa

1
  1. 使用torch.eq()获取两个张量之间的公共项
  2. 获取索引并连接张量
  3. 最后通过torch.unique获得公共项:
import torch as pt

a = pt.tensor([1,2,3,2,3,4,3,4,5,6])
b = pt.tensor([7,2,3,2,7,4,9,4,9,8])

equal_data = pt.eq(a, b)
pt.unique(pt.cat([a[equal_data],b[equal_data]]))

3
如果您在代码中添加一些解释说明,那就更好了。 - L. F.
我们获取两个张量之间的共同项。通过 @2 tensor.eq() 相当于获取索引,最后通过 'torch.unique' 来连接张量以获取共同项。 - Sunil
a[equal_data].unique()应该也能实现目标,同时避免“猫” ;) - Tom

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接