高效地按元素将numpy数组与自身进行比较

6

我正在执行大量这样的计算:

A == A[np.newaxis].T

其中A是一个密集的numpy数组,经常具有共同的值。

为了基准测试目的,我们可以使用:

n = 30000
A = np.random.randint(0, 1000, n)
A == A[np.newaxis].T

当我执行这个计算时,我遇到了内存问题。我认为这是因为输出不是更高效的bitarray或np.packedbits格式。第二个问题是我们进行了两倍于必要的比较,因为结果的布尔数组是对称的。
我的问题是:
  1. 是否可能以更节省内存的方式生成布尔numpy数组输出,而不会牺牲速度? 我知道的选项是bitarray和np.packedbits,但我只知道如何在创建大型布尔数组后应用它们。
  2. 我们能利用计算的对称性将处理的比较次数减半,而不会牺牲速度吗?
我需要能够对布尔数组执行&和|操作。我已经尝试过bitarray,它非常适用于这些位运算。但是将np.ndarray打包成bitarray然后再解压缩bitarray -> np.ndarray很慢。

@DanielF,在稀疏矩阵上执行布尔运算是否可能 / 高效?我尝试了使用&和|运算,但它们似乎没有被实现。 - jpp
@PaulPanzer,这似乎是个好主意。我们可以应用一个[可能高度优化的]排序算法。然后只需比较相邻元素即可。您能否提供一些代码来展示您如何在numpy中实现这个想法? - jpp
你对这个问题期望的输出是什么?是一个常规的NumPy布尔数组还是一些scipy稀疏矩阵? - Divakar
@Divakar,我一开始使用了布尔numpy数组输出。但是,下面DanielF和PaulPanzer的答案让我倾向于使用坐标(或稀疏矩阵)。因此,我上面的问题是关于是否可以计算csr矩阵A、B.multiply(~)。如果可能的话,那么csr_sparse可能是一个好方法。布尔输出矩阵的用途很多(按行或列进行切片,&与其他矩阵进行比较,|与其他矩阵进行比较等)。 - jpp
或者是 A-(A.multiply(B))。可能会稍微便宜一些。 - Paul Panzer
显示剩余8条评论
4个回答

4

这是一个使用numba的示例,它可以提供一个NumPy布尔数组作为输出:

from numba import njit

@njit
def numba_app1(idx, n, s, out):
    for i,j in zip(idx[:-1],idx[1:]):
        s0 = s[i:j]
        c = 0
        for p1 in s0[c:]:
            for p2 in s0[c+1:]:
                out[p1,p2] = 1
                out[p2,p1] = 1
            c += 1
    return out

def app1(A):
    s = A.argsort()
    b = A[s]
    n = len(A)
    idx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
    out = np.zeros((n,n),dtype=bool)
    numba_app1(idx, n, s, out)
    out.ravel()[::out.shape[1]+1] = 1
    return out

时间 -

In [287]: np.random.seed(0)
     ...: n = 30000
     ...: A = np.random.randint(0, 1000, n)

# Original soln
In [288]: %timeit A == A[np.newaxis].T
1 loop, best of 3: 317 ms per loop

# @Daniel F's soln-1 that skips assigning lower diagonal in output
In [289]: %timeit sparse_outer_eq(A)
1 loop, best of 3: 450 ms per loop

# @Daniel F's soln-2 (complete one)
In [291]: %timeit sparse_outer_eq(A)
1 loop, best of 3: 634 ms per loop

# Solution from this post
In [292]: %timeit app1(A)
10 loops, best of 3: 66.9 ms per loop

1
最终,我喜欢这个解决方案,因为生成的数组可以利用numpy的所有“好”特性,例如~A。如果稀疏矩阵的实现足够通用,使得默认值可以是非零值,那就太好了。 - jpp
稀疏矩阵的实际目的并不是如此。如果允许非零默认值,整个稀疏代数领域就会爆炸。这就像说布尔值应该允许“Maybe” - 是的,少数模糊逻辑博士可能会喜欢它,但实现起来会很困难,对于其他99.99%的人口来说,这只会是另一个要检查错误来源并减慢一切的问题。 - Daniel F
@DanielF 不太明白 - 稀疏矩阵的目的是什么? - Divakar
非零默认值,正如jp所期望的那样。我假设他希望~A成为一个“稀疏”矩阵,其中A.dataFalse,默认值为True。抱歉,应该使用@。 - Daniel F
@DanielF,你当然是正确的。对于非零默认值,整个数学分支都会崩溃。我的想法更多地是:稀疏矩阵可以成为通过 CSR / COO / DOK 算法存储矩阵的更一般化方式的子类 [具有其特殊方法和应用程序]。事实上,我确信可以对自己进行子类化,并覆盖特定用例的某些功能。 - jpp
显示剩余6条评论

2

这并不是一个numpy的解答,但可以使用一些自制的稀疏符号来降低数据需求量

from numba import jit

@jit   # because this is gonna be loopy
def sparse_outer_eq(A):
    n = A.size
    c = []
    for i in range(n):
        for j in range(i + 1, n):
            if A[i] == A[j]:
                 c.append((i, j))
    return c

