是否有一个“增强版”的numpy/scipy点积方法?

28

问题

我想使用numpy或scipy计算以下内容:

Y = A**T * Q * A

其中A是一个m x n的矩阵,A**TA的转置矩阵,Q是一个m x m的对角线矩阵。

由于Q是一个对角线矩阵,我只将其对角线元素存储为一个向量。

计算Y的方法

目前我能想到两种计算Y的方法:

  1. Y = np.dot(np.dot(A.T, np.diag(Q)), A)
  2. Y = np.dot(A.T * Q, A)

显然,选项2比选项1更好,因为不需要创建带有diag(Q)的实际矩阵(如果这就是numpy所做的...)
然而,这两种方法都存在缺陷,需要分配比实际所需更多的内存,因为A.T * Qnp.dot(A.T, np.diag(Q))必须与A一起存储才能计算Y

问题

在numpy/scipy中是否有一种方法可以消除额外内存的不必要分配,只需传递两个矩阵AB(在我的情况下,BA.T)以及一个加权向量Q

3个回答

26
关于原帖中的最后一句话:我不知道是否有这样的numpy/scipy方法,但是关于原帖标题中的问题(即如何提高NumPy点积性能),以下内容应该会对你有所帮助。换句话说,我的回答旨在改善大部分构成Y函数的步骤的性能。
首先,这应该比原始的NumPy点积方法给你带来明显的提升:
>>> from scipy.linalg import blas as FB
>>> vx = FB.dgemm(alpha=1., a=v1, b=v2, trans_b=True)

请注意,这两个数组v1、v2都是按照C_FORTRAN顺序排列的。
您可以通过数组的flags属性来访问NumPy数组的字节顺序,如下所示:
>>> c = NP.ones((4, 3))
>>> c.flags
      C_CONTIGUOUS : True          # refers to C-contiguous order
      F_CONTIGUOUS : False         # fortran-contiguous
      OWNDATA : True
      MASKNA : False
      OWNMASKNA : False
      WRITEABLE : True
      ALIGNED : True
      UPDATEIFCOPY : False

要改变其中一个数组的顺序,使两个数组对齐,只需调用NumPy数组构造函数,传入该数组并将适当的order标志设置为True即可。

>>> c = NP.array(c, order="F")

>>> c.flags
      C_CONTIGUOUS : False
      F_CONTIGUOUS : True
      OWNDATA : True
      MASKNA : False
      OWNMASKNA : False
      WRITEABLE : True
      ALIGNED : True
      UPDATEIFCOPY : False

您可以通过利用数组顺序对齐来进一步优化,以减少由于复制原始数组而导致的过多内存消耗。

但是为什么在传递给点乘之前要复制数组呢?

点积依赖于BLAS操作。这些操作需要以C连续顺序存储的数组--正是这个约束导致数组被复制。

另一方面,转置不会产生副本,但不幸的是结果返回为Fortran顺序:

因此,为了消除性能瓶颈,您需要消除谓词数组复制步骤;只需要以C连续顺序将两个数组传递给点乘即可做到这一点*。

因此,要计算dot(A.T., A)不需要额外复制:

>>> import scipy.linalg.blas as FB
>>> vx = FB.dgemm(alpha=1.0, a=A.T, b=A.T, trans_b=True)

总之,上面的表达式(连同谓词导入语句)可以替代点号以提供相同的功能但更好的性能。
您可以将该表达式绑定到一个函数中,如下所示:
>>> super_dot = lambda v, w: FB.dgemm(alpha=1., a=v.T, b=w.T, trans_b=True)

1
这确实是一种很好的访问BLAS例程的方式。我相信将来我可以充分利用它。 然而,仍然有一个“Q”需要在这里插入... :) - Woltan
当然,我应该在上面的回答中加上“我不知道有这样一个numpy/scipy方法,正如OP最后一句所描述的那样,但是”这样的前缀,以下是如何提高您的函数Y中包含的大多数步骤的性能;如果这让您产生了误导,我很抱歉(我的回答已经相应地进行了编辑)。 - doug
@doug 速度方面怎么样? 在我的测试中,无论矩阵的顺序如何,dgemm似乎比np.dot慢得多。 dgemm应该比numpy.dot更快吗? - user2675516
你能解释一下'trans_b=True'是什么意思吗?参考资料中没有详细的描述。 - nosense

4

我只是想在SO上发布这个内容,但这个拉取请求应该很有帮助,可以消除numpy.dot需要单独函数的需求。 https://github.com/numpy/numpy/pull/2730 这应该在numpy 1.7中可用。

与此同时,我使用上面的例子编写了一个函数,可以替换numpy dot,无论数组的顺序如何,并正确调用fblas.dgemm。 http://pastebin.com/M8TfbURi

希望这可以帮到你,


您是否有任何关于numpy 1.8中np.dot是否会复制参数的示例?或者,我该如何判断? - denis

2

numpy.einsum 是您寻找的函数:

numpy.einsum('ij, i, ik -> jk', A, Q, A)

这不需要任何额外的内存(虽然通常einsum比BLAS操作慢)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接