为什么 np.linalg.norm(..., axis=1) 比手写向量范数公式慢?

3

为了将矩阵 X 的行归一化为单位长度,我通常使用以下方法:

X /= np.linalg.norm(X, axis=1, keepdims=True)

尝试为算法优化此操作时,我很惊讶地发现在我的机器上写出规范化大约快了40%。
X /= np.sqrt(X[:,0]**2+X[:,1]**2+X[:,2]**2)[:,np.newaxis]
X /= np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]

为什么会出现这种情况?在 np.linalg.norm() 中性能损失在哪里?

import numpy as np
X = np.random.randn(10000,3)

%timeit X/np.linalg.norm(X,axis=1, keepdims=True)
# 276 µs ± 4.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit X/np.sqrt(X[:,0]**2+X[:,1]**2+X[:,2]**2)[:,np.newaxis]
# 169 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit X/np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
# 185 µs ± 4.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

我观察到这个问题在MacbookPro 2015上使用OpenBLAS支持的(1) python3.6 + numpy v1.17.2(2) python3.9 + numpy v1.19.3中存在。
我认为这不是this post的重复,那篇文章涉及矩阵范数,而这篇文章是关于向量的L2范数。

3
看源代码,它在幕后处理了很多东西,这很可能就是原因。一种快速的“检查”方法是复制粘贴源代码,但删除所有不适用的废话,然后再次运行测试。 - IanQ
你可以尝试逐行分析性能。https://dev59.com/d2865IYBdhLWcg3wNL07#3927671 - Trilarion
@IanQuah 这就是消耗大部分时间的这一行代码 - normanius
1个回答

4

计算按行L2范数的源代码如下所示:

def norm(x, keepdims=False):
    x = np.asarray(x)
    s = x**2
    return np.sqrt(s.sum(axis=(1,), keepdims=keepdims))

简化后的代码假设实值x,并利用了np.add.reduce(s, ...)等价于s.sum(...)的事实。
因此,OP问题就是询问为什么np.sum(x,axis=1)sum(x[:,i] for i in range(x.shape[1]))慢:
%timeit X.sum(axis=1, keepdims=False)
# 131 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit sum(X[:,i] for i in range(X.shape[1]))
# 36.7 µs ± 91.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

这个问题已经在这里得到了解答。简单来说,降维(.sum(axis=1))会带来一定的开销,通常可以换取浮点精度和速度(如缓存机制、并行性),但是在只对三列进行降维的特殊情况下不会。在这种情况下,相对于实际计算,开销比较大。

如果X有更多的列,情况就会改变。使用numpy加速的标准化现在比使用Python for循环进行降维要快得多:

X = np.random.randn(10000,100)
%timeit X/np.linalg.norm(X,axis=1, keepdims=True)
# 3.36 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit X/np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
# 5.92 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这里可以找到另一个相关的SO主题: numpy ufuncs vs. for loop

问题在于为什么numpy没有明确处理常见的约简特殊情况(例如对低轴维度矩阵的列或行求和),可能是因为这种优化的效果往往强烈依赖于目标机器,并且会显著增加代码复杂性。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接