在Numpy/PyTorch中快速查找大于阈值的值的索引

5

任务

给定一个 numpy 或者 pytorch 矩阵,找到数值大于给定阈值的单元格的索引。

我的实现

#abs_cosine is the matrix
#sim_vec is the wanted

sim_vec = []
for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
        # exclude diagonal cells
        if m != n and abs_cosine[m][n] >= threshold:
            sim_vec.append((m, n))

问题

速度。所有其他计算都是基于Pytorch构建的,使用numpy已经是一种妥协,因为它已将计算从GPU移动到CPU。纯python for循环将使整个过程变得更糟(对于小型数据集,速度已经慢了5倍)。我想知道是否可以将整个计算移动到Numpy(或pytorch)而不调用任何for循环?

我能想到的一个改进(但卡住了...)

bool_cosine = abs_cosine > 阈值

它返回一个布尔矩阵,其中包含TrueFalse。但我找不到一种快速检索True单元格的索引。

2个回答

7
以下内容适用于完全在GPU上运行的PyTorch:
# abs_cosine should be a Tensor of shape (m, m)
mask = torch.ones(abs_cosine.size()[0])
mask = 1 - mask.diag()
sim_vec = torch.nonzero((abs_cosine >= threshold)*mask)

# sim_vec is a tensor of shape (?, 2) where the first column is the row index and second is the column index

以下是numpy中可用的代码:
mask = 1 - np.diag(np.ones(abs_cosine.shape[0]))
sim_vec = np.nonzero((abs_cosine >= 0.2)*mask)
# sim_vec is a 2-array tuple where the first array is the row index and the second array is column index

事实上,如果您初始化mask = torch.ones(img.size())而不是精选第一个大小轴,则PyTorch方法适用于任意形状的ND张量。 - Addison Klinke
如果您的意图是让 sim_vec 排除返回的索引中的对角线元素(根据 OP 的请求),则只有 mask = 1 - mask.diag() 加法是必需的。然而,它可以轻松地被省略以允许所有索引作为有效索引,或者修改以排除一些非对角线子集的索引。 - Addison Klinke
我不确定这是否完全涉及到CUDA。torch.nonzero的文档中说:“当输入在CUDA上时,torch.nonzero()会导致主机设备同步。” - nairbv

0

这比 np.where 快大约两倍

import numba as nb
@nb.njit(fastmath=True)

def get_threshold(abs_cosine,threshold):
  idx=0
  sim_vec=np.empty((abs_cosine.shape[0]*abs_cosine.shape[1],2),dtype=np.uint32)
  for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
      # exclude diagonal cells
      if m != n and abs_cosine[m,n] >= threshold:
        sim_vec[idx,0]=m
        sim_vec[idx,1]=n
        idx+=1

  return sim_vec[0:idx,:]

第一次调用需要大约0.2秒的时间(编译开销)。如果数组在GPU上,可能还有一种方法可以在GPU上完成整个计算。

尽管如此,我对性能并不是很满意,因为简单的布尔运算比上面显示的解决方案快约5倍,比np.where快10倍。如果索引的顺序不重要,这个问题也可以并行化。


顺序不重要。 - GabrielChu

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接