现在,c 是一个坐标元组列表 (i, j),其中 i < j 对应于布尔数组中为“True”的坐标。您可以轻松地对这些集合进行 andor 运算:
list(set(c1) & set(c2))
list(set(c1) | set(c2))

之后,当你想将这个掩码应用到一个数组上时,你可以回溯坐标并使用它们进行高级索引:

i_, j_ = list(np.array(c).T)
i = np.r_[i_, j_, np.arange(n)]
j = np.r_[j_, i_, np.arange(n)]

如果您关心顺序,您可以使用np.lexsortij进行排序。

或者,您可以将sparse_outer_eq定义为:

@jit
def sparse_outer_eq(A):
    n = A.size
    c = []
    for i in range(n):
        for j in range(n):
            if A[i] == A[j]:
                 c.append((i, j))
    return c

这个功能可以保留超过2倍的数据,但是坐标会变得简单:

 i, j = list(np.array(c).T)

如果您已经进行了任何set操作,如果您想要一个合理的顺序,这仍然需要进行lexsort

如果您的坐标是n位整数,只要您的稀疏度小于1/n->32位约为3%,这应该比布尔格式更节省空间。

关于时间,由于numba的存在,它甚至比广播更快:

n = 3000
A = np.random.randint(0, 1000, n)

%timeit sparse_outer_eq(A)
100 loops, best of 3: 4.86 ms per loop

%timeit A == A[:, None]
100 loops, best of 3: 11.8 ms per loop

以及比较:

a = A == A[:, None]

b = B == B[:, None]

a_ = sparse_outer_eq(A)

b_ = sparse_outer_eq(B)

%timeit a & b
100 loops, best of 3: 5.9 ms per loop

%timeit list(set(a_) & set(b_))
1000 loops, best of 3: 641 µs per loop

%timeit a | b
100 loops, best of 3: 5.52 ms per loop

%timeit list(set(a_) | set(b_))
1000 loops, best of 3: 955 µs per loop

编辑:如果您想执行&~(根据您的评论),请使用第二个sparse_outer_eq方法(这样您就不必跟踪对角线),然后只需执行:

list(set(a_) - set(b_))

这很有用,但有几个问题: (1)性能对n和唯一值的数量非常敏感,例如,当n = 15,000且有200个唯一值时,我发现您的方法要慢3倍。 (2)像~A这样的操作不太容易。我可能需要存储所有坐标的完整集合[或循环遍历它们]并排除集合(a_)中的项目。但请参见https://github.com/scipy/scipy/issues/1166。 - jpp
是的,这种方法在很大程度上依赖于原始问题非常稀疏(1k个唯一值)。如果您的输入和输出矩阵相对密集,则布尔矩阵将更具性能。 - Daniel F
然而,我要指出的是,即使在1位/布尔值的情况下,只要稀疏度小于1/log2(n),这种方法比布尔矩阵更加内存高效,即n<4e9时有32个唯一值,n<65k时有16个唯一值。当然,一个非操作的~将会扩大任何类型的稀疏度,并使布尔更加高效。 - Daniel F

2

这里是几乎标准的argsort解决方案:

import numpy as np

def f_argsort(A):
    idx = np.argsort(A)
    As = A[idx]
    ne_ = np.r_[True, As[:-1] != As[1:], True]
    bnds = np.flatnonzero(ne_)
    valid = np.diff(bnds) != 1
    return [idx[bnds[i]:bnds[i+1]] for i in np.flatnonzero(valid)]

n = 30000
A = np.random.randint(0, 1000, n)
groups = f_argsort(A)

for grp in groups:
    print(len(grp), set(A[grp]), end=' ')
print()

这很不错,但像DanielF的解决方案一样,它不容易计算/存储~A。公平地说,我应该在我的原始查询中明确说明这一点。然而,你的规范argsort函数也被用在Divakar的解决方案中,这是我要选择的解决方案。 - jpp
@jp_data_analysis 没问题,选择最适合你的方法就好。 - Paul Panzer

0
我正在添加一个解决方案到我的问题,因为它满足以下三个属性:
- 低、固定的内存需求 - 快速的位运算(&, |, ~等) - 低存储,每个布尔值1位通过打包整数
缺点是它以np.packbits格式存储。它比其他方法慢得多(特别是argsort),但如果速度不是问题,该算法应该工作良好。如果有人找到了进一步优化的方法,这将非常有帮助。
更新:下面算法的更高效版本可以在这里找到:Improving performance on comparison algorithm np.packbits(A==A[:, None], axis=1)
import numpy as np
from numba import jit

@jit(nopython=True)
def bool2int(x):
    y = 0
    for i, j in enumerate(x):
        if j: y += int(j)<<(7-i)
    return y

@jit(nopython=True)
def compare_elementwise(arr, result, section):
    n = len(arr)

    for row in range(n):
        for col in range(n):

            section[col%8] = arr[row] == arr[col]

            if ((col + 1) % 8 == 0) or (col == (n-1)):
                result[row, col // 8] = bool2int(section)
                section[:] = 0

    return result

A = np.random.randint(0, 10, 100)
n = len(A)
result_arr = np.zeros((n, n // 8 if n % 8 == 0 else n // 8 + 1)).astype(np.uint8)
selection_arr = np.zeros(8).astype(np.uint8)

packed = compare_elementwise(A, result_arr, selection_arr)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接