高效地遍历numpy数组测试多个元素

4
我有以下代码,它迭代名为“m”的2D NumPy数组。 它的运行速度非常慢。 如何使用NumPy函数改写此代码,以避免使用for循环?
pairs = []
for i in range(size):
    for j in range(size):
        if(i >= j):
            continue
        if(m[i][j] + m[j][i] >= 0.75):
            pairs.append([i, j, m[i][j] + m[j][i]])

这个数组的维度是多少? - Eduardo Soares
about 5000x5000 - nota
2个回答

6
您可以使用NumPy来进行向量化处理。思路如下:
  • 首先初始化一个矩阵m,然后创建m+m.T,这相当于m[i][j] + m[j][i],其中m.T是矩阵的转置,并将其称为summ
  • np.triu(summ)返回矩阵的上三角部分(这相当于在代码中使用continue忽略下部分)。这避免了在代码中明确使用if(i >= j):。这里您必须使用k=1来排除对角元素。默认情况下,k=0也包括对角线元素。
  • 然后,您可以使用np.argwhere获取点的索引,其中和m+m.T之和大于等于0.75
  • 然后,您可以将这些索引和相应的值存储在列表中以供稍后处理/打印。

可验证的示例(使用小型3x3随机数据集)

import numpy as np

np.random.seed(0)
m = np.random.rand(3,3)
summ = m + m.T

index = np.argwhere(np.triu(summ, k=1)>=0.75)

pairs = [(x,y, summ[x,y]) for x,y in index]
print (pairs)
# # [(0, 1, 1.2600725493693163), (0, 2, 1.0403505873343364), (1, 2, 1.537667113848736)]

进一步提高性能

我刚刚想出了一个更快的方法来生成最终的pairs列表,避免使用显式的for循环。

pairs = list(zip(index[:, 0], index[:, 1], summ[index[:,0], index[:,1]]))

建议您在顶部添加 np.random.seed(0) 并重新运行以获得可重复的结果。 - 8one6
将我的程序执行时间从55秒减少到1.5秒。非常感谢! - nota
1
@nota:我稍微编辑了一下,加入了 k=1,因为您不希望处理对角线上的元素。 - Sheldore
1
@nota:为了获得更快的速度,请查看我的编辑,使用“zip”。 - Sheldore

5
优化代码的方法之一是避免使用比较语句 if (i >= j)。若要遍历数组的下三角,可以让内层循环从外层循环的变量的值开始。这样做可以避免进行规模为size x sizeif比较。
import numpy as np
size = 5000
m = np.random.rand(size, size)
pairs = []


for i in range(size):
    for j in range(i , size):

        if(m[i][j] + m[j][i] >= 0.75):
            pairs.append([i, j, m[i][j] + m[j][i]])

最好先定义大小,然后在 m = np.random.rand(size, size) 中使用它。 - Sheldore

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接