经过一些性能分析,我发现矩阵乘法是性能瓶颈。出于好奇,我尝试了“显式”地编写矩阵乘法。下面是一个带有测量运行时间的代码示例。
import timeit
import numpy as np
def explicit_2x2_matrices_multiplication(
mats_a: np.ndarray, mats_b: np.ndarray
) -> np.ndarray:
matrices_multiplied = np.empty_like(mats_b)
for i in range(2):
for j in range(2):
matrices_multiplied[:, i, j] = (
mats_a[:, i, 0] * mats_b[:, 0, j] + mats_a[:, i, 1] * mats_b[:, 1, j]
)
return matrices_multiplied
matrices_a = np.random.random((1000, 2, 2))
matrices_b = np.random.random((1000, 2, 2))
assert np.allclose( # Checking that the explicit version is correct
matrices_a @ matrices_b,
explicit_2x2_matrices_multiplication(matrices_a, matrices_b),
)
print( # 1.1814142999992328 seconds
timeit.timeit(lambda: matrices_a @ matrices_b, number=10000)
)
print( # 1.1954495010013488 seconds
timeit.timeit(lambda: np.matmul(matrices_a, matrices_b), number=10000)
)
print( # 2.2304022700009227 seconds
timeit.timeit(lambda: np.einsum('lij,ljk->lik', matrices_a, matrices_b), number=10000)
)
print( # 0.19581600800120214 seconds
timeit.timeit(
lambda: explicit_2x2_matrices_multiplication(matrices_a, matrices_b),
number=10000,
)
)
根据代码测试,这个函数产生的结果与常规矩阵的__matmul__结果相同。然而不同的是速度:在我的机器上,显式表达式要快10倍。
对我来说,这是一个相当令人惊讶的结果。我本以为numpy表达式会更快,或者至少与较长的Python版本相当,而不是像我们在这里看到的那样慢一个数量级。我很好奇为什么性能差异如此巨大。
我正在运行numpy版本1.25和Python版本3.10.6。
einsum
比较。 - Reinderieneinsum
的比较。 - Reinderien