基于另一个数组的元素,过滤NumPy数组的行

3

我有一个Nx2的数组group:

array([[    1,     6],
       [    1,     0],
       [    2,     1],
       ...,
       [40196, 40197],
       [40196, 40198],
       [40196, 40199]], dtype=uint32)

还有一个数组selection,其形状为(M,):

array([3216, 3217, 3218, ..., 8039]) 

我想创建一个新数组,其中包含group的所有行,这些行中的两个元素都在selection中。这是我的方法:

np.array([(i,j) for (i,j) in group if i in selection and j in selection])

这个方法可行,但我知道必须有一个更高效的方式来利用一些numpy函数。
1个回答

3
你可以使用 np.isin 函数来获取一个布尔数组,其形状与 group 相同,表示元素是否在 selection 中。然后,要检查行中的两个条目是否都在 selection 中,可以使用 all 函数和 axis=1 参数,它将返回一个一维的布尔数组,表示哪些行需要保留。最后,我们使用该数组进行索引。
group[np.isin(group, selection).all(axis=1)]

样例:

>>> group

array([[    1,     6],
       [    1,     0],
       [    2,     1],
       [40196, 40197],
       [40196, 40198],
       [40196, 40199]])

>>> selection

array([    1,     2,     3,     4,     5,     6, 40196, 40199])

>>> np.isin(group, selection)

array([[ True,  True],
       [ True, False],
       [ True,  True],
       [ True, False],
       [ True, False],
       [ True,  True]])

>>> np.isin(group, selection).all(axis=1)

array([ True, False,  True, False, False,  True])

>>> group[np.isin(group, selection).all(axis=1)]

array([[    1,     6],
       [    2,     1],
       [40196, 40199]])

1
非常好,谢谢!我用timeit比较了一下结果: 你的解决方案 4.07 ms ± 197 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 和我的 263 ms ± 9.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)。 快多了! :) - jenny_wren

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接