使用与原始张量相同大小的索引张量来分割torch张量

4

假设我有一个张量

t = torch.tensor([1,2,3,4,5])

我想使用一个相同大小的张量来对它进行分割,该张量告诉我每个元素应该放在哪个子集中。
indices = torch.tensor([0,1,1,0,2])

因此最终结果是

splits
[tensor([1,4]), tensor([2,3]), tensor([5])]

有没有一种简便的方法在Pytorch中实现这个功能?
编辑:通常会有超过2或3个拆分。
3个回答

4

对于一般情况, 可以使用argsort进行排序:

def mask_split(tensor, indices):
    sorter = torch.argsort(indices)
    _, counts = torch.unique(indices, return_counts=True)
    return torch.split(t[sorter], counts.tolist())


mask_split(t, indices)

如果这是你真实的使用情况,最好使用@flawr的答案(同时列表推导式可能更快,因为它不需要排序),可以尝试以下代码:

def mask_split(tensor, indices):
    unique = torch.unique(indices)
    return [tensor[indices == i] for i in unique]

2

使用逻辑索引确实是可能的,您只需要确保索引“掩码”由布尔值组成,因此在您的情况下

splits = t[indices > 0] , t[indices < 1]

或者,您可以先将张量indices转换为布尔类型。


1
除了其他答案之外,对于在pytorch中进行索引,您可以直接使用索引位置来访问这些元素:
t = torch.tensor([1,2,3,4])
print(t[[0,1,3]])

所以你并不需要为索引存储张量。如果想要的话,仍然可以存储一个由1和0组成的numpy数组,然后从该数组中找到访问索引。
a = np.array([0, 1, 1, 0])
ind_ones = np.argwhere(a == 1).squeeze()
ind_zers = np.argwhere(a == 0).squeeze()
print(t[ind_ones])   # tensor([2, 3])
print(t[ind_zers])   # tensor([1, 4])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接