在numpy矩阵中查找每行的最大值及其不同列索引的方法

4
我有一个规模约为10000行和10000列的大型`numpy`浮点类型矩阵。对于每一行,我需要找到具有最大值的列索引,并且每个列最多只选择一次。
例如,对于给定的数组`arr`,我需要将输出作为`(行索引,列索引)`元组的`list`/`array`,如`out`所示:
arr = np.array(
   [[0.86, 0.23, 0.83, 0.79],
   [0.15, 0.98, 0.86, 0.47],
   [1.  , 0.08, 0.01, 0.04],
   [0.78, 0.82, 0.17, 0.56],
   [0.73, 0.91, 0.52, 0.31]])
out = [(0,0),(1,1),(2,3),(3,2)]

说明:

  • 初始时out为空。
  • 对于row 0,在column 0中最大值为0.86,因此out现在为[(0,0)]
  • 对于row 1,在column 1中最大值为0.98,而且column 1还没有在out中出现过,因此out现在为[(0,0),(1,1)]
  • 对于row 2,在column 0中最大值是1,但是column 0已经被选择了,所以我们查找下一个最大值,即column 1中的0.08,它也出现在out中,然后是下一个最大值,即column 3中的0.04,因此out现在为[(0,0),(1,1),(2,3)]
  • 同样地,对于row 3,最大值还未被选择的列是column 2,因此最终的out[(0,0),(1,1),(2,3),(3,2)]

我希望尽可能高效地计算它。使用两个for循环的O(n2)解决方案很简单,因此任何比这更好的解决方案(无论是更好的时间复杂度还是使用内置的numpy函数更好的运行时)都将非常有帮助。

2个回答

2

如果您愿意使用cython,那么迭代速度可以有所提升。

# distutils: language = c++
# cython: boundscheck = False

from libcpp.set cimport set as cset
from libc.math cimport INFINITY

def method2(double[:, :] x):
    cdef:
        int nrows, ncols, i, j, best_j
        double best_value
        # Define the set of columns that have already been used.
        cset[int] usedcols
        
    nrows = x.shape[0]
    ncols = x.shape[1]
    out = []
    
    for i in range(nrows):
        best_value = -INFINITY
        best_idx = -1
        
        # Find the largest value for each row that's not already used.
        for j in range(ncols):
            if x[i, j] > best_value and usedcols.find(j) == usedcols.end():
                best_value = x[i, j]
                best_j = j
        out.append((i, best_j))
        usedcols.insert(best_j)
    
    return out

假设Samarth的解决方案是method1,以下是性能比较。
x = np.random.normal(0, 1, (10000, 10000))
%timeit method1(x)  # 770 ms ± 4.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit method2(x)  # 57.7 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

1
假设矩阵中的最小值为0。
for i in range(len(arr)): # for the ith row
    maxcol = np.argmax(arr[i])
    arr[:,maxcol] = np.zeros( len(arr[:,maxcol]) )
    out.append((i,maxcol))

这个方法是通过找到每一行最大列的索引(使用argmax),然后将该列的所有元素设置为任意最小值(这里用np.zeros高效地给出了0)。

这应该可以工作,我正在尝试它。而且肯定比O(n^2)少,但不知道具体少多少。

编辑:我尝试了一下,out[(0, 0), (1, 1), (2, 3), (3, 2), (4, 0)]。所以它确实有所期望的输出,但有一个额外的项(4,0)。[这是因为迭代了所有行]。如果你想要迭代直到你“用完”列,你可以使用min(arr.shape)而不是len(arr)


时间复杂度

我们运行了一个 for 循环,它的时间复杂度为 O(n)
在循环中,我们使用了 np.argmax,其时间复杂度也为 O(n)
然后是常数时间的替换操作 [O(1)]

因此,它并不是很优化。你可以编写自己的代码,直接编码边界情况和其他改进,如仅检查所需部分等...


1
有趣的是,我刚刚写了几乎完全相同的代码。但我认为你应该用-np.inf替换np.zeros(len(arr[:,maxcol]))。这样更快更干净。 - Lukas S
是的,特别是如果实际矩阵也可能有负值。 - pu239
感谢您的想法。如果能对一个大数组(10000x10000)进行运行时比较,并与两个for循环的基准进行比较,那就更好了。虽然它是O(n^2),但由于numpy函数的存在,它可能具有更好的运行时间。 - Kapil
@Kapil 我试着用10^5乘以10^5个随机数来运行它,结果内存不足 :) - Lukas S
也许是10^4乘以10^4 - Kapil
@Kapil 对于10^4乘以10^5,使用-np.inf替换需要17.6秒,而不使用它则需要38.8秒。 - Lukas S

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接