使用einsum计算的叉积

4
我正在尝试尽可能快地计算许多3x1向量对的叉积。这个过程可以通过以下方式实现:
n = 10000
a = np.random.rand(n, 3)
b = np.random.rand(n, 3)
numpy.cross(a, b)

这个问题已经有正确的答案了,但是受到类似问题的这个答案的启发,我认为einsum可能会对我有所帮助。我发现两种方法都可以:

eijk = np.zeros((3, 3, 3))
eijk[0, 1, 2] = eijk[1, 2, 0] = eijk[2, 0, 1] = 1
eijk[0, 2, 1] = eijk[2, 1, 0] = eijk[1, 0, 2] = -1

np.einsum('ijk,aj,ak->ai', eijk, a, b)
np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)

计算叉积的方法很多,但它们的性能令人失望:这两种方法的性能都比np.cross差很多:

%timeit np.cross(a, b)
1000 loops, best of 3: 628 µs per loop

%timeit np.einsum('ijk,aj,ak->ai', eijk, a, b)
100 loops, best of 3: 9.02 ms per loop

%timeit np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
100 loops, best of 3: 10.6 ms per loop

有没有什么想法可以改进 einsum 函数?
2个回答

5
你可以使用np.tensordot来进行矩阵乘法,并在第一级别上丢弃一个维度,然后再使用np.einsum来丢弃另一个维度,如下所示 -
np.einsum('aik,ak->ai',np.tensordot(a,eijk,axes=([1],[1])),b)

或者,我们可以使用np.einsum执行广播的逐元素乘法,然后使用np.tensordot一次性消除两个维度,如下所示 -

np.tensordot(np.einsum('aj,ak->ajk', a, b),eijk,axes=([1,2],[1,2]))

我们可以通过引入新的轴来执行元素级别的乘法,例如a[...,None]*b[:,None],但这似乎会减慢速度。


尽管如此,这些方法相对于仅基于np.einsum的方法有很大改进,但仍无法超过np.cross

运行时间测试 -

In [26]: # Setup input arrays
    ...: n = 10000
    ...: a = np.random.rand(n, 3)
    ...: b = np.random.rand(n, 3)
    ...: 

In [27]: # Time already posted approaches
    ...: %timeit np.cross(a, b)
    ...: %timeit np.einsum('ijk,aj,ak->ai', eijk, a, b)
    ...: %timeit np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
    ...: 
1000 loops, best of 3: 298 µs per loop
100 loops, best of 3: 5.29 ms per loop
100 loops, best of 3: 9 ms per loop

In [28]: %timeit np.einsum('aik,ak->ai',np.tensordot(a,eijk,axes=([1],[1])),b)
1000 loops, best of 3: 838 µs per loop

In [30]: %timeit np.tensordot(np.einsum('aj,ak->ajk',a,b),eijk,axes=([1,2],[1,2]))
1000 loops, best of 3: 882 µs per loop

5
einsum()的乘法运算次数比cross()多,并且在最新的NumPy版本中,cross()不会创建很多临时数组。因此,einsum()不能比cross()更快。
以下是cross()的旧代码:
x = a[1]*b[2] - a[2]*b[1]
y = a[2]*b[0] - a[0]*b[2]
z = a[0]*b[1] - a[1]*b[0]

这里是新的跨域代码:

multiply(a1, b2, out=cp0)
tmp = array(a2 * b1)
cp0 -= tmp
multiply(a2, b0, out=cp1)
multiply(a0, b2, out=tmp)
cp1 -= tmp
multiply(a0, b1, out=cp2)
multiply(a1, b0, out=tmp)
cp2 -= tmp

为加快速度,您需要使用cython或numba。

由于这些操作不涉及循环,因此Cython的优化可能并不显着。当以这种方式表达时,cross更像是一种代数运算而不是数组运算。 - hpaulj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接