一个numpy数组的累计argmax

6
考虑数组a
np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))
a

array([[0, 2],
       [7, 3],
       [8, 7],
       [0, 6],
       [8, 6],
       [0, 2],
       [0, 4],
       [9, 7],
       [3, 2],
       [4, 3]])

什么是向量化的获取累积argmax的方法?
array([[0, 0],  <-- both start off as max position
       [1, 1],  <-- 7 > 0 so 1st col = 1, 3 > 2 2nd col = 1
       [2, 2],  <-- 8 > 7 1st col = 2, 7 > 3 2nd col = 2
       [2, 2],  <-- 0 < 8 1st col stays the same, 6 < 7 2nd col stays the same
       [2, 2],  
       [2, 2],
       [2, 2],
       [7, 2],  <-- 9 is new max of 2nd col, argmax is now 7
       [7, 2],
       [7, 2]])

这里是一种非向量化的方法。

注意随着窗口扩大,argmax应用于不断增长的窗口。

pd.DataFrame(a).expanding().apply(np.argmax).astype(int).values

array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])
3个回答

8
这里有一个向量化的纯NumPy解决方案,执行起来非常迅速:
def cumargmax(a):
    m = np.maximum.accumulate(a)
    x = np.repeat(np.arange(a.shape[0])[:, None], a.shape[1], axis=1)
    x[1:] *= m[:-1] < m[1:]
    np.maximum.accumulate(x, axis=0, out=x)
    return x

然后我们有:
>>> cumargmax(a)
array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])

对具有数千到数百万个值的数组进行快速测试表明,在Python层面(无论是隐式还是显式)循环相比,这将快10-50倍。


1
为了进一步优化 - 如果 a.ndim == 1:--> x = np.arange(a.shape[0])。这是因为在这里调用np.repeat相对较昂贵。 - Brad Solomon

1

我想不到一种简单地在两列上进行向量化的方法;但如果相对于行数,列数很小,那么这不应该成为问题,可以使用for循环来处理该轴:

import numpy as np
import numpy_indexed as npi
a = np.random.randint(0, 10, (10))
max = np.maximum.accumulate(a)
idx = npi.indices(a, max)
print(idx)

1
我希望创建一个函数,用于计算一维数组的累积 argmax,然后将其应用于所有列。以下是代码:
import numpy as np

np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))

def cumargmax(v):
    uargmax = np.frompyfunc(lambda i, j: j if v[j] > v[i] else i, 2, 1)
    return uargmax.accumulate(np.arange(0, len(v)), 0, dtype=np.object).astype(v.dtype)

np.apply_along_axis(cumargmax, 0, a)

将数据类型转换为np.object,然后再转换回来的原因是为了解决Numpy 1.9中的问题,正如generalized cumulative functions in NumPy/SciPy?中所提到的那样。

2
请注意,frompyfunc仅矢量化语法,而不是性能。这将与朴素的Python循环具有可比性的性能。 - Eelco Hoogendoorn

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接