任务
给定一个 numpy 或者 pytorch 矩阵,找到数值大于给定阈值的单元格的索引。
我的实现
#abs_cosine is the matrix
#sim_vec is the wanted
sim_vec = []
for m in range(abs_cosine.shape[0]):
for n in range(abs_cosine.shape[1]):
# exclude diagonal cells
if m != n and abs_cosine[m][n] >= threshold:
sim_vec.append((m, n))
问题
速度。所有其他计算都是基于Pytorch构建的,使用numpy
已经是一种妥协,因为它已将计算从GPU移动到CPU。纯python for
循环将使整个过程变得更糟(对于小型数据集,速度已经慢了5倍)。我想知道是否可以将整个计算移动到Numpy(或pytorch)而不调用任何for
循环?
我能想到的一个改进(但卡住了...)
bool_cosine = abs_cosine > 阈值
它返回一个布尔矩阵,其中包含True
和False
。但我找不到一种快速检索True
单元格的索引。
mask = torch.ones(img.size())
而不是精选第一个大小轴,则PyTorch方法适用于任意形状的ND张量。 - Addison Klinkesim_vec
排除返回的索引中的对角线元素(根据 OP 的请求),则只有mask = 1 - mask.diag()
加法是必需的。然而,它可以轻松地被省略以允许所有索引作为有效索引,或者修改以排除一些非对角线子集的索引。 - Addison Klinketorch.nonzero
的文档中说:“当输入在CUDA上时,torch.nonzero()会导致主机设备同步。” - nairbv