从一个数组中删除另一个数组中存在的元素,保留重复项 - NumPy / Python

14

我有两个数组A(长度为380万)和B(长度为20k)。为了最小化示例,我们以这种情况为例:

A = np.array([1,1,2,3,3,3,4,5,6,7,8,8])
B = np.array([1,2,8])

现在我希望得到的数组为:

C = np.array([3,3,3,4,5,6,7])

i.e. 如果在数组 A 中发现任何一个值也出现在数组 B 中,就从数组 A 中删除该值;否则保留它。

我想知道是否有不使用 for 循环的方法来实现这个操作,因为数组很长,循环时间较长。


B数组已排序吗? - Divakar
@Divakar:不,但如果需要的话我可以排序,这不是问题。 - Srivatsan
1
numpy.setdiff1d 不正是你所需要的吗? - der_herr_g
3个回答

20

使用 searchsorted

有了已排序的B,我们可以使用searchsorted函数 -

A[B[np.searchsorted(B,A)] !=  A]
从链接的文档中,searchsorted(a,v)方法可用于查找已排序数组a中元素的索引,以便在将相应的v中的元素插入前,可以保持a的顺序。因此,假设idx = searchsorted(B,A),并使用这些索引从B中获取对应的元素:B[idx],我们将得到一个映射版本的B,对应于A中的每个元素。因此,将这个映射版本与A进行比较,可以告诉我们对于A中的每个元素,B中是否存在匹配项。最后,通过在A中选择非匹配项来进行索引。 通用情况(B未排序): 如果B尚未按照先决条件排序,请将其排序,然后再使用提议的方法。
或者,我们可以在searchsorted中使用sorter参数。
sidx = B.argsort()
out = A[B[sidx[np.searchsorted(B,A,sorter=sidx)]] != A]

更一般化的情况(A 的值高于 B 中的值):

sidx = B.argsort()
idx = np.searchsorted(B,A,sorter=sidx)
idx[idx==len(B)] = 0
out = A[B[sidx[idx]] != A]

使用in1d/isin

我们也可以使用np.in1d,它非常直观(链接文档应该有所帮助),因为它在B中查找任何匹配A中的每个元素,然后我们可以使用反转的掩码进行布尔索引以查找不匹配的元素 -

A[~np.in1d(A,B)]

isin相同 -

A[~np.isin(A,B)]

使用 invert 标志 -

A[np.in1d(A,B,invert=True)]

A[np.isin(A,B,invert=True)]

B不一定排序时,这可以解决一般情况。


我能否获得更多关于这两种方法如何工作的信息? - Srivatsan
使用 timeit,似乎第一种方法快了约两倍。 - CIsForCookies
1
@ThePredator 添加了注释。 - Divakar
1
这是一个简洁明了的好答案,但如果A中的值范围大于B(也就是说,如果A中有比B中更大的值),那么它将不能直接使用。我会添加一个补充答案以供参考。 - vmg
@vmg 很好的观点!编辑了解决方案以处理那些情况。 - Divakar

5

我对numpy不是很熟悉,但使用集合如何:

C = set(A.flat) - set(B.flat)

编辑:根据评论,集合不能包含重复的值。

因此,另一种解决方案是使用lambda表达式:

C = np.array(list(filter(lambda x: x not in B, A)))

集合不能包含重复的值。 - kur ag

1

Divakar的回答的基础上,如果原始数组A的范围比B更广,那么就会出现“索引超出范围”的错误。请参见:

A = np.array([1,1,2,3,3,3,4,5,6,7,8,8,10,12,14])
B = np.array([1,2,8])

A[B[np.searchsorted(B,A)] !=  A]
>> IndexError: index 3 is out of bounds for axis 0 with size 3


这将发生是因为np.searchsorted会将A中的10、12和14元素分配给B中的索引3(即B中插入位置的最后一个位置)。因此,在B[np.searchsorted(B,A)]中会出现IndexError。
为了避免这种情况,一种可能的方法是:
def subset_sorted_array(A,B):
    Aa = A[np.where(A <= np.max(B))]
    Bb = (B[np.searchsorted(B,Aa)] !=  Aa)
    Bb = np.pad(Bb,(0,A.shape[0]-Aa.shape[0]), method='constant', constant_values=True)
    return A[Bb]

它的工作方式如下:
# Take only the elements in A that would be inserted in B
Aa = A[np.where(A <= np.max(B))]

# Pad the resulting filter with 'Trues' - I split this in two operations for
# easier reading
Bb = (B[np.searchsorted(B,Aa)] !=  Aa)
Bb = np.pad(Bb,(0,A.shape[0]-Aa.shape[0]),  method='constant', constant_values=True)

# Then you can filter A by Bb
A[Bb]
# For the input arrays above:
>> array([ 3,  3,  3,  4,  5,  6,  7, 10, 12, 14])

请注意,这也适用于字符串数组和其他类型之间的比较(对于所有定义了比较运算符“<=”的类型)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接