我从Jaime在这里的出色回答中学到的一个技巧是使用np.void
dtype来将输入数组的每一行视为单个元素。这允许您将它们作为1D数组处理,然后可以将它们传递给np.in1d
或其他集合例程。
import numpy as np
def find_overlap(A, B):
if not A.dtype == B.dtype:
raise TypeError("A and B must have the same dtype")
if not A.shape[1:] == B.shape[1:]:
raise ValueError("the shapes of A and B must be identical apart from "
"the row dimension")
A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
B = np.ascontiguousarray(B.reshape(B.shape[0], -1))
t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))
return np.in1d(A.view(t), B.view(t))
例如:
gen = np.random.RandomState(0)
A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]
A_in_B = find_overlap(A, B)
print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
这种方法比Divakar的方法更加节省内存,因为它不需要广播到一个(m, n, ...)
布尔数组上。实际上,如果A
和B
是按行存储的,则根本不需要复制。
为了比较,我稍微修改了Divakar和B. M.的解决方案。
def divakar(A, B):
A.shape = A.shape[0], -1
B.shape = B.shape[0], -1
return (B[:,None] == A).all(axis=(2)).any(0)
def bm(A, B):
t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
ma = np.frombuffer(np.ascontiguousarray(A), t)
mb = np.frombuffer(np.ascontiguousarray(B), t)
return (mb[:, None] == ma).any(0)
基准测试:
In [1]: na = 1000; nb = 200; rowshape = 28, 28
In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
....:
1 loops, best of 3: 244 ms per loop
In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
100 loops, best of 3: 2.81 ms per loop
In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
100 loops, best of 3: 15 ms per loop
可以看出,对于小的n,B. M.的解决方案比我的略快一些,但是使用np.in1d
进行元素相等性测试比测试所有元素的相等性(O(n²)复杂度)更好地扩展(O(n log n)复杂度)。
In [5]: na = 10000; nb = 2000; rowshape = 28, 28
In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
1 loops, best of 3: 271 ms per loop
In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
10 loops, best of 3: 123 ms per loop
对于这个大小的数组,Divakar的解决方案在我的笔记本电脑上是不可行的,因为它需要生成一个15GB的中间数组,而我只有8GB的RAM。
train_dataset
和val_dataset
中都存在的行的索引吗? - ali_m