基于一个列的交集,过滤多个NumPy数组

4

我有三个具有不同行数的相当大的NumPy数组,它们的第一列都是整数。我的希望是过滤这些数组,使得只有那些第一列的值被所有三个数组共享的行留下来。这将留下三个相同大小的数组。其他列中的条目在数组之间未必共享。

因此,输入如下:

A = 
[[1, 1],
[2, 2],
[3, 3],]

B = 
[[2, 1],
[3, 2],
[4, 3],
[5, 4]]

C = 
[[2, 2],
[3, 1]
[5, 2]]

I hope to get back as output:

A = 
[[2, 2],
[3, 3]]


B = 
[[2, 1],
[3, 2]]

C = 
[[2, 2],
[3, 1]]

我的当前方法是:

  1. 使用numpy.intersect1d()找到三个第一列的交集

  2. 在此交集和每个数组的第一列上使用numpy.in1d(),找到在每个数组中没有共享的行索引(使用此处找到的方法的修改版本将boolean转换为索引:Python: intersection indices numpy array

  3. 最后,使用numpy.delete()与每个索引及其相应的数组,删除具有非共享条目的第一列的行。

但我想知道是否有更快或更优雅的 Pythonic 方法来处理这个问题,特别是适用于非常大的数组。

3个回答

3

您的示例中索引已经排序并且唯一。假设这不是巧合(这种情况经常出现,或者可以轻松实现),以下方法可行:

import numpy as np

A = np.array(
[[1, 1],
[2, 2],
[3, 3],])

B = np.array(
[[2, 1],
[3, 2],
[4, 3],
[5, 4]])

C = np.array(
[[2, 2],
[3, 1],
[5, 2],])

I = reduce(
    lambda l,r: np.intersect1d(l,r,True),
    (i[:,0] for i in (A,B,C)))

print A[np.searchsorted(A[:,0], I)]
print B[np.searchsorted(B[:,0], I)]
print C[np.searchsorted(C[:,0], I)]

如果第一列未按排序顺序排列(但仍然是唯一的):

C = np.array(
[[9, 2],
[1,6],
[5, 1],
[2, 5],
[3, 2],])

def index_by_first_column_entry(M, keys):
    colkeys = M[:,0]
    sorter = np.argsort(colkeys)
    index = np.searchsorted(colkeys, keys, sorter = sorter)
    return M[sorter[index]]

print index_by_first_column_entry(C, I)

请确保将“true”更改为“false”

I = reduce(
    lambda l,r: np.intersect1d(l,r,False),
    (i[:,0] for i in (A,B,C)))

可以使用np.unique来实现对重复值的泛化处理。


1
一些评论:这个代码应该相当高效;主要的开销在于argsort,如果第一列尚未排序,则会有开销。如果您正在进行重复调用,则可能希望预先计算此argsort。它应该是相当高效的,但我想象中,在一个为这些操作调优的代码库中使用数据库系统仍然可以获得巨大的收益,特别是针对真正的大数据集。 - Eelco Hoogendoorn
这个方法假设所涉及的整数是唯一的,否则searchsorted只会返回每个整数的第一个实例。这是一个常见的用例,例如数据库中的主键,但你应该在答案中注意到这一点。 - Bi Rico
我曾考虑过推广重复键,但我认为这样做会更加模糊而不是更清晰。但你是对的,把它明确化是有好处的。 - Eelco Hoogendoorn
如果数组未排序,则可以使用计算量较小的算法,而不是对它们进行排序,然后在每个数组中搜索交集,而是在交集中搜索每个数组的值,并添加一个检查,即np.searchsorted找到匹配项,例如idx = np.searchsorted(I, B[:, 0]); idx[idx == len(I)] = -1; B[I[idx] == B[:, 0]]即使B未排序也可以工作。 - Jaime
我不会那么确定。searchsorted在第一个参数中具有O(log(n))的复杂度,在最后一个参数中具有O(n)的复杂度。所以,通过这种方式,你将通常的O(n*log(n))换成了O(n)(尽管基数排序或快速排序在几乎排序好的数据上也是O(n)),但假设交集只占总条目的一小部分,这很可能是真实情况,那么现在你正在为更多的点进行搜索排序。我非常怀疑这种方法在任何真实世界的场景中都不会更快。 - Eelco Hoogendoorn
如果数组没有排序,那么我概述的算法肯定计算成本较低,但这并不意味着它会运行得更快...;-) - Jaime

2

一种方法是构建一个指示器数组或哈希表,以指示哪些整数在您的所有输入数组中。然后,您可以使用基于此指示器数组的布尔索引来获取子数组。类似这样:

import numpy as np

# Setup
A = np.array(
[[1, 1],
[2, 2],
[3, 3],])

B = np.array(
[[2, 1],
[3, 2],
[4, 3],
[5, 4]])

C = np.array(
[[2, 2],
[3, 1],
[5, 2],])


def take_overlap(*input):
    n = len(input)
    maxIndex = max(array[:, 0].max() for array in input)
    indicator = np.zeros(maxIndex + 1, dtype=int)
    for array in input:
        indicator[array[:, 0]] += 1
    indicator = indicator == n

    result = []
    for array in input:
        # Look up each integer in the indicator array
        mask = indicator[array[:, 0]]
        # Use boolean indexing to get the sub array
        result.append(array[mask])

    return result

subA, subB, subC = take_overlap(A, B, C)

这种方法应该非常快,而且不假定输入数组的元素是唯一的或已排序的。但是,如果索引整数是稀疏的(即[1,10,10000]),则此方法可能需要大量内存,并且可能会慢一些,但如果整数相对密集,则应该接近最佳。


1
@Jaime,它类似于bincount但略有不同。请注意,在操作中,indicator == n指示器值等于整数出现在数组中的数量,而不是整数在所有数组中出现的总次数。创建指示器数组的等效但我认为更加复杂的代码如下:reduce(np.logical_and, (np.bincount(A[:, 0], minlength=maxIndex+1) for A in inputs)) - Bi Rico

0

这个方法是可行的,但我不确定它是否比其他答案更快:

import numpy as np

A = np.array(
[[1, 1],
[2, 2],
[3, 3],])

B = np.array(
[[2, 1],
[3, 2],
[4, 3],
[5, 4]])

C = np.array(
[[2, 2],
[3, 1],
[5, 2],])

a = A[:,0]
b = B[:,0]
c = C[:,0]

ab = np.where(a[:, np.newaxis] == b[np.newaxis, :])
bc = np.where(b[:, np.newaxis] == c[np.newaxis, :])

ab_in_bc = np.in1d(ab[1], bc[0])
bc_in_ab = np.in1d(bc[0], ab[1])

arows = ab[0][ab_in_bc]
brows = ab[1][ab_in_bc]
crows = bc[1][bc_in_ab]

anew = A[arows, :]
bnew = B[brows, :]
cnew = C[crows, :]

print(anew)
print(bnew)
print(cnew)

给出:

[[2 2]
 [3 3]]
[[2 1]
 [3 2]]
[[2 2]
 [3 1]]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接