numba速度变慢是很奇怪的。
其实不太奇怪。当你在numba函数中调用NumPy函数时,你调用的是这些函数的numba版本。这些版本可能比NumPy版本更快、更慢或者速度相同。你可能会运气好,也可能会运气差(你就是运气不太好!)。但即使在numba函数中,你仍然会创建大量的临时数组,因为你使用了NumPy函数(一个用于点积结果、一个用于每个平方和、一个用于点积加上第一个和),所以你没有充分利用numba的可能性。
我用错了吗?
本质上:是的。
我真的需要加速。
好的,我会试一试。
让我们先从展开沿轴1的平方和开始:
import numba as nb
@nb.njit
def sum_squares_2d_array_along_axis1(arr):
res = np.empty(arr.shape[0], dtype=arr.dtype)
for o_idx in range(arr.shape[0]):
sum_ = 0
for i_idx in range(arr.shape[1]):
sum_ += arr[o_idx, i_idx] * arr[o_idx, i_idx]
res[o_idx] = sum_
return res
@nb.njit
def euclidean_distance_square_numba_v1(x1, x2):
return -2 * np.dot(x1, x2.T) + np.expand_dims(sum_squares_2d_array_along_axis1(x1), axis=1) + sum_squares_2d_array_along_axis1(x2)
在我的电脑上,这个代码比 NumPy 代码快了 2 倍,比你原来的 Numba 代码快了近 10 倍。
根据经验,将其加速到比 NumPy 快 2 倍通常是极限(至少如果 NumPy 版本不是无谓地复杂或低效),但是您可以通过展开所有内容来挤出更多性能:
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v2(x1, x2):
f1 = 0.
for i_idx in range(x1.shape[1]):
f1 += x1[0, i_idx] * x1[0, i_idx]
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
val_from_x2 = x2[o_idx, i_idx]
val += (-2) * x1[0, i_idx] * val_from_x2 + val_from_x2 * val_from_x2
val += f1
res[o_idx] = val
return res
但这仅对最新方法有约10-20%的改进。
此时,您可能会意识到可以简化代码(尽管这可能不会加速它):
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v3(x1, x2):
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
tmp = x1[0, i_idx] - x2[o_idx, i_idx]
val += tmp * tmp
res[o_idx] = val
return res
是的,看起来相当简单,而且并不会更慢。
然而,在所有兴奋中,我忘了提到显而易见的解决方案:scipy.spatial.distance.cdist
它有一个sqeuclidean
(平方欧几里得距离)选项:
from scipy.spatial import distance
distance.cdist(x1, x2, metric='sqeuclidean')
它并不比numba更快,但可用性更高,无需编写自己的函数...
测试
进行正确性测试和热身:
x1 = np.array([[1.,2,3]])
x2 = np.array([[1.,2,3], [2,3,4], [3,4,5], [4,5,6], [5,6,7]])
res1 = euclidean_distance_square(x1, x2)
res2 = euclidean_distance_square_numba_original(x1, x2)
res3 = euclidean_distance_square_numba_v1(x1, x2)
res4 = euclidean_distance_square_numba_v2(x1, x2)
res5 = euclidean_distance_square_numba_v3(x1, x2)
np.testing.assert_array_equal(res1, res2)
np.testing.assert_array_equal(res1, res3)
np.testing.assert_array_equal(res1[0], res4)
np.testing.assert_array_equal(res1[0], res5)
np.testing.assert_almost_equal(res1, distance.cdist(x1, x2, metric='sqeuclidean'))
时间:
x1 = np.random.random((1, 512))
x2 = np.random.random((1000000, 512))
# 2.09 s ± 54.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 10.9 s ± 158 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 907 ms ± 7.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 715 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 731 ms ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 706 ms ± 4.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
注意:如果您有整数数组,您可能希望将numba函数中硬编码的0.0
更改为0
。
numpy
使用 OpenCL 利用 GPGPU)。 - Basile Starynkevitch