在Numpy中检查矩阵是否对称

50

我正在尝试编写一个带有参数(a,tol=1e-8)的函数,该函数返回一个布尔值,告诉用户矩阵是否对称(对称矩阵等于其转置)。到目前为止,我的代码如下:

def check_symmetric(a, tol=1e-8):
if np.transpose(a, axes=axes) == np.transpose(a, axes=axes):
    return True
def sqr(s):
    rows = len(s)
    for row in sq:
        if len(row) != rows:
            return False
    return True
if a != sqr(s):
    raise ValueError

尽管我一直收到“axes未定义”的消息,所以我非常确定那根本不起作用......我想通过的测试是:
e = np.eye(4)
f = np.diag([1], k=3)
g = e[1:, :]

print(check_symmetric(e))
print(not check_symmetric(e + f))
print(check_symmetric(e + f * 1e-9))
print(not check_symmetric(e + f * 1e-9, 1e-10))
try:
    check_symmetric(g)
    print(False)
except ValueError:
    print(True)

Any help is appreciated, thanks!


假设您的矩阵仅为2D,则不需要使用“axes”关键字。此外,应保持一个矩阵未转置,然后检查其是否与矩阵的转置相等。目前,您正在检查两个转置矩阵的相等性。 - Paul Brodersen
昨天有一个关于对称矩阵测试的问题:http://stackoverflow.com/questions/42876082/python-numpy-see-if-an-array-is-symmetric-within-a-tolerance - hpaulj
哦,谢谢提供链接...虽然我的测试不太对,但我已经开始让它正常工作了:def check_symmetric(a, tol=1e-8): if np.transpose(a.any()) == np.array(a.any()): return True def sqr(s): rows = len(s) for row in sq: if len(row) != rows: return False return True if a != sqr(s): raise ValueError - plshalp
6个回答

112

您可以使用allclose函数将其与其转置进行比较。

def check_symmetric(a, rtol=1e-05, atol=1e-08):
    return numpy.allclose(a, a.T, rtol=rtol, atol=atol)

1
或者,检查a-a.T < (tol*a.shape**2) - jeremy_rutman

24

下面的函数同样可以解决这个问题:

def check_symmetric(a, tol=1e-8):
    return np.all(np.abs(a-a.T) < tol)

4
这个答案实际上也适用于矩阵元素是自定义类型的情况,与被接受的答案中的“allclose”不同。例如,四元数矩阵就是这种情况(参见https://github.com/moble/quaternion)。 - Giorgos Sfikas

7

如果您不担心tot阈值

(a==a.T).all()

这是最简单的解决方案。对于N维(N>2)数组同样适用。


4

这篇文章过时了,但我会推荐另一种方法。特别是对于稀疏矩阵,这种方法可以快上数百倍。

def is_symmetric(A, tol=1e-8):
    return scipy.sparse.linalg.norm(A-A.T, scipy.Inf) < tol;

或类似的,你可以理解这个概念。使用标准化是更优化的计算方法。


2
如果使用SciPy是可以接受的,您可以使用scipy.linalg.issymmetric()(从v1.8.0开始),它还包括一些输入验证。
  • 请参见实现此处
  • 关于性能的说明(来自文档;强调是我的):

    当设置了atol和/或rtol时,则通过numpy.allclose执行比较,并将容差值传递给它。否则,将通过内部函数对零进行精确比较。因此,性能可能会提高或降低,具体取决于数组的大小和dtype。


1
假设我们不考虑容差,即考虑整数矩阵。我们可以在这段代码中使用.all()来检查相应的矩阵元素,并使用.shape来找到矩阵的维度,然后我们可以写出类似这样的代码:
def check_matrix(A:np.array):
    a, b = A.shape
    if a == b and (-1*np.transpose(A) == A).all():
        return -1
    if a == b and (np.transpose(A) == A).all():
        return 1
    return 0

首先,接受一个矩阵作为输入。第一个if语句检查矩阵是否是斜对称的,第二个if语句检查矩阵是否是对称的,如果两者都不是,则返回0。要将此代码适应于浮点数,请使用.allclose()。我知道之前的答案有一个更好和更简洁的代码,但我想分享这个代码,因为其他答案没有使用.shape,所以这是独特的。 :)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接