使用np.where在二维数组中查找匹配行

9

我想知道如何在2D数组中使用np.where

我有以下数组:

arr1 = np.array([[ 3.,  0.],
                 [ 3.,  1.],
                 [ 3.,  2.],
                 [ 3.,  3.],
                 [ 3.,  6.],
                 [ 3.,  5.]])

我想要找到这个数组:
arr2 = np.array([3.,0.])

但是当我使用np.where()时:

np.where(arr1 == arr2)

它返回:
(array([0, 0, 1, 2, 3, 4, 5]), array([0, 1, 0, 0, 0, 0, 0]))

我不明白这是什么意思。有人可以为我解释一下吗?


你看过 arr1 == arr2 吗? - hpaulj
1个回答

12

你可能想要所有与你的arr2相等的行:

>>> np.where(np.all(arr1 == arr2, axis=1))
(array([0], dtype=int64),)

这意味着第一行(零索引)匹配。


你的方法存在问题,即numpy会广播数组(可视化使用np.broadcast_arrays):

>>> arr1_tmp, arr2_tmp = np.broadcast_arrays(arr1, arr2)
>>> arr2_tmp
array([[ 3.,  0.],
       [ 3.,  0.],
       [ 3.,  0.],
       [ 3.,  0.],
       [ 3.,  0.],
       [ 3.,  0.]]) 

然后进行逐元素比较:

>>> arr1 == arr2
array([[ True,  True],
       [ True, False],
       [ True, False],
       [ True, False],
       [ True, False],
       [ True, False]], dtype=bool)

np.where然后为您提供每个True的坐标:

(说明:该翻译保留原句中的英文单词“True”,因为其在编程中是一种特定的数据类型,直接翻译成“真”或“真值”可能会导致误解。)

>>> np.where(arr1 == arr2)
(array([0, 0, 1, 2, 3, 4, 5], dtype=int64),
 array([0, 1, 0, 0, 0, 0, 0], dtype=int64))
#       ^---- first match (0, 0)
#          ^--- second match (0, 1)
#             ^--- third match (1, 0)
#  ...

这意味着 (0, 0)(第一行左侧项目)是第一个True,然后是 0, 1(第一行右侧项目),接着是 1, 0(第二行左侧项目),....


如果您沿第一轴使用np.all,则会得到所有完全相等的行:

>>> np.all(arr1 == arr2, axis=1)
array([ True, False, False, False, False, False], dtype=bool)

如果保留尺寸,就可以更好地进行可视化:

>>> np.all(arr1 == arr2, axis=1, keepdims=True)
array([[ True],
       [False],
       [False],
       [False],
       [False],
       [False]], dtype=bool)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接