检查单个元素是否包含在Numpy数组中

11

我只想快速检查一个numpy数组是否包含单个数字,类似于列表的contains。是否有一种简洁的方法来做到这一点?

a = np.array(9,2,7,0)
a.contains(0)  == true

3
可能是 https://dev59.com/ZWw05IYBdhLWcg3wszwl 的重复问题。 - abagshaw
np.array(9,2,7,0)会引发错误。 - hpaulj
4个回答

19

你可以在a中使用0,即

a = np.array([9,2,7,0])
0 in a

6

如果a是一个NumPy数组:

a = np.array([1, 2])

然后使用:

1 in a

如果返回 true,则表示:

0 in a

返回false


5
我使用Python 3.7计时了一些处理方法:
import numpy as np
rnd = np.random.RandomState(42)
one_d = rnd.randint(100, size=10000)
n_d = rnd.randint(100, size=(10000, 10000))
searched = 42

# One dimension
%timeit if np.isin(one_d, searched, assume_unique=True).any(): pass
%timeit if np.in1d(one_d, searched, assume_unique=True).any(): pass
%timeit if searched in one_d: pass
%timeit if one_d[np.searchsorted(one_d, searched)] == searched: pass
%timeit if np.count_nonzero(one_d == searched): pass

print("------------------------------------------------------------------")

# N dimensions
%timeit if np.isin(n_d, searched, assume_unique=True).any(): pass
%timeit if np.in1d(n_d, searched, assume_unique=True).any(): pass
%timeit if searched in n_d: pass
%timeit if np.count_nonzero(n_d == searched): pass

>>> 42.8 µs ± 79.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>> 38.6 µs ± 76.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>> 16.4 µs ± 57.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> 4.7 µs ± 62.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> 12.1 µs ± 69.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> ------------------------------------------------------------------
>>> 239 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> 241 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> 156 ms ± 2.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> 163 ms ± 527 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于一维数组,最快的方法是使用上面提出的np.searchsorted,但是它不能用于多维数组。另外,np.count_nonzero是最快的方法,但它并不比Pythonic的in更快,因此建议使用in


有时,在一维数组中,将 one_d 转换为 set_one_d = set(one_d.tolist()),然后检查 searched in set_one_d 会更快,如果初始开销在您的用例中是可以接受的话。对于小的一维数组,转换为集合的开销非常低,甚至比 np.searchsorted() 更快。 - Larry Panozzo

-1
x = 0
if x in a:
   print 'find'

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接