是否有一个只计算结果对角线条目的numpy / scipy点积?

59

想象一下有两个NumPy数组:

> A, A.shape = (n,p)
> B, B.shape = (p,p)

通常情况下,p是一个较小的数字(p ≤ 200),而n可以任意大。

我正在做以下事情:

result = np.diag(A.dot(B).dot(A.T))

你可以看到,我只保留了n个对角线元素,但中间有一个(n x n)的数组,只从中保留了对角线元素。

我希望有一个像diag_dot()这样的函数,它只计算结果的对角线元素,并且不分配完整的内存。

最终结果如下:

> result = diag_dot(A.dot(B), A.T)

是否有类似的预制功能?而且能否在不需要分配中间 (n x n) 数组的情况下高效地完成此项操作?

3个回答

68

我认为我自己解决了这个问题,但是无论如何我还是会分享解决方案:

因为只想要矩阵乘法的对角线:

> Z = N.diag(X.dot(Y))

等价于X的行与Y的列的数量积的总和,前面的陈述等价于:

> Z = (X * Y.T).sum(-1)

对于原始变量,这意味着:

> result = (A.dot(B) * A).sum(-1)

如果我有错的话,请纠正我,但应该就是这样了...


22
+1 聪明的代数运算总比复杂的算法更好。 - Jaime
如果有人对numpy不熟悉,这里的重点在于X.dot(Y)运算符和*运算符之间的区别。 X.dot(Y)表示线性代数中传统的矩阵乘积,而X * Y返回X和Y条目之间的逐点乘积,因此X和Y需要具有相同的形状。 - ijuneja
虽然我喜欢你的回答,但我认为它计算了所有元素。这个问题的重点是让编译器理解我们只对对角线元素感兴趣 :) - taless

41

使用 numpy.einsum,你可以获得几乎梦想中的任何东西。在你开始掌握它之前,它基本上看起来像黑魔法...

>>> a = np.arange(15).reshape(5, 3)
>>> b = np.arange(9).reshape(3, 3)

>>> np.diag(np.dot(np.dot(a, b), a.T))
array([  60,  672, 1932, 3840, 6396])
>>> np.einsum('ij,ji->i', np.dot(a, b), a.T)
array([  60,  672, 1932, 3840, 6396])
>>> np.einsum('ij,ij->i', np.dot(a, b), a)
array([  60,  672, 1932, 3840, 6396])

编辑:你实际上可以一次性获取整个东西,这太荒谬了...

>>> np.einsum('ij,jk,ki->i', a, b, a.T)
array([  60,  672, 1932, 3840, 6396])
>>> np.einsum('ij,jk,ik->i', a, b, a)
array([  60,  672, 1932, 3840, 6396])

编辑 不过你不希望它自己解决得太多......另外,为了比较,将原帖回答添加到了问题本身中。

n, p = 10000, 200
a = np.random.rand(n, p)
b = np.random.rand(p, p)

In [2]: %timeit np.einsum('ij,jk,ki->i', a, b, a.T)
1 loops, best of 3: 1.3 s per loop

In [3]: %timeit np.einsum('ij,ij->i', np.dot(a, b), a)
10 loops, best of 3: 105 ms per loop

In [4]: %timeit np.diag(np.dot(np.dot(a, b), a.T))
1 loops, best of 3: 5.73 s per loop

In [5]: %timeit (a.dot(b) * a).sum(-1)
10 loops, best of 3: 115 ms per loop

我之前不知道这个函数,但现在肯定会了。谢谢分享! - user2051916
我相信'In [3]'依赖于“dot”是高度优化的c代码(blas?)但是这确实构建了一个潜在的大型中间数组。 - Dave
我不确定 np.einsum('ij,ij->i', np.dot(a, b), a) 的结果是什么,但它肯定与 Z = N.diag(X.dot(Y)) 不同,后者提供了点积的对角元素(在这种情况下为 [15, 54, 111])。 - Rony Armon
1
自2013年以来,已经发生了一些变化。使用optimize=True选项可以提高einsum的性能。有一个新的面向批处理的matmul可以通过(A[:,None,:]@B@A[:,:,None])[:,0,0]来解决这个问题。但是(A.dot(B) * A).sum(-1)仍然是一种好的方法。 - hpaulj

2
一个避免构建大型中间数组的简单方法是:
result=np.empty([n,], dtype=A.dtype )
for i in xrange(n):
    result[i]=A[i,:].dot(B).dot(A[i,:])

[n.] 不是有效的 Python 语法。您是否想表达 A.shape - wkschwartz
@wkschwartz已修复。不,只是预先分配结果数组。 - Dave

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接