Python中检查两个高维数组是否重叠的有效方法

4
例如,我有两个ndarrays,train_dataset的形状为(10000, 28, 28),而val_dateset的形状为(2000, 28, 28)。除了使用循环迭代的方法外,是否有任何有效的方式可以使用numpy数组函数来查找两个ndarrays之间的重叠部分?

1
你能解释一下“overlap”具体是什么意思吗?你是在寻找train_datasetval_dataset中都存在的行的索引吗? - ali_m
1
好的,我想找出在这两个数据集中都出现的元素(28*28)。 - Vicky
如果你想要创建训练和验证数据集,最好使用scikit-learn的cross_validation模块 - Praveen
5个回答

4

我从Jaime在这里的出色回答中学到的一个技巧是使用np.void dtype来将输入数组的每一行视为单个元素。这允许您将它们作为1D数组处理,然后可以将它们传递给np.in1d或其他集合例程

import numpy as np

def find_overlap(A, B):

    if not A.dtype == B.dtype:
        raise TypeError("A and B must have the same dtype")
    if not A.shape[1:] == B.shape[1:]:
        raise ValueError("the shapes of A and B must be identical apart from "
                         "the row dimension")

    # reshape A and B to 2D arrays. force a copy if neccessary in order to
    # ensure that they are C-contiguous.
    A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
    B = np.ascontiguousarray(B.reshape(B.shape[0], -1))

    # void type that views each row in A and B as a single item
    t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))

    # use in1d to find rows in A that are also in B
    return np.in1d(A.view(t), B.view(t))

例如:

gen = np.random.RandomState(0)

A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]

A_in_B = find_overlap(A, B)

print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True

这种方法比Divakar的方法更加节省内存,因为它不需要广播到一个(m, n, ...)布尔数组上。实际上,如果AB是按行存储的,则根本不需要复制。


为了比较,我稍微修改了Divakar和B. M.的解决方案。

def divakar(A, B):
    A.shape = A.shape[0], -1
    B.shape = B.shape[0], -1
    return (B[:,None] == A).all(axis=(2)).any(0)

def bm(A, B):
    t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
    ma = np.frombuffer(np.ascontiguousarray(A), t)
    mb = np.frombuffer(np.ascontiguousarray(B), t)
    return (mb[:, None] == ma).any(0)

基准测试:

In [1]: na = 1000; nb = 200; rowshape = 28, 28

In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
   ....: 
1 loops, best of 3: 244 ms per loop

In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
100 loops, best of 3: 2.81 ms per loop

In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
100 loops, best of 3: 15 ms per loop

可以看出,对于小的n,B. M.的解决方案比我的略快一些,但是使用np.in1d进行元素相等性测试比测试所有元素的相等性(O(n²)复杂度)更好地扩展(O(n log n)复杂度)。

In [5]: na = 10000; nb = 2000; rowshape = 28, 28

In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
1 loops, best of 3: 271 ms per loop

In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
10 loops, best of 3: 123 ms per loop
对于这个大小的数组,Divakar的解决方案在我的笔记本电脑上是不可行的,因为它需要生成一个15GB的中间数组,而我只有8GB的RAM。

谢谢,这个解决方案真的很有帮助。虽然我正在尝试更好地理解它。使用np.ascontiguousarray(A.reshape(A.shape[0], -1))而不是我看到其他代码使用的np.array([x.flatten() for x in A])的原因是什么?这是一种风格问题,还是它们执行不同的操作? - Barker
1
@Barker 请尝试计时针对一个相当大的输入数组的这两行代码。首先,列表推导式几乎肯定比单个调用reshape慢,特别是如果A中有很多行的话。此外,.flatten()总是返回一个副本(而.reshape().ravel()仅在必要时才返回副本),因此您正在创建A中每一行的临时副本,然后在列表上调用np.array(...)时又创建了另一个副本。np.ascontiguousarray仅在必要时返回副本,因此我的代码最多只会创建A的一个副本。 - ali_m
这是此页面上唯一适用于大型数据集的解决方案。它应该有更多的赞! - Per Quested Aronsson

3

如果内存允许,您可以使用广播,像这样 -

val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]

运行示例 -

In [55]: train_dataset
Out[55]: 
array([[[1, 1],
        [1, 1]],

       [[1, 0],
        [0, 0]],

       [[0, 0],
        [0, 1]],

       [[0, 1],
        [0, 0]],

       [[1, 1],
        [1, 0]]])

