这是我的问题。我有两个矩阵 A
和 B
,它们的维度分别为 (n,n,m,m)
和 (n,n)
,其中每个元素都是复数。
以下是我执行的操作以获得矩阵 C
-
C = np.sum(B[:,:,None,None]*A, axis=(0,1))
计算上述内容大约需要6-8秒钟。由于我必须计算许多这样的C
,所以需要很长时间。有更快的方法吗?(我是在多核CPU上使用JAX NumPy进行这些操作;普通NumPy需要更长时间)
如果你想知道,n=77
并且m=512
。我可以并行处理,因为我正在使用集群,但数组的大小消耗了很多内存。