基于另一个数组中元素出现次数的情况下,NumPy数组中最高效的删除元素方法是什么?

3

假设我有两个Numpy数组:

a = np.array([1,2,2,3,3,3])
b = np.array([2,2,3])

我希望能够从a中移除所有在b中出现相同次数的元素。例如:

diff(a, b)
>>> np.array([1,3,3])

请注意,对于我的用例来说,b将始终是a的子集,两者都可能无序,但类似于集合的方法,如numpy.setdiff1d并不适用,因为重要的是要移除每个元素一定次数。
我的当前的懒惰解决方案如下:
def diff(a, b):
    for el in b:
        idx = (el == a).argmax()
        if a[idx] == el:
            a = np.delete(a, idx)
    return a

但我想知道是否有更高效或更紧凑的写法,类似于“NumPy”?


1
你可以使用 for 循环遍历并切换掩码中的位,然后将该掩码应用于 numpy 数组。这样做可能比使用 np.delete 反复重新创建数组要快得多。 - Mateen Ulhaq
无序:(1):使用np.unique和计数,然后使用该值重构数组。 (2):使用类似于枚举的想法,将每个数组项与唯一索引配对,然后取结果的集合差异。 - Mateen Ulhaq
@MateenUlhaq 可能是无序的,帖子已更新。 - Johan Dettmar
a 是否按照排序顺序排列? - Divakar
如果aarray([3, 2, 3, 3, 2, 1]),你希望使用diff输出array([3, 3, 1])还是希望得到排序后的输出array([1, 3, 3]) - Divakar
显示剩余3条评论
3个回答

3
这里提供一种基于 np.searchsorted 的矢量化方法 -
import pandas as pd

def diff_v2(a, b):
    # Get sorted orders
    sidx = a.argsort(kind='stable')
    A = a[sidx]
    
    # Get searchsorted indices per sorted order
    idx = np.searchsorted(A,b)
    
    # Get increments
    s = pd.Series(idx)
    inc = s.groupby(s).cumcount().values
    
    # Delete elemnents off traced back positions
    return np.delete(a,sidx[idx+inc])

进一步优化

让我们使用NumPy来处理按组计数部分的groupby cumcount -

# Perform groupby cumcount on sorted array
def groupby_cumcount(idx):
    mask = np.r_[False,idx[:-1]==idx[1:],False]
    ids = mask[:-1].cumsum()
    count = np.diff(np.flatnonzero(~mask))
    return ids - np.repeat(ids[~mask[:-1]],count)

def diff_v3(a, b):
    # Get sorted orders
    sidx = a.argsort(kind='stable')
    A = a[sidx]
    
    # Get searchsorted indices per sorted order
    idx = np.searchsorted(A,b)
    
    # Get increments
    idx = np.sort(idx)
    inc = groupby_cumcount(idx)
    
    # Delete elemnents off traced back positions
    return np.delete(a,sidx[idx+inc])

基准测试

使用一个设置,包含 10000 个元素,对于大小为 a 的元素和大小为 a 一半的元素 b 进行 ~2x 次重复测试。

In [52]: np.random.seed(0)
    ...: a = np.random.randint(0,5000,10000)
    ...: b = a[np.random.choice(len(a), 5000,replace=False)]

In [53]: %timeit diff(a,b)
    ...: %timeit diff_v2(a,b)
    ...: %timeit diff_v3(a,b)
108 ms ± 821 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.85 ms ± 53.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.89 ms ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

接下来,在 100000 个元素上进行 -
In [54]: np.random.seed(0)
    ...: a = np.random.randint(0,50000,100000)
    ...: b = a[np.random.choice(len(a), 50000,replace=False)]

In [55]: %timeit diff(a,b)
    ...: %timeit diff_v2(a,b)
    ...: %timeit diff_v3(a,b)
4.45 s ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
37.5 ms ± 661 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
28 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于正数且需要排序输出

我们可以使用np.bincount-

def diff_v4(a, b):
    C = np.bincount(a)
    C -= np.bincount(b,minlength=len(C))
    return np.repeat(np.arange(len(C)), C)

敢挑战吗? - Paul Panzer
@PaulPanzer 你是在笔记本上试吗?另外,在导入 benchit 之前是否导入了 matplotlib?时间大致相近 :) - Divakar
@PaulPanzer 如果您能够导入 benchit,能否为我获取 benchit.print_specs() 的输出?这是Windows系统吗?似乎我需要进行更多的测试。 - Divakar
@PaulPanzer,我认为您需要安装qT5Agg。可以尝试使用pip安装命令pip install pyqt5进行安装。您可以试试看吗? - Divakar
我创建了一个聊天室 https://chat.stackoverflow.com/rooms/216791/benchit - Paul Panzer
显示剩余2条评论

3

这里提供一种方法,与@Divakar的方法类似但速度略快(截至撰写本文时,情况可能会有所改变…)。

import numpy as np

def pp():
    if a.dtype.kind == "i":
        small = np.iinfo(a.dtype).min
    else:
        small = -np.inf
    ba = np.concatenate([[small],b,a])
    idx = ba.argsort(kind="stable")
    aux = np.where(idx<=b.size,-1,1)
    aux = aux.cumsum()
    valid = aux==np.maximum.accumulate(aux)
    valid[0] = False
    valid[1:] &= valid[:-1]
    aux2 = np.zeros(ba.size,bool)
    aux2[idx[valid]] = True
    return ba[aux2.nonzero()]

