PyTorch中复数矩阵乘法

3
2个回答

3

当前,torch.matmul 不支持像 ComplexFloatTensor 这样的复杂张量,但您可以使用以下紧凑的代码:

def matmul_complex(t1,t2):
    return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag + t1.imag @ t2.real),dim=2))

尽可能避免使用for循环,因为这会导致实现速度变慢。向量化是通过使用内置方法来实现的,如我附加的代码所示。例如,对于2个随机复杂矩阵,维度为1000 X 1000,您的代码在CPU上大约需要6.1秒,而向量化版本仅需要101ms(快了约60倍)。
更新: 自PyTorch 1.7.0以来(如@EduardoReis所提到的),您可以像处理实值矩阵一样处理复杂矩阵进行矩阵乘法,如下所示: t1 @ t2(对于t1,t2为复杂矩阵)。

最近,我使用torch 1.8.1+cu101成功地将两个张量简单相乘,得到它们的复数积。 - Eduardo Reis
1
@EduardoReis 您是正确的。自 PyTorch 1.7.0 版本以来,您可以缩短上面的代码。但请注意,t1 * t2 是张量 t1t2 之间的逐点乘法。您可以使用 t1 @ t2 来获得等同于 matmul_complex 的矩阵乘法。我已更新帖子。 - Gil Pinsky

0
我使用torch.mv为复数实现了pytorch.matmul函数,目前运行良好。
def matmul_complex(t1, t2):
  m = list(t1.size())[0]
  n = list(t2.size())[1]
  t = torch.empty((1,n), dtype=torch.cfloat)
  t_total = torch.empty((m,n), dtype=torch.cfloat)
  for i in range(0,n):
    if i == 0:
      t_total = torch.mv(t1,t2[:,i])
    else:
      t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
  t_final = torch.reshape(t_total, (m,n))
  return t_final

我是PyTorch的新手,如果我有错误,请纠正我。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接