我试图理解使用不同的numba
算法实现所看到的性能差异。特别是,我期望下面的func1d
是最快的算法实现,因为它是唯一不复制数据的算法,但从我的计时结果来看,func1b
似乎是最快的。
import numpy
import numba
def func1a(data, a, b, c):
# pure numpy
return a * (1 + numpy.tanh((data / b) - c))
@numba.njit(fastmath=True)
def func1b(data, a, b, c):
new_data = a * (1 + numpy.tanh((data / b) - c))
return new_data
@numba.njit(fastmath=True)
def func1c(data, a, b, c):
new_data = numpy.empty(data.shape)
for i in range(new_data.shape[0]):
for j in range(new_data.shape[1]):
new_data[i, j] = a * (1 + numpy.tanh((data[i, j] / b) - c))
return new_data
@numba.njit(fastmath=True)
def func1d(data, a, b, c):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
data[i, j] = a * (1 + numpy.tanh((data[i, j] / b) - c))
return data
用于测试内存拷贝的辅助函数
def get_data_base(arr):
"""For a given NumPy array, find the base array
that owns the actual data.
https://ipython-books.github.io/45-understanding-the-internals-of-numpy-to-avoid-unnecessary-array-copying/
"""
base = arr
while isinstance(base.base, numpy.ndarray):
base = base.base
return base
def arrays_share_data(x, y):
return get_data_base(x) is get_data_base(y)
def test_share(func):
data = data = numpy.random.randn(100, 3)
print(arrays_share_data(data, func(data, 0.5, 2.5, 2.5)))
时间
# force compiling
data = numpy.random.randn(10_000, 300)
_ = func1a(data, 0.5, 2.5, 2.5)
_ = func1b(data, 0.5, 2.5, 2.5)
_ = func1c(data, 0.5, 2.5, 2.5)
_ = func1d(data, 0.5, 2.5, 2.5)
data = numpy.random.randn(10_000, 300)
%timeit func1a(data, 0.5, 2.5, 2.5)
%timeit func1b(data, 0.5, 2.5, 2.5)
%timeit func1c(data, 0.5, 2.5, 2.5)
%timeit func1d(data, 0.5, 2.5, 2.5)
67.2 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
13 ms ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
69.8 ms ± 60.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
69.8 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
测试哪些实现可以拷贝内存
test_share(func1a)
test_share(func1b)
test_share(func1c)
test_share(func1d)
False
False
False
True
numexpr
进行比较,我之前不知道这个选项。话虽如此,我的结果似乎与@ead上面提到的一致,即如果我放弃使用tanh
,我的函数1b、1c和1d都显示出类似的性能。 - mgilbertnumpy
和numba
中的tanh
,并得到了相同的结果。但是如果他的答案正确,那么numba
版本应该更快。 - Lukas Snumpy.tanh
占用了 17.4 毫秒中的 15.7 毫秒。因此,临时数组的成本并不大,无法解释差异。正如我所说,使用哪个版本的 tanh 取决于使用的 Python 发行版和版本。 - eadtanh
的 vml 版本,这可能解释了 numba 和 numexpr 之间的差异。 - ead