在3维数组中查找2维数组。

4

有没有一种快速的方法可以找到2d数组在3d数组中的所有索引?

我有一个3d numpy数组:

arr = np.array([
        [[0,1],[0,2],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5],[0,5],[0,5]],
        [[0,1],[0,2],[0,2],[0,2],[0,3],[0,4],[0,4],[0,4],[0,5],[0,5]],
        [[0,1],[0,2],[0,3],[0,3],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5]]
       ])

我希望找到所有出现 [0,4] 的索引。 我尝试过以下方法:

whereInd = np.argwhere(arr == np.array([0,4]))

但是它不起作用。 预期结果是:
[[0 3],[0 4],[1 5],[1 6],[1 7],[2 5],[2 6]]

另一个问题是,这样做是否会很快?因为我想用它处理一个 (10000,100,2) 的数组。

2个回答

2

使用argwhere()是个好主意,但你还需要使用all()来得到你想要的输出:

>>> np.argwhere((arr == [0, 4]).all(axis=2))
array([[0, 3],
       [0, 4],
       [1, 5],
       [1, 6],
       [1, 7],
       [2, 5],
       [2, 6]])

这里使用all()来检查每一行是否与比较结果[0, 4]相等(即每一行都是[True, True])。在3D数组中,axis=2指向行。
这将把维度数减少到二,并且argwhere()返回所需索引的数组。
关于性能,该方法应该可以快速处理您指定大小的数组:
In [20]: arr = np.random.randint(0, 10, size=(10000, 100, 2))
In [21]: %timeit np.argwhere((arr == [0, 4]).all(axis=2))
10 loops, best of 3: 44.9 ms per loop

0
我能想到的最简单的解决方案是:
import numpy as np
arr = np.array([
        [[0,1],[0,2],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5],[0,5],[0,5]],
        [[0,1],[0,2],[0,2],[0,2],[0,3],[0,4],[0,4],[0,4],[0,5],[0,5]],
        [[0,1],[0,2],[0,3],[0,3],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5]]
       ])

whereInd = []
for i,row in enumerate(arr):
    for j,elem in enumerate(row):
        if all(elem == [0,4]):
            whereInd.append((i,j))

print whereInd
#prints [(0, 3), (0, 4), (1, 5), (1, 6), (1, 7), (2, 5), (2, 6)]

虽然使用 np.argwhere 的任何解决方案都应该比较快,大约快10倍。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接