一个2x更快的方法是只使用np.count_nonzero()
,但需要按照所需的条件使用。
In [3]: arr
Out[3]:
array([[1, 2, 0, 3],
[3, 9, 0, 4]])
In [4]: np.count_nonzero(arr==0)
Out[4]: 2
In [5]:def func_cnt():
for arr in arrs:
zero_els = np.count_nonzero(arr==0)
你还可以使用
np.where()
,但速度比
np.count_nonzero()
慢。
In [6]: np.where( arr == 0)
Out[6]: (array([0, 1]), array([2, 2]))
In [7]: len(np.where( arr == 0))
Out[7]: 2
效率:(按降序排列)
In [8]: %timeit func_cnt()
10 loops, best of 3: 29.2 ms per loop
In [9]: %timeit func1()
10 loops, best of 3: 46.5 ms per loop
In [10]: %timeit func_where()
10 loops, best of 3: 61.2 ms per loop
加速器可带来更多速度提升
使用JAX并配备加速器(GPU/TPU),现在可以实现超过3个数量级的速度提升。使用JAX的另一个好处是NumPy代码几乎不需要修改即可与JAX兼容。下面是一个可复现的示例:
In [1]: import jax.numpy as jnp
In [2]: from jax import jit
In [3]: arrs = []
In [4]: for _ in range(1000):
...: arrs.append(np.random.randint(-5, 5, 10000))
In [5]: @jit
...: def func_cnt():
...: for arr in arrs:
...: zero_els = jnp.count_nonzero(arr==0)
# efficiency test
In [8]: %timeit func_cnt()
15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
count_nonzero
是一个非常基本的编译操作。无论您想知道零的数量还是非零的数量,您仍然必须遍历整个数组。让numpy在编译代码中完成这项工作,不用担心效率问题。 - hpauljlen(arr) - np.count_nonzero(arr)
是低效的? - juanpa.arrivillagalen(are)
只是一个简单的属性查找,对吧?它不会再次迭代数组... - juanpa.arrivillagalen(arr)
是通过函数调用进行的属性查找。纯属性查找a.size
花费的时间少了25%。 - DYZa.size
,特别是对于多维数组,而len(a)
会给出错误的答案。但我认为这不是 OP 所指的问题…… - juanpa.arrivillaga