重新开始累加和,获取累加和大于某个值的索引

14

假设我有一个距离数组x=[1,2,1,3,3,2,1,5,1,1]

我想要得到满足累积总和为10时的x数组下标,即idx=[4,9]。

因此,在符合条件后,累积总和会重新开始计算。

我可以使用循环来实现,但对于大型数组,循环速度较慢,我想知道是否可以以向量化的方式进行。


难以通过向量化方法实现 - BENY
嗯,看起来像是 @WeNYoBen - yatu
1
@yatu 我想我之前回答过这种类型的问题 :-) ...更好的解决方案可能是使用numba.. - BENY
3
如果你想在累加和达到10时重置cumsum,可以尝试使用以下代码:x.cumsum()%10。如果x是一个数组,这个方法非常快速。但是这种方法可能不适合你,因为如何处理边缘情况还不清楚。例如,如果cumsum等于11,它应该重置为0还是1呢? - Brenlla
无论如何都需要循环。但是为了速度,您希望在快速编译的代码中进行循环。对于这种固有的顺序任务,NumPy的标准工具更加有限。 - hpaulj
显示剩余2条评论
3个回答

13
一个有趣的方法
sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1)
newx=sumlm.accumulate(x, dtype=np.object)
newx
array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object)
np.nonzero(newx==10)

(array([4, 9]),)

请使用 np.random.seed([3, 1415]); x = np.random.randint(100, size=1_000_000).tolist() 进行检查。 - piRSquared
在其他测试中,frompyfunc 往往比更明确的 Python 循环快 2 倍。 - hpaulj
我没有意识到frompyfunc会生成一个带有像accumulate这样的方法的ufunc - hpaulj
@hpaulj 啊,早上刚发现它 :-) - BENY

10

循环并不总是坏的(尤其是当你需要循环时)。另外,没有工具或算法可以比O(n)更快地完成此操作。因此,让我们编写一个好的循环。

生成器函数

def cumsum_breach(x, target):
    total = 0
    for i, y in enumerate(x):
        total += y
        if total >= target:
            yield i
            total = 0

list(cumsum_breach(x, 10))

[4, 9]

使用Numba进行即时编译

Numba是一个需要安装的第三方库。
Numba对支持的特性非常挑剔,但它可以正常工作。
此外,正如Divakar所指出的那样,Numba在处理数组方面具有更好的性能。

from numba import njit

@njit
def cumsum_breach_numba(x, target):
    total = 0
    result = []
    for i, y in enumerate(x):
        total += y
        if total >= target:
            result.append(i)
            total = 0

    return result

cumsum_breach_numba(x, 10)

测试这两个东西

因为我想这么做 ¯\_(ツ)_/¯

设置


np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()

准确度

i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))

assert i0 == i1

时间

%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))

582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numba的速度大约是快100倍。

为了进行更加公正的比较,我将一个列表转换成了Numpy数组。

%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))

43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

这使它们差不多平衡了。

这应该足够快了。对于非负数组,当 total > 10 时我们可以使用 break 来使它更快。 - Quang Hoang
可能可以通过Numba加速 :-) 顺便说一句,祝7-4快乐 - BENY
您好,能否测试一下速度,想知道使用 frompyfunc 和 numba 哪个更加快速。谢谢! - BENY
@piRSquared 转换为数组的成本并不高。因此,总体来说还不错。 - Divakar
@Divakar请查看这个问题。我尝试通过参考上面的解决方案来解决,但是我没有成功... https://stackoverflow.com/questions/61346334/groupby-cumulative-in-pandas-then-update-using-numpy-based-specific-condition - Danish
显示剩余3条评论

10

这里有一个与numba和数组初始化有关的例子 -

from numba import njit

@njit
def cumsum_breach_numba2(x, target, result):
    total = 0
    iterID = 0
    for i,x_i in enumerate(x):
        total += x_i
        if total >= target:
            result[iterID] = i
            iterID += 1
            total = 0
    return iterID

def cumsum_breach_array_init(x, target):
    x = np.asarray(x)
    result = np.empty(len(x),dtype=np.uint64)
    idx = cumsum_breach_numba2(x, target, result)
    return result[:idx]

时序

包括@piRSquared的解决方案,并使用同一篇帖子中的基准测试设置-

In [58]: np.random.seed([3, 1415])
    ...: x = np.random.randint(100, size=1000000).tolist()

# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop

# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop

# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop

Numba: 追加 vs 数组初始化

为了更加深入地研究数组初始化的帮助作用,这似乎是两个 Numba 实现之间的重大区别,让我们对数组数据进行计时,因为数组数据的创建本身对运行时间影响很大,它们都依赖于它 -

In [62]: x = np.array(x)

In [63]: %timeit cumsum_breach_numba(x, 10)# with appending
10 loops, best of 3: 31.5 ms per loop

In [64]: %timeit cumsum_breach_array_init(x, 10)
1000 loops, best of 3: 1.8 ms per loop

为了让输出具有自己的内存空间,我们可以进行复制。虽然不会对事情产生太大的影响 -
In [65]: %timeit cumsum_breach_array_init(x, 10).copy()
100 loops, best of 3: 2.67 ms per loop

它可能不会有太大的差异,但我认为为了完全公正,您应该返回 result [:idx] 的副本,以避免泄漏 result [idx:] 的内存。 - Paul Panzer
很棒的答案!不意外,只是说一下。 - piRSquared
@piRSquared 肯定受到了你的激励。我认为这种基于数组初始化的方法可以用于需要在输出数组大小未知时使用附加操作的 numba 解决方案。而且,数组数据与 numba 更配哦。所以,这就是这个问题和答案带来的两个收获。 - Divakar
是的,很棒的技巧。我会使用它。 - piRSquared

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接