def groupby_cumcount(idx):
    mask = np.r_[False,idx[:-1]==idx[1:],False]
    ids = mask[:-1].cumsum()
    count = np.diff(np.flatnonzero(~mask))
    return ids - np.repeat(ids[~mask[:-1]],count)

def diff_v3():
    # Get sorted orders
    sidx = a.argsort(kind='stable')
    A = a[sidx]
    
    # Get searchsorted indices per sorted order
    idx = np.searchsorted(A,b)
    
    # Get increments
    idx = np.sort(idx)
    inc = groupby_cumcount(idx)
    
    # Delete elemnents off traced back positions
    return np.delete(a,sidx[idx+inc])

np.random.seed(0)
a = np.random.randint(0,5000,10000)
b = a[np.random.choice(len(a), 5000,replace=False)]

from timeit import timeit

print(timeit(pp,number=100)*10)
print(timeit(diff_v3,number=100)*10)
print((pp() == diff_v3()).all())

np.random.seed(0)
a = np.random.randint(0,50000,100000)
b = a[np.random.choice(len(a), 50000,replace=False)]

print(timeit(pp,number=10)*100)
print(timeit(diff_v3,number=10)*100)
print((pp() == diff_v3()).all())

示例运行:

1.4644702401710674
1.6345531499246135
True
22.230969095835462
24.67835019924678
True

更新:@MateenUlhaq的“dedup_unique”的相应时间如下:
7.986748410039581
81.83312350302003

请注意,此函数生成的结果与Divakar和我生成的结果并不完全相同(至少不是显而易见的)。

2

Your method:

def dedup_reference(a, b):
    for el in b:
        idx = (el == a).argmax()
        if a[idx] == el:
            a = np.delete(a, idx)
    return a

需要进行输入排序的扫描方法:
def dedup_scan(arr, sel):
    arr.sort()
    sel.sort()
    mask = np.ones_like(arr, dtype=np.bool)
    sel_idx = 0
    for i, x in enumerate(arr):
        if sel_idx == sel.size:
            break
        if x == sel[sel_idx]:
            mask[i] = False
            sel_idx += 1
    return arr[mask]

np.unique计数方法:

def dedup_unique(arr, sel):
    d_arr = dict(zip(*np.unique(arr, return_counts=True)))
    d_sel = dict(zip(*np.unique(sel, return_counts=True)))
    d = {k: v - d_sel.get(k, 0) for k, v in d_arr.items()}
    res = np.empty(sum(d.values()), dtype=arr.dtype)
    idx = 0
    for k, count in d.items():
        res[idx:idx+count] = k
        idx += count
    return res

你也许可以通过巧妙地使用numpy集合函数(例如np.in1d)来实现与上述相同的功能,但我认为这不比使用字典更快。


以下是一种懒惰的基准测试尝试(已更新以包括@Divakar的diff_v2diff_v3方法):

>>> def timeit_ab(f, n=10):
...     cmd = f"{f}(a.copy(), b.copy())"
...     t = timeit(cmd, globals=globals(), number=n) / n
...     print("{:.4f} {}".format(t, f))

>>> array_copy = lambda x, y: None

>>> funcs = [
...     'array_copy',
...     'dedup_reference',
...     'dedup_scan',
...     'dedup_unique',
...     'diff_v2',
...     'diff_v3',
... ]

>>> def run_test(maxval, an, bn):
...     global a, b
...     a = np.random.randint(maxval, size=an)
...     b = np.random.choice(a, size=bn, replace=False)
...     for f in funcs:
...         timeit_ab(f)

>>> run_test(10**1, 10000, 5000)
0.0000 array_copy
0.0617 dedup_reference
0.0035 dedup_scan
0.0004 dedup_unique     (*)
0.0020 diff_v2
0.0009 diff_v3

>>> run_test(10**2, 10000, 5000)
0.0000 array_copy
0.0643 dedup_reference
0.0037 dedup_scan
0.0007 dedup_unique     (*)
0.0023 diff_v2
0.0013 diff_v3

>>> run_test(10**3, 10000, 5000)
0.0000 array_copy
0.0641 dedup_reference
0.0041 dedup_scan
0.0022 dedup_unique
0.0027 diff_v2
0.0016 diff_v3          (*)

>>> run_test(10**4, 10000, 5000)
0.0000 array_copy
0.0635 dedup_reference
0.0041 dedup_scan
0.0082 dedup_unique
0.0029 diff_v2
0.0015 diff_v3          (*)

>>> run_test(10**5, 10000, 5000)
0.0000 array_copy
0.0635 dedup_reference
0.0041 dedup_scan
0.0118 dedup_unique
0.0031 diff_v2
0.0016 diff_v3          (*)

>>> run_test(10**6, 10000, 5000)
0.0000 array_copy
0.0627 dedup_reference
0.0043 dedup_scan
0.0126 dedup_unique
0.0032 diff_v2
0.0016 diff_v3          (*)

要点:

  • 随着重复项数量的增加,dedup_reference 的速度显著下降。
  • 如果值范围较小,则dedup_unique 是最快的。 diff_v3 很快,并且不取决于值的范围。
  • 数组复制时间可以忽略不计。
  • 字典非常酷。

性能特征强烈依赖于数据量(未经测试)和数据的统计分布。我建议使用自己的数据测试这些方法并选择最快的方法。请注意,各种解决方案会产生不同的输出,并对输入做出不同的假设。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接