在torch.tensor中删除重复行

4

我有一个形状为(n,m)torch.tensor,我想要删除重复的行(或者至少找到它们)。例如:

t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
t2 = remove_duplicates(t1)

t2现在应该等于tensor([[1, 2, 3], [4, 5, 6]]),也就是删除了行13。您知道如何执行此操作吗?

我考虑使用torch.unique进行操作,但我无法弄清楚该怎么做。

1个回答

8
你可以简单地利用torch.unique中的参数dim来实现。
t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.unique(t1, dim=0)

这样做可以获得您想要的结果:

tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

在这里,你可以阅读该参数的含义。


谢谢,我误解了dim参数在unique函数中的使用。 - aretor

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接