用Python(numpy)快速乘积矩阵数组的最佳方法

6

我有两个2x2复数矩阵的数组,想知道最快的方法是什么。 (我想对矩阵数组的元素进行矩阵乘法运算。)目前,我的代码如下:

numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))

但是有没有更好的方法呢?

谢谢,

v923z

5个回答

4

numpy.einsum是解决此问题的最佳方案,它在DaveP参考文献的底部提到。代码清晰易懂,并且比循环遍历数组并逐个进行乘法运算要快一个数量级。以下是一些示例代码:

import numpy
l = 100

m1 = rand(l,2,2)
m2 = rand(l,2,2)

m3 = numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
m3e = numpy.einsum('lij,ljk->lik', m1, m2)

%timeit numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
%timeit numpy.einsum('lij,ljk->lik', m1, m2)

print np.all(m3==m3e)

在ipython笔记本中运行时的返回值如下:
每次循环1000次,3次中最好的结果是每次479微秒
每次循环10000次,3次中最好的结果是每次48.9微秒
True


为什么使用循环比手写更快呢? - dfrankow
啊,我想我找到答案了:numpy使用Atlas/BLAS。 - dfrankow
请参见 https://dev59.com/kmsy5IYBdhLWcg3wsADQ#8385745 以获取更多关于此的猜测。 - dfrankow

2
我认为你在寻找的答案在这里不幸的是,这是一个相当混乱的解决方案,涉及到重新塑形。

0
如果m1m2是2x2复矩阵的一维数组,那么它们的形状基本上是(l,2,2)。因此,在最后两个轴上进行矩阵乘法等价于将m1的最后一个轴与m2的倒数第二个轴的乘积相加。这正是np.dot所做的事情:
np.dot(m1,m2)

或者,由于您有复杂矩阵,也许您想先对 m1 取共轭。在这种情况下,请使用 np.vdot

附:如果 m1 是一个 2x2 复杂矩阵的列表,则可以尝试重排代码,使 m1 从一开始就成为形状为 (l,2,2) 的数组。

如果不可能,可以使用列表推导式。

[np.dot(m1[i],m2[i]) for i in range(l)]

使用lambdamap比这种方法更快,但是执行lnp.dot会比上面建议的在两个形状为(l,2,2)的数组上执行一个np.dot要慢。


0
如果m1和m2是1维数组,每个元素都是2x2的复矩阵,则它们的形状基本上为(l,2,2)。因此,在最后两个轴上进行矩阵乘法等价于将m1的最后一个轴与m2的倒数第二个轴相乘,并求和。这正是np.dot所做的。但那不是np.dot所做的。
 a = numpy.array([numpy.diag([1, 2]), numpy.diag([2, 3]), numpy.diag([3, 4])])

创建一个(3,2,2)的数组,其中包含2x2的矩阵。但是,numpy.dot(a,a)会创建6个矩阵,而结果的形状为(3,2,3,2)。这不是我需要的。我需要的是一个包含numpy.dot(a[0],a[0]), numpy.dot(a[1],a[1]), numpy.dot(a[2],a[2])...的数组。

[np.dot(m1[i],m2[i]) for i in range(l)]

应该是可以正常工作,但我还没有验证它是否比 lambda 表达式的映射更快。

祝好,

v923z

编辑:for 循环和 map 运行速度大约相同。消耗大量时间的是转换为 numpy.array 的强制类型转换,但这对于两种方法都必须进行,因此在这里没有任何收益。


0

也许这是一个过时的问题,但我仍在寻找答案。

我尝试了这段代码

a=np.asarray(range(1048576),dtype='complex');b=np.reshape(a//1024,(1024,1024));b=b+1J*b
%timeit c=np.dot(b,b)
%timeit d=np.einsum('ij, ki -> jk', b,b).T

结果是:对于“点”

10 loops, best of 3: 174 ms per loop

关于 'einsum'

1 loops, best of 3: 4.51 s per loop

我已经检查过 c 和 d 是相同的

(c==d).all()
True

仍然是“点”赢了,我仍在寻找更好的方法,但没有成功


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接