假设我有一个张量
t = torch.tensor([1,2,3,4,5])
我想使用一个相同大小的张量来对它进行分割,该张量告诉我每个元素应该放在哪个子集中。
indices = torch.tensor([0,1,1,0,2])
因此最终结果是
splits
[tensor([1,4]), tensor([2,3]), tensor([5])]
有没有一种简便的方法在Pytorch中实现这个功能?
编辑:通常会有超过2或3个拆分。
对于一般情况, 可以使用argsort
进行排序:
def mask_split(tensor, indices):
sorter = torch.argsort(indices)
_, counts = torch.unique(indices, return_counts=True)
return torch.split(t[sorter], counts.tolist())
mask_split(t, indices)
如果这是你真实的使用情况,最好使用@flawr的答案(同时列表推导式
可能更快,因为它不需要排序),可以尝试以下代码:
def mask_split(tensor, indices):
unique = torch.unique(indices)
return [tensor[indices == i] for i in unique]
使用逻辑索引确实是可能的,您只需要确保索引“掩码”由布尔值组成,因此在您的情况下
splits = t[indices > 0] , t[indices < 1]
或者,您可以先将张量indices
转换为布尔类型。
t = torch.tensor([1,2,3,4])
print(t[[0,1,3]])
a = np.array([0, 1, 1, 0])
ind_ones = np.argwhere(a == 1).squeeze()
ind_zers = np.argwhere(a == 0).squeeze()
print(t[ind_ones]) # tensor([2, 3])
print(t[ind_zers]) # tensor([1, 4])