NumPy广播优化点积的性能

3

这是一个相对简单的操作,但在我的实际代码中重复了数百万次,如果可能的话,我希望提高其性能。

import numpy as np

# Initial data array
xx = np.random.uniform(0., 1., (3, 14, 1))
# Coefficients used to modify 'xx'
a, b, c = np.random.uniform(0., 1., 3)

# Operation on 'xx' to obtain the final array 'yy'
yy = xx[0] * a * b + xx[1] * b + xx[2] * c

最后一行是我想改进的。基本上,xx 中的每个术语都乘以一个因子(由 a、b、c 系数给出),然后将所有术语相加,以给出具有形状为 (14, 1) 的最终 yy 数组,与初始 xx 数组的形状 (3, 14, 1) 不同。可以通过 numpy 广播来实现这个目标吗?

1
将值相加仅为轴提供一个结果,因此yy.shape将为(1,14,1)。 - f5r5e5d
你是正确的 @f5r5e5d,我指的是保持 (14, 1) 的形状。我现在会修复这个问题。 - Gabriel
einsum 被认为是非常快的。 - f5r5e5d
2个回答

3
我们可以使用广播乘法,然后沿第一轴求和作为第一种选择。
作为第二种选择,我们还可以使用矩阵乘法,并使用np.dot。因此,我们有了两种更多的方法。以下是问题中提供的示例的时间记录 -
# Original one
In [81]: %timeit xx[0] * a * b + xx[1] * b + xx[2] * c
100000 loops, best of 3: 5.04 µs per loop

# Proposed alternative #1
In [82]: %timeit (xx *np.array([a*b,b,c])[:,None,None]).sum(0)
100000 loops, best of 3: 4.44 µs per loop

# Proposed alternative #2
In [83]: %timeit np.array([a*b,b,c]).dot(xx[...,0])[:,None]
1000000 loops, best of 3: 1.51 µs per loop

1
这种方法是所有方法中最快的,比我的原始代码快2.5倍。再次感谢Divakar! - Gabriel

2

这类似于Divakar的答案。交换xx的第一和第三个轴,并进行点积。

import numpy as np

# Initial data array
xx = np.random.uniform(0., 1., (3, 14, 1))
# Coefficients used to modify 'xx'
a, b, c = np.random.uniform(0., 1., 3)

def op():
    yy = xx[0] * a * b + xx[1] * b + xx[2] * c
    return yy

def tai():
    d = np.array([a*b, b, c])
    return np.swapaxes(np.swapaxes(xx, 0, 2).dot(d), 0, 1)

def Divakar():
    # improvement given by Divakar
    np.array([a*b,b,c]).dot(xx.swapaxes(0,1))

%timeit op()
7.21 µs ± 222 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit tai()
4.06 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit Divakar()
3 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

我之前并不知道np.swapaxes()函数,谢谢。这个方法比我的原来的方法快,但是仍然比Divakar的方法慢近两倍。 - Gabriel
@Gabriel,没错。Divakar的方法很快 :P - Tai
2
@Gabriel np.array([a*b,b,c]).dot(xx.swapaxes(0,1)) 可以使这个解决方案的性能更接近于我的。 - Divakar
@Divakar确实如此。不过你的速度仍然更快。 - Gabriel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接