Numpy的向量化(1d)版本itertools.combinations

3

我正在尝试实现一个向量化的 Numpy 版本的 itertools.combinations,以便我可以使用 @jit(Numba)进行修饰 - 使其更快。(我不确定是否可以完成这个任务)

我正在处理的数据集是一个一维 np.array,目标是获取三元组组合。

x = np.array([1,2,3,4,5])

set(combinations(x, 3))

Result:

{(1, 2, 3),
 (1, 2, 4),
 (1, 2, 5),
 (1, 3, 4),
 (1, 3, 5),
 (1, 4, 5),
 (2, 3, 4),
 (2, 3, 5),
 (2, 4, 5),
 (3, 4, 5)}

我一直在搜索堆栈和其他资源,我确实发现了关于这个主题的很多信息,但是没有任何有用的信息适用于我的用例。

numpy中itertools.combinations的N-D版本

这个线程表明,很可能很难找到比itertools.combinations更快的东西:

itertools.combinations的numba安全版本?


我被要求更具体地说明用例:

Example of regression

  1. 我有一个价格数据 np.array price= ([100,101,102,103,104,105 等等..])

  2. 我使用 scipy.signal.argrelmax 找到数组中的所有峰值。(上图中的红点。)

  3. 现在我需要获取所有峰值的组合。

  4. 然后我将对所有组合运行简单的线性回归,并寻找特定的 r_val 阈值,以验证趋势线。(上图中的绿点)

(选择 3 的组合是因为这样我就知道趋势线有 3 个触点。在我发布的插图中有 4 个触点,所以我会处理 3 或 4 个触点。)

(此外,我使用趋势线上方/下方的积分进行过滤)

我不需要 combi 函数返回一组元组,我编写的简单线性回归算法已经针对 numpy 进行了优化。


1
你能告诉我们更多关于你的用例吗?你真的需要输出为元组集合吗?创建这样的数据结构是不高效的,因为每个元组都需要单独分配、哈希,然后存储在CPython动态集合数据结构中。Numba不支持这种数据结构的高效实现(但支持类型化集合和类型化元组),因为它们本质上是动态的。输入数组是小还是大?元组中的项目数量总是3吗?这些细节对于编写快速代码至关重要。通用的高级代码几乎从来不高效。 - Jérôme Richard
@JérômeRichard 很好的问题!- 我在我的问题中添加了更多信息。如果您需要更多信息,请询问。 - traderblakeq
1个回答

阿里云服务器只需要99元/年,新老用户同享,点击查看详情
1

itertools.combinations在CPython中不是最佳选择,因为它必须为每个组合调用一个C函数,并且还因为它必须创建一个新的元组。在CPython中调用C函数和分配对象都是有一定代价的。此外,这段代码是通用的,而实际上可以根据你的情况将其专门化,例如生成三元组或四元组的组合。如果输入的x由排序的唯一项组成,则计算set肯定是不需要的,因为组合已经保证是唯一的并且按字典顺序排序的。可以使用二分查找来查找一个数组是否是另一个数组的排列。实际上,在这种情况下,可以只对数组进行排序并进行比较。Numba函数可以在Numpy数组中生成所有组合。首先让我们编写一个优化的通用函数:

import numba as nb
import numpy as np

@nb.njit('(int64, int64)')
def combCount(n, r):
    if r < 0:
        return 0
    res = 1
    if r > n - r:
        r = n - r
    for i in range(r):
        res *= (n - i)
        res //= (i + 1)
    return res

@nb.njit(inline='always')
def genComb_generic(arr, r):
    n = arr.size
    out = np.empty((combCount(n, r), r), dtype=arr.dtype)
    idx = np.empty(r, dtype=np.int32)

    for i in range(r):
        idx[i] = i

    i = r - 1

    cur = 0
    while idx[0] < n - r + 1:
        while i > 0 and idx[i] == n - r + i:
            i -= 1
        for j in range(r):
            out[cur, j] = arr[idx[j]]
        cur += 1
        idx[i] += 1
        while i < r - 1:
            idx[i + 1] = idx[i] + 1
            i += 1
    return out

@nb.njit
def genComb_x3(arr):
    return genComb_generic(arr, 3)

@nb.njit
def genComb_x4(arr):
    return genComb_generic(arr, 4)

x = np.array([1,2,3,4,5])
genComb_x3(x) # generate triplets based on `x`

genComb_generic是一种快速的算法,特别适用于需要生成大量输出的情况。但是如果已知r的值,它还可以更快。以下是一个特定的案例,其中r=3(三元组):

@nb.njit
def genComb_x3(arr):
    n = arr.size
    nComb = combCount(n, 3)
    out = np.empty((nComb, 3), dtype=arr.dtype)
    a, b, c = 0, 1, 2
    arr_a = arr[a]
    arr_b = arr[b]
    for cur in range(nComb):
        out[cur, 0] = arr_a
        out[cur, 1] = arr_b
        out[cur, 2] = arr[c]
        if c < n - 1:
            c += 1
        else:
            if b < n - 2:
                b, c = b + 1, b + 2
                arr_b = arr[b]
            else:
                a, b, c = a + 1, a + 2, a + 3
                arr_a = arr[a]
                arr_b = arr[b]
    return out
这个最后的实现与itertools相比在输出相对较大时非常快。当x.size为30时,它比itertools快68倍。当x.size为5时,只快了约2倍。这是因为从CPython调用Numba函数会有重要的开销(就像Numpy一样),更不用说分配输出的时间了。在我的机器上(i5-9600KF),这个开销大约为800纳秒。你可以通过从另一个Numba函数中调用这样的函数并在可能的情况下预分配输出来大大减少开销。最终,在这种情况下,Numba函数应该比itertools实现快80倍。 如果你计划生成很多组合,那么最好是即时计算它们而不是生成一个巨大的数组(因为RAM往往很慢)。这应该会更快。

1
我会在周一实现它,并给你统计数据。- 我从你的答案中学到了很多!谢谢。 - traderblakeq
1
它能像魔术一样运行 - 而且,它在我的使用案例中比Iter.comb快得多! - traderblakeq

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,