我希望在Numpy中实现类似于SQL WHERE
表达式的功能,涉及到IT技术。假设我有两个数组,如下所示:
import numpy as np
dt = np.dtype([('f1', np.uint8), ('f2', np.uint8), ('f3', np.float_)])
a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
[1, 5, 5, 4, 2, 2],
[2.0, -4.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
b = np.rec.fromarrays([[ 1, 4, 7, 9, 9],
[ 7, 5, 4, 2, 2],
[-3.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
我希望能够返回原始数组的索引,以便涵盖每个可能的匹配对。此外,我想利用两个数组都已排序的事实,因此不需要最坏情况下的 O(mn) 算法。在本例中,由于 (4, 5, -4.5) 匹配,但在第一个数组中出现了两次,因此它会在结果索引中出现两次,并且由于 (9, 2, 24.3) 在两个数组中都出现了两次,因此它将出现总共 4 次。由于 (3, 1, 2.0) 在第二个数组中不存在,因此将被跳过,第二个数组中的 (1, 7, -3.5) 也是如此。该函数应适用于任何 dtype。
在这种情况下,结果可能如下:
a_idx, b_idx = match_arrays(a, b)
a_idx = np.array([1, 2, 3, 4, 4, 5, 5])
b_idx = np.array([1, 1, 2, 3, 4, 3, 4])
同样输出结果的另一个例子:
dt2 = np.dtype([('f1', np.uint8), ('f2', dt)])
a2 = np.rec.fromarrays([[3, 4, 4, 7, 9, 9], a], dtype=dt2)
b2 = np.rec.fromarrays([[1, 4, 7, 9, 9], b], dtype=dt2)
我有一个纯Python实现,但在我的使用情况下速度非常慢。我希望有更多的向量化操作。以下是我目前已经实现的代码:
def match_arrays(a, b):
len_a = len(a)
len_b = len(b)
a_idx = []
b_idx = []
i, j = 0, 0
first_matched_j = 0
while i < len_a and j < len_b:
matched = False
j = first_matched_j
while j < len_b and a[i] == b[j]:
a_idx.append(i)
b_idx.append(j)
if not matched:
matched = True
first_matched_j = j
j += 1
else:
i += 1
j = first_matched_j
while i < len_a and j < len_b and a[i] > b[j]:
j += 1
first_matched_j = j
while i < len_a and j < len_b and a[i] < b[j]:
i += 1
return np.array(a_idx), np.array(b_idx)
编辑:正如Divakar在他的答案中指出的那样,我可以使用a_idx, b_idx = np.where(np.equal.outer(a, b))
。然而,这似乎是我想要避免的最坏情况O(mn)
解决方案,通过对数组进行预排序可以避免这种情况。特别地,在没有重复项的情况下,如果能做到O(m + n)
就太好了。
编辑2:Paul Panzer的答案如果只使用Numpy,则不是O(m + n)
,但通常会更快。此外,他还提供了一个O(m + n)
的答案,所以我接受了那个答案。我很快就会使用timeit
进行性能比较。
编辑3:如承诺的那样,这里是性能测试结果:
╔════════════════╦═══════════════════╦═══════════════════╦═══════════════════╦══════════════════╦═══════════════════╗
║ User ║ Version ║ n = 10 ** 2 ║ n = 10 ** 4 ║ n = 10 ** 6 ║ n = 10 ** 8 ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Paul Panzer ║ USE_HEAPQ = False ║ 115 µs ± 385 ns ║ 793 µs ± 8.43 µs ║ 105 ms ± 1.57 ms ║ 18.2 s ± 116 ms ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ USE_HEAPQ = True ║ 189 µs ± 3.6 µs ║ 6.38 ms ± 28.8 µs ║ 650 ms ± 2.49 ms ║ 1min 11s ± 420 ms ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ SigmaPiEpsilon ║ Generator ║ 936 µs ± 1.52 µs ║ 9.17 s ± 57 ms ║ N/A ║ N/A ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ for loop ║ 144 µs ± 526 ns ║ 15.6 ms ± 18.6 µs ║ 1.74 s ± 33.9 ms ║ N/A ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Divakar ║ np.where ║ 39.1 µs ± 281 ns ║ 302 ms ± 4.49 ms ║ Out of memory ║ N/A ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ recarrays 1 ║ 69.9 µs ± 491 ns ║ 1.6 ms ± 24.2 µs ║ 230 ms ± 3.52 ms ║ 41.5 s ± 543 ms ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ recarrays 2 ║ 82.6 µs ± 1.01 µs ║ 1.4 ms ± 4.51 µs ║ 212 ms ± 2.59 ms ║ 36.7 s ± 900 ms ║
╚════════════════╩═══════════════════╩═══════════════════╩═══════════════════╩══════════════════╩═══════════════════╝
看起来Paul Panzer的答案以
USE_HEAPQ = False
获胜。 我原本以为对于大型输入,USE_HEAPQ = True
会获胜,因为它是O(m + n)
,但事实并非如此。另外一个评论是,USE_HEAPQ = False
版本使用了更少的内存,最多只有5.79GB,而USE_HEAPQ = True
则需要10.18GB,对于n=10**8
。请记住,这是进程内存,并包括控制台的输入和其他内容。Divakar的recarrays答案1使用了8.42 GB的内存,recarrays答案2使用了10.61 GB。