在 Python 中,有没有一种高效的方法来获取位数组中第 n 个位置上的布尔值数组?
- 创建一个值为 0 或 1 的 NumPy 数组:
import numpy as np
array = np.array(
[
[1, 0, 1],
[1, 1, 1],
[0, 0, 1],
]
)
- 使用 np.packbits 压缩大小:
pack_array = np.packbits(array, axis=1)
- 期望结果 - 一个函数,可以从位数组中获取第 n 列的所有值。例如,如果我想要第二列,我希望获得与调用 array[:,1] 相同的结果:
array([0, 1, 0])
我尝试了以下函数的numba版本,虽然结果正确但速度非常慢:
import numpy as np
from numba import njit
@njit(nopython=True, fastmath=True)
def getVector(packed, j):
n = packed.shape[0]
res = np.zeros(n, dtype=np.int32)
for i in range(n):
res[i] = bool(packed[i, j//8] & (128>>(j%8)))
return res
如何进行测试?
import numpy as np
import time
from numba import njit
array = np.random.choice(a=[False, True], size=(100000000,15))
pack_array = np.packbits(array, axis=1)
start = time.time()
array[:,10]
print('np array')
print(time.time()-start)
@njit(nopython=True, fastmath=True)
def getVector(packed, j):
n = packed.shape[0]
res = np.zeros(n, dtype=np.int32)
for i in range(n):
res[i] = bool(packed[i, j//8] & (128>>(j%8)))
return res
# To initialize
getVector(pack_array, 10)
start = time.time()
getVector(pack_array, 10)
print('getVector')
print(time.time()-start)
它返回:
np array
0.00010132789611816406
getVector
0.15648770332336426
j//8
和128>>(j%8)
,创建一个np.empty
(使用dtype=np.bool
?)作为res。但这些只是微小的优化,可能已经被编译器完成了。 - Michael Szczesnynumpy
的方法是O(1),你不能将其作为基准。它只返回一个调整过步幅的视图,没有任何计算。计时结果应该由print
调用主导。 - Michael SzczesnyLLVM
在循环内部没有对明显的常量进行优化(在我的机器上)。将它们移到循环外部后,速度提高了约3.5倍。 - Michael Szczesny