在不显式复制的情况下,对角线数组上的numpy线性代数

4
我有一个数组w(形状为(3000, 100, 100)),我想将其乘以另一个数组e(形状为(5, 3000)),使得结果k的形状为(5, 5, 100, 100)
k[:, :, i, j] = e @ np.diag(w[:, i, j]) @ e.T

由于w非常大,使用形状为(3000, 3000, 100, 100)super_w数组并显式填充主对角线是不切实际的。循环遍历ij也不是非常高效的方法。有没有一种内存有效的方法来解决这个问题,而不是将w分成块?

1个回答

4

使用np.einsum -

k = np.einsum('li,ijk,mi->lmjk',e,w,e)

我从未想过使用einsum可以做到这一点! - DathosPachy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接