PyTorch中复数矩阵的行列式

3
有没有在PyTorch中计算复矩阵行列式的方法? torch.det不支持'ComplexFloat'类型。
2个回答

1

很遗憾,目前还没有实现。一种方法是实现自己的版本或者直接使用np.linalg.det。这里是一个使用LU分解计算复杂矩阵行列式的简短函数:

def complex_det(A):
    def complex_diag(A):
        return torch.view_as_complex(torch.stack((A.real.diag(), A.imag.diag()),dim=1))
    #Perform LU decomposition to matrix A:
    A_LU, pivots = A.lu()
    P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
    #Det. of multiplied matrices is multiplcation of det.:
    det = torch.prod(complex_diag(A_L)) * torch.prod(complex_diag(A_U)) * torch.det(P.real) #Could probably calculate det(P) [which is +-1] efficiently using Sylvester's determinant identity
    return det
#Test it:
A = torch.view_as_complex(torch.randn(3,3,2))
complex_det(A)

你说得对,我一直在努力编写自己的版本。完成后我会在这里发布。谢谢Gil。 - DeepRazi
1
我刚更新了这篇帖子,并添加了该版本。快去看看吧 :) - Gil Pinsky

0
从1.8版本开始,PyTorch原生支持numpy风格的torch.linalg操作。特别地,torch.linalg.det支持cfloatcdouble复数数据类型:
torch.linalg.det(input)

计算方阵输入的行列式,或批次输入中每个方阵的行列式。
此函数支持 float、double、cfloat 和 cdouble 数据类型。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接