获取数组中匹配元素的索引,考虑到重复出现的情况。

3

我希望在Numpy中实现类似于SQL WHERE 表达式的功能,涉及到IT技术。假设我有两个数组,如下所示:

import numpy as np
dt = np.dtype([('f1', np.uint8), ('f2', np.uint8), ('f3', np.float_)])
a = np.rec.fromarrays([[3,    4,    4,   7,    9,    9],
                       [1,    5,    5,   4,    2,    2],
                       [2.0, -4.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
b = np.rec.fromarrays([[ 1,    4,   7,    9,    9],
                       [ 7,    5,   4,    2,    2],
                       [-3.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)

我希望能够返回原始数组的索引,以便涵盖每个可能的匹配对。此外,我想利用两个数组都已排序的事实,因此不需要最坏情况下的 O(mn) 算法。在本例中,由于 (4, 5, -4.5) 匹配,但在第一个数组中出现了两次,因此它会在结果索引中出现两次,并且由于 (9, 2, 24.3) 在两个数组中都出现了两次,因此它将出现总共 4 次。由于 (3, 1, 2.0) 在第二个数组中不存在,因此将被跳过,第二个数组中的 (1, 7, -3.5) 也是如此。该函数应适用于任何 dtype。
在这种情况下,结果可能如下:
a_idx, b_idx = match_arrays(a, b)
a_idx = np.array([1, 2, 3, 4, 4, 5, 5])
b_idx = np.array([1, 1, 2, 3, 4, 3, 4])

同样输出结果的另一个例子:

dt2 = np.dtype([('f1', np.uint8), ('f2', dt)])
a2 = np.rec.fromarrays([[3, 4, 4, 7, 9, 9], a], dtype=dt2)
b2 = np.rec.fromarrays([[1, 4, 7, 9, 9], b], dtype=dt2)

我有一个纯Python实现,但在我的使用情况下速度非常慢。我希望有更多的向量化操作。以下是我目前已经实现的代码:

def match_arrays(a, b):
    len_a = len(a)
    len_b = len(b)

    a_idx = []
    b_idx = []

    i, j = 0, 0

    first_matched_j = 0

    while i < len_a and j < len_b:
        matched = False
        j = first_matched_j

        while j < len_b and a[i] == b[j]:
            a_idx.append(i)
            b_idx.append(j)
            if not matched:
                matched = True
                first_matched_j = j

            j += 1
        else:
            i += 1

        j = first_matched_j

        while i < len_a and j < len_b and a[i] > b[j]:
            j += 1
            first_matched_j = j

        while i < len_a and j < len_b and a[i] < b[j]:
            i += 1

    return np.array(a_idx), np.array(b_idx)

编辑:正如Divakar在他的答案中指出的那样,我可以使用a_idx, b_idx = np.where(np.equal.outer(a, b))。然而,这似乎是我想要避免的最坏情况O(mn)解决方案,通过对数组进行预排序可以避免这种情况。特别地,在没有重复项的情况下,如果能做到O(m + n)就太好了。

编辑2:Paul Panzer答案如果只使用Numpy,则不是O(m + n),但通常会更快。此外,他还提供了一个O(m + n)的答案,所以我接受了那个答案。我很快就会使用timeit进行性能比较。

编辑3:如承诺的那样,这里是性能测试结果:

╔════════════════╦═══════════════════╦═══════════════════╦═══════════════════╦══════════════════╦═══════════════════╗
║ User           ║ Version           ║ n = 10 ** 2       ║ n = 10 ** 4       ║ n = 10 ** 6      ║ n = 10 ** 8       ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Paul Panzer    ║ USE_HEAPQ = False ║ 115 µs ± 385 ns   ║ 793 µs ± 8.43 µs  ║ 105 ms ± 1.57 ms ║ 18.2 s ± 116 ms   ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ USE_HEAPQ = True  ║ 189 µs ± 3.6 µs   ║ 6.38 ms ± 28.8 µs ║ 650 ms ± 2.49 ms ║ 1min 11s ± 420 ms ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ SigmaPiEpsilon ║ Generator         ║ 936 µs ± 1.52 µs  ║ 9.17 s ± 57 ms    ║ N/A              ║ N/A               ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ for loop          ║ 144 µs ± 526 ns   ║ 15.6 ms ± 18.6 µs ║ 1.74 s ± 33.9 ms ║ N/A               ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Divakar        ║ np.where          ║ 39.1 µs ± 281 ns  ║ 302 ms ± 4.49 ms  ║ Out of memory    ║ N/A               ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ recarrays 1       ║ 69.9 µs ± 491 ns  ║ 1.6 ms ± 24.2 µs  ║ 230 ms ± 3.52 ms ║ 41.5 s ± 543 ms   ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ recarrays 2       ║ 82.6 µs ± 1.01 µs ║ 1.4 ms ± 4.51 µs  ║ 212 ms ± 2.59 ms ║ 36.7 s ± 900 ms   ║
╚════════════════╩═══════════════════╩═══════════════════╩═══════════════════╩══════════════════╩═══════════════════╝

看起来Paul Panzer答案USE_HEAPQ = False获胜。 我原本以为对于大型输入,USE_HEAPQ = True会获胜,因为它是O(m + n),但事实并非如此。另外一个评论是,USE_HEAPQ = False版本使用了更少的内存,最多只有5.79GB,而USE_HEAPQ = True则需要10.18GB,对于n=10**8。请记住,这是进程内存,并包括控制台的输入和其他内容。Divakar的recarrays答案1使用了8.42 GB的内存,recarrays答案2使用了10.61 GB。

你能否将我在帖子中新增的“方法3:通用情况”加入到时间测试中? - Divakar
在编辑的样本中出现了“ValueError: mismatch between the number of fields and the number of arrays”错误。我使用的是1.13.3 NumPy和Python 2.7。 - Divakar
@Divakar 已修复,谢谢。 - Hameer Abbasi
3个回答

2

方法一:基于广播的方法

使用两个数组之间的outer相等比较来利用矢量化的broadcasting,然后获取行、列索引,这将是与两个数组对应的匹配索引所需的 -

a_idx, b_idx = np.where(a[:,None]==b)
a_idx, b_idx = np.where(np.equal.outer(a,b))

我们还可以使用np.nonzero代替np.where
方法#2:特定情况的解决方案
如果输入数组没有重复项并且已排序,我们可以使用np.searchsorted,如下所示-
idx0 = np.searchsorted(a,b)
idx1 = np.searchsorted(b,a)
idx0[idx0==len(a)] = 0
idx1[idx1==len(b)] = 0

a_idx = idx0[a[idx0] == b]
b_idx = idx1[b[idx1] == a]

修改一下,可能更有效率的方法是 -
idx0 = np.searchsorted(a,b)
idx0[idx0==len(a)] = 0

a_idx = idx0[a[idx0] == b]
b_idx = np.searchsorted(b,a[a_idx])

方法三:通用情况

这是一个适用于通用情况(允许重复)的解决方案 -

def findwhere(a, b):
    c = np.bincount(b, minlength=a.max()+1)[a]
    a_idx1 = np.repeat(np.flatnonzero(c),c[c!=0])
    
    b_idx1 = np.searchsorted(b,a[a_idx1])
    m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
    idx11 = np.flatnonzero(m1[1:] != m1[:-1])
    id_arr = m1.astype(int)
    id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
    b_idx1 += id_arr.cumsum()[:-1]
    return a_idx1, b_idx1

时间

使用 @Paul Panzer 的解决方案中的 mock_data 来设置输入:

In [295]: a, b = mock_data(1000000)

# @Paul Panzer's soln
In [296]: %timeit sqlwhere(a, b) # USE_HEAPQ = False
10 loops, best of 3: 118 ms per loop

# Approach #3 from this post
In [297]: %timeit findwhere(a,b)
10 loops, best of 3: 61.7 ms per loop

将记录数组(uint8数据)转换为1D数组的工具

def convert_recarrays_to_1Darrs(a, b):
    a2D = a.view('u1').reshape(-1,2)
    b2D = b.view('u1').reshape(-1,2)
    s = max(a2D[:,0].max(), b2D[:,0].max())+1
    
    a1D = s*a2D[:,1] + a2D[:,0]
    b1D = s*b2D[:,1] + b2D[:,0]
    return a1D, b1D

样例运行 -

In [90]: dt = np.dtype([('f1', np.uint8), ('f2', np.uint8)])
    ...: a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
    ...:                        [1, 5, 5, 4, 2, 2]], dtype=dt)
    ...: b = np.rec.fromarrays([[1, 4, 7, 9, 9],
    ...:                        [7, 5, 4, 2, 2]], dtype=dt)

In [91]: convert_recarrays_to_1Darrs(a, b)
Out[91]: 
(array([13, 54, 54, 47, 29, 29], dtype=uint8),
 array([71, 54, 47, 29, 29], dtype=uint8))

通用版本来覆盖rec-arrays

版本 #1 :

def findwhere_generic_v1(a, b):
    cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
    count = np.diff(cidx)
    b_starts = b[cidx[:-1]]
    
    a_starts = np.searchsorted(a,b_starts)
    a_starts[a_starts==len(a)] = 0
    
    valid_mask = (b_starts == a[a_starts])
    count_valid = count[valid_mask]
    
    idx2m0 = np.searchsorted(a,b_starts[valid_mask],'right')    
    idx1m0 = a_starts[valid_mask]
    
    id_arr = np.zeros(len(a)+1, dtype=int)
    id_arr[idx2m0] -= 1
    id_arr[idx1m0] += 1
    
    n = idx2m0 - idx1m0
    r1 = np.flatnonzero(id_arr.cumsum()!=0)
    r2 = np.repeat(count_valid,n)
    a_idx1 = np.repeat(r1, r2)
    
    b_idx1 = np.searchsorted(b,a[a_idx1])
    m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
    idx11 = np.flatnonzero(m1[1:] != m1[:-1])
    id_arr = m1.astype(int)
    id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
    b_idx1 += id_arr.cumsum()[:-1]
    return a_idx1, b_idx1

版本 #2:

def findwhere_generic_v2(a, b):    
    cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
    count = np.diff(cidx)
    b_starts = b[cidx[:-1]]
    
    idxx = np.flatnonzero(np.r_[True,a[1:] != a[:-1],True])
    av = a[idxx[:-1]]
    idxxs = np.searchsorted(av,b_starts)
    idxxs[idxxs==len(av)] = 0
    valid_mask0 = av[idxxs] == b_starts
    
    starts = idxx[idxxs]
    stops = idxx[idxxs+1]
    
    idx1m0 = starts[valid_mask0]
    idx2m0 = stops[valid_mask0]  
    
    count_valid = count[valid_mask0]
    
    id_arr = np.zeros(len(a)+1, dtype=int)
    id_arr[idx2m0] -= 1
    id_arr[idx1m0] += 1
    
    n = idx2m0 - idx1m0
    r1 = np.flatnonzero(id_arr.cumsum()!=0)
    r2 = np.repeat(count_valid,n)
    a_idx1 = np.repeat(r1, r2)
    
    b_idx1 = np.searchsorted(b,a[a_idx1])
    m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
    idx11 = np.flatnonzero(m1[1:] != m1[:-1])
    id_arr = m1.astype(int)
    id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
    b_idx1 += id_arr.cumsum()[:-1]
    return a_idx1, b_idx1

我本来也想这样做,但不幸的是它需要太多内存并且如果没有很多索引匹配的话会变得非常慢。这是因为它没有将数组排序的事实考虑在内。我曾希望能够使用np.searchsortednp.repeat之类的方法,但我无法理解。 - Hameer Abbasi
哇,我刚看到了您最新的编辑。能否将其通用化?我的实际数组是numpy结构化数组。也许可以使用np.uniquereturn_counts=True - Hameer Abbasi
@HameerAbbasi 我们谈论代表性样本数据吧。嗯,发布样本数据? - Divakar
@Divakar 刚刚完成了。另外,你的示例还有一个小问题,如果数组中的值很高,由于bincount操作,它仍然不会是O(m + n) - Hameer Abbasi
实际上,np.bincount([1000])是一个形状为(1001,)的数组,只有最后一个元素等于1。minlength只会使它变得更长。所以...如果我在里面放大数字(我确实这样做了),它就会崩溃,特别是当数组有7个字段之类的时候。我有99%的把握,任何你可以不转换为整数而只使用某种方式的相等性来使其工作的东西都会起作用。 - Hameer Abbasi
显示剩余9条评论

2
这是一个 O(n)-ish 解决方案(ish 是因为如果重复很长,它显然不能是 O(n))。在实践中,根据输入长度,可以通过牺牲 O(n) 并将 heapq.merge 替换为稳定的 np.argsort 来节省一些时间。目前,在 N=10^6 的情况下,它大约需要一秒钟。
代码:
import numpy as np

USE_HEAPQ = True

def sqlwhere(a, b):
    asw = np.r_[0, 1 + np.flatnonzero(a[:-1]!=a[1:]), len(a)]
    bsw = np.r_[0, 1 + np.flatnonzero(b[:-1]!=b[1:]), len(b)]
    al, bl = np.diff(asw), np.diff(bsw)
    na, nb = len(al), len(bl)
    abunq = np.r_[a[asw[:-1]], b[bsw[:-1]]]
    if USE_HEAPQ:
        from heapq import merge
        m = np.fromiter(merge(range(na), range(na, na+nb), key=abunq.__getitem__), int, na+nb)
    else:
        m = np.argsort(abunq, kind='mergesort')
    mv = abunq[m]
    midx = np.flatnonzero(mv[:-1]==mv[1:])
    ai, bi = m[midx], m[midx+1] - na
    aic = np.r_[0, np.cumsum(al[ai])]
    a_idx = np.ones((aic[-1],), dtype=int)
    a_idx[aic[:-1]] = asw[ai]
    a_idx[aic[1:-1]] -= asw[ai[:-1]] + al[ai[:-1]] - 1
    a_idx = np.repeat(np.cumsum(a_idx), np.repeat(bl[bi], al[ai]))
    bi = np.repeat(bi, al[ai])
    bic = np.r_[0, np.cumsum(bl[bi])]
    b_idx = np.ones((bic[-1],), dtype=int)
    b_idx[bic[:-1]] = bsw[bi]
    b_idx[bic[1:-1]] -= bsw[bi[:-1]] + bl[bi[:-1]] - 1
    b_idx = np.cumsum(b_idx)
    return a_idx, b_idx

def f_D(a, b):
    return np.where(np.equal.outer(a,b))

def mock_data(n):
    return np.cumsum(np.random.randint(0, 3, (2, n)), axis=1)


a = np.array([3, 4, 4, 7, 9, 9], dtype=np.uint8)
b = np.array([1, 4, 7, 9, 9], dtype=np.uint8)

# check correct
a, b = mock_data(1000)
ai0, bi0 = f_D(a, b)
ai1, bi1 = sqlwhere(a, b)
print(np.all(ai0 == ai1), np.all(bi0 == bi1))

# check fast
a, b = mock_data(1000000)
sqlwhere(a, b)

如果您为使用np.argsortUSE_HEAPQ版本提供了else子句,我将很高兴编辑我的问题并接受它作为答案。 - Hameer Abbasi
@HameerAbbasi 完成。 - Paul Panzer
我必须要给你很多信任,我编写了一个Cython版本的@SigmaPiEpsilon的答案,但它仍然无法胜过你的。 - Hameer Abbasi
@HameerAbbasi 如果你已经在使用Cython,可以尝试将heapq.merge进行Cython化。如果你硬编码n=2和key=abunq.__getitem__的话,应该不会太困难。 - Paul Panzer

1

纯Python实现

生成器推导式

使用生成器和列表推导式的另一种纯Python实现。与您的代码相比,可能更加内存高效,但可能比numpy版本慢。对于排序数组,这将更快。

def pywheregen(a, b):

    l = ((ia,ib) for ia,j in enumerate(a) for ib,k in enumerate(b) if j == k)
    a_idx,b_idx = zip(*l)
    return a_idx,b_idx

Python循环考虑数组排序

这里有一个替代版本,使用简单的Python for循环,并考虑到数组已排序,因此只检查它需要的对。

def pywhere(a, b):

    l = []
    a.sort()
    b.sort()
    match = 0
    for ia,j in enumerate(a):
        ib = match
        while ib < len(b) and j >= b[ib]:
            if j == b[ib]:
                l.append(((ia,ib)))
                if b[match] < b[ib]:
                    match = ib
            ib += 1

    a_ind,b_ind = zip(*l)

    return a_ind, b_ind

时间

我使用@Paul Panzer的mock_data()函数比较了使用findwhere()f_D()以及@Divakar的np.outer方法。 findwhere()仍然表现最佳,但考虑到它是纯Python,pywhere()表现也不错。 pywheregen()失败了,而出人意料的是f_D()花费了更长的时间。它们都在N = 10^6时失败了。由于heapq模块中的一个无关错误,我无法运行sqlwhere。

In [2]: a, b = mock_data(10000)
In [10]: %timeit -n 10 findwhere(a,b)                                     
10 loops, best of 3: 1.62 ms per loop

In [11]: %timeit -n 10 pywhere(a,b)                                       
10 loops, best of 3: 20.6 ms per loop

In [12]: %timeit pywheregen(a,b)                                          
1 loop, best of 3: 12.7 s per loop

In [13]: %timeit -n 10 f_D(a,b)                                           
10 loops, best of 3: 476 ms per loop

In [14]: a, b = mock_data(1000000)
In [15]: %timeit -n 10 findwhere(a,b)                                     
10 loops, best of 3: 109 ms per loop

In [16]: %timeit -n 10 pywhere(a,b)                                       
10 loops, best of 3: 2.51 s per loop

嘿,你的答案并不总是线性的。要使其如此,您需要进行更改。在顶部放置 ib = 0,并在 ib = match 之前放置 if b[match] == j: - Hameer Abbasi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接