在PyTorch张量中筛选数据

19

我有一个张量 X,如 [0.1, 0.5, -1.0, 0, 1.2, 0],我想实现一个名为 filter_positive() 的函数,它可以将正数数据过滤到一个新的张量中,并返回原始张量的索引。例如:

new_tensor, index = filter_positive(X)

new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]

我该如何在PyTorch中高效地实现这个函数?

2个回答

28

看一下torch.nonzero,它大致相当于np.where。它将二进制掩码转换为索引:

>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)

>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
        [1],
        [3],
        [4],
        [5]])

>>> X[indices]
tensor([[0.1000],
        [0.5000],
        [0.0000],
        [1.2000],
        [0.0000]])

那么一个解决方案是写成:

mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)

7

如果索引不是必要的,您可以直接这样操作:

X = X[X > 0]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接