In [56]: val_dateset
Out[56]: 
array([[[0, 1],
        [1, 0]],

       [[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

In [57]: val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
Out[57]: 
array([[[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

如果元素是整数,您可以将输入数组中的每个axis=(1,2)块折叠成标量,将它们视为可线性索引的数字,然后有效地使用np.in1dnp.intersect1d查找匹配项。

3
这段文本的翻译是:“完全广播在这里生成一个大小为10000*2000*28*28=150兆布尔数组。为了提高效率,你可以:”
  • pack data, for a 200 ko array:

    from pylab import *
    N=10000
    a=rand(N,28,28)
    b=a[[randint(0,N,N//5)]]
    
    packedtype='S'+ str(a.size//a.shape[0]*a.dtype.itemsize) # 'S6272' 
    ma=frombuffer(a,packedtype)  # ma.shape=10000
    mb=frombuffer(b,packedtype)  # mb.shape=2000
    
    %timeit a[:,None]==b   : 102 s
    %timeit ma[:,None]==mb   : 800 ms
    allclose((a[:,None]==b).all((2,3)),(ma[:,None]==mb)) : True
    

    less memory is helped here by lazy string comparison, breaking at first difference :

    In [31]: %timeit a[:100]==b[:100]
    10000 loops, best of 3: 175 µs per loop
    
    In [32]: %timeit a[:100]==a[:100]
    10000 loops, best of 3: 133 µs per loop
    
    In [34]: %timeit ma[:100]==mb[:100]
    100000 loops, best of 3: 7.55 µs per loop
    
    In [35]: %timeit ma[:100]==ma[:100]
    10000 loops, best of 3: 156 µs per loop
    
这里提供的解决方案是使用(ma[:,None]==mb).nonzero()
  • use in1d, for a (Na+Nb) ln(Na+Nb) complexity, against Na*Nb on full comparison :

    %timeit in1d(ma,mb).nonzero()  : 590ms 
    
这里的收益不是很大,但渐进地更好。

太棒了!从来不知道可以用这种方式在字符串中实现短路。但是也许应该在其中添加一个 ascontiguousarray,这样它才能正确地处理像 a = rand(28,28,N).T 这样的数组。 - user2379410
@morningsun:你说得对。在这种情况下,我只是创建了数组,所以它是连续的,但对于外部数据,“ma=frombuffer(ascontiguousarray(a),packedtype)”更安全。谢谢。 - B. M.

1

这个问题来自谷歌的在线深度学习课程吗?以下是我的解决方案:

sum = 0 # number of overlapping rows
for i in range(val_dataset.shape[0]): # iterate over all rows of val_dataset
    overlap = (train_dataset == val_dataset[i,:,:]).all(axis=1).all(axis=1).sum()
    if overlap:
        sum += 1
print(sum)

自动广播被用来代替迭代。您可以测试性能差异。

1
解决方案
def overlap(a,b):
    """
    returns a boolean index array for input array b representing
    elements in b that are also found in a
    """
    a.repeat(b.shape[0],axis=0)
    b.repeat(a.shape[0],axis=0)
    c = aa == bb
    c = c[::a.shape[0]]
    return c.all(axis=1)[:,0]

你可以使用返回的索引数组来索引b,以提取在a中也找到的元素。
b[overlap(a,b)]

说明

为了简化起见,我假设您已经在此示例中导入了numpy中的所有内容:

from numpy import *

例如,给定两个ndarrays:

a = arange(4*2*2).reshape(4,2,2)
b = arange(3*2*2).reshape(3,2,2)

我们重复 ab 以使它们具有相同的形状。
aa = a.repeat(b.shape[0],axis=0)
bb = b.repeat(a.shape[0],axis=0)

我们可以简单地比较aabb的元素。
c = aa == bb

最后,通过查看 c 的每个第四个元素,或者实际上是每个 shape(a)[0] 个元素,获取在 b 中也找到的元素的索引。
cc == c[::a.shape[0]]

最后,我们提取一个索引数组,其中仅包含子数组中所有元素都为 True 的元素。
c.all(axis=1)[:,0]

在我们的例子中,我们得到
array([True,  True,  True], dtype=bool)

检查一下,更改b的第一个元素。

b[0] = array([[50,60],[70,80]])

and we get

array([False,  True,  True], dtype=bool)

1
Divakar的解决方案实际上更加清晰简洁,点赞。 - Jan Christoph Terasa
我建议不要通过 from numpy import * 的方式污染你的命名空间。 - ali_m

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接