将大型NumPy乘法向量化

5

我希望能计算一个大的NumPy数组。我有一个大数组 A, 包含了许多数字。我想计算这些数字不同组合的总和。数据的结构如下:

A = np.random.uniform(0,1, (3743, 1388, 3))
Combinations = np.random.randint(0,3, (306,3))
Final_Product = np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])

我的问题是是否有一种更优雅和内存效率更高的方法来计算这个问题?当涉及到3-D数组时,我发现使用np.dot()令人沮丧。

如果可以的话,Final_Product的形状理想上应该是(3743, 306, 1388)。目前,Final_Product的形状为(306, 3743, 1388),所以我只需要重新调整形状即可。

2个回答

5

np.dot() 不会给你期望的输出,除非你涉及额外的步骤,这可能包括 reshaping。这里有一种 向量化 方法,使用 np.einsum 一次性完成而不需要任何额外的内存开销 -

Final_Product = np.einsum('ijk,lk->lij',A,Combinations)

为了完整起见,这里使用之前讨论过的np.dotreshaping进行解释:

M,N,R = A.shape
Final_Product = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)

运行时测试并验证输出 -

In [138]: # Inputs ( smaller version of those listed in question )
     ...: A = np.random.uniform(0,1, (374, 138, 3))
     ...: Combinations = np.random.randint(0,3, (30,3))
     ...: 

In [139]: %timeit np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])
1 loops, best of 3: 324 ms per loop

In [140]: %timeit np.einsum('ijk,lk->lij',A,Combinations)
10 loops, best of 3: 32 ms per loop

In [141]: M,N,R = A.shape

In [142]: %timeit A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
100 loops, best of 3: 15.6 ms per loop

In [143]: Final_Product =np.array([np.sum( A*cb, axis=2)  for cb in Combinations])
     ...: Final_Product2 = np.einsum('ijk,lk->lij',A,Combinations)
     ...: M,N,R = A.shape
     ...: Final_Product3 = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
     ...: 

In [144]: print np.allclose(Final_Product,Final_Product2)
True

In [145]: print np.allclose(Final_Product,Final_Product3)
True

谢谢!我也发现@ajcr的答案非常有帮助。使用张量,我将np.einsum所用的时间减少了一半。 - Julien
@Julien 我也喜欢ajcr的解决方案!我认为这是dot在这里所做的简洁版本。 - Divakar

5

你可以使用tensordot代替dot。你当前的方法相当于:

np.tensordot(A, Combinations, [2, 1]).transpose(2, 0, 1)

注意结尾的 transpose 来将轴按正确顺序排列。
类似于 dottensordot 函数可以调用快速的 BLAS/LAPACK 库 (如果已安装),因此对于大型数组应该具有良好的性能表现。

简短而简单,我喜欢它! - Divakar
@Divakar:谢谢!不过我还是更喜欢einsum :-) - Alex Riley
我也有同感!!当einsum输出一个3D数组时,它不如将其缩减为一个2D数组或最好的情况是一个scalar高效。 - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接