比较两个具有numpy矩阵值的字典

21
我想要确认两个Python字典是否相等(即:键的数量相等,并且每个键值对应的映射值相等;顺序不重要)。一种简单的方法是使用assert A==B,但是如果字典的值为numpy数组,则此方法不起作用。我该如何编写一个通用的函数来检查两个字典是否相等?
>>> import numpy as np
>>> A = {1: np.identity(5)}
>>> B = {1: np.identity(5) + np.ones([5,5])}
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

编辑 我知道应该使用.all()检查numpy矩阵是否相等。我想找到一种通用的方法来检查它们,而不必检查isinstance(np.ndarray)。这可能吗?

与numpy数组无关的相关主题:

4个回答

26

这不会返回布尔值。相反,如果对象不相等,它会抛出异常。它可以用于构建cmp函数,但本身不是一个函数。 - Guilherme de Lazari
@GuilhermedeLazari 在某些时候你在深究细节了。只要使用try/except块编写您的cmp函数即可。它几乎是自己编写的。 - eric
3
原问题是“我想要确定两个Python字典是否相等”。 - vitiral
这个答案只适用于你事先知道这些值是numpy数组的情况。问题在于找到一种通用方式,而不是首先检查值的实例类型。 - physicalattraction
我已经有一段时间没有使用它了,但文档说“给定两个对象(标量、列表、元组、字典或numpy数组),检查这些对象的所有元素是否相等。在第一个冲突值时会引发异常。”看起来应该适用于其他类型,但如果你发现这不是真的,那可能是一个错误。 - vitiral

10
我将回答你问题标题和前半部分中隐藏的一半问题,因为坦白地说,这是一个更常见的需要解决的问题,而现有的答案并没有很好地解决它。这个问题是“如何比较两个numpy数组字典的相等性”?
问题的第一部分是从远处检查字典:确保它们的键相同。如果所有键都相同,则第二部分是比较每个对应的值。
现在一个微妙的问题是许多numpy数组并非整数值,而双精度不准确。因此,除非您有整数值(或其他类似于非浮点数的数组),否则您可能需要检查这些值是否几乎相同,即在机器精度范围内。因此,在这种情况下,您不会使用np.array_equal(它检查精确的数值相等性),而是使用np.allclose(它使用有限容差来比较两个数组之间的相对误差和绝对误差)。
问题的前一个半部分很简单:检查字典的键是否一致,并使用生成器推导式比较每个值(并使用all在推导式外部验证每个项目是否相同):
import numpy as np

# some dummy data

# these are equal exactly
dct1 = {'a': np.array([2, 3, 4])}
dct2 = {'a': np.array([2, 3, 4])}

# these are equal _roughly_
dct3 = {'b': np.array([42.0, 0.2])}
dct4 = {'b': np.array([42.0, 3*0.1 - 0.1])}  # still 0.2, right?

def compare_exact(first, second):
    """Return whether two dicts of arrays are exactly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.array_equal(first[key], second[key]) for key in first)

def compare_approximate(first, second):
    """Return whether two dicts of arrays are roughly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.allclose(first[key], second[key]) for key in first)

# let's try them:
print(compare_exact(dct1, dct2))  # True
print(compare_exact(dct3, dct4))  # False
print(compare_approximate(dct3, dct4))  # True

如上例所示,整数数组可以进行精确比较,并且根据您的实际需求(或运气好的话),甚至可以使用浮点数进行比较。但是,如果您的浮点数是任何算术操作(例如线性变换)的结果,那么您应该使用近似检查。关于后一种选择的完整描述,请参见numpy.allclose的文档(以及其逐元素的朋友numpy.isclose),特别注意rtolatol关键字参数。

0
你可以分别提取两个字典的键和值,然后将键与键进行比较,将值与值进行比较: 以下是解决方案:
import numpy as np

def dic_to_keys_values(dic):
    keys, values = list(dic.keys()), list(dic.values())
    return keys, values

def numpy_assert_almost_dict_values(dict1, dict2):
    keys1, values1 = dic_to_keys_values(dict1)
    keys2, values2 = dic_to_keys_values(dict2)
    np.testing.assert_equal(keys1, keys2)
    np.testing.assert_almost_equal(values1, values2)

dict1 = {"b": np.array([1, 2, 0.2])}
dict2 = {"b": np.array([1, 2, 3 * 0.1 - 0.1])}  # almost 0.2, but not equal
dict3 = {"b": np.array([999, 888, 444])} # completely different

numpy_assert_almost_dict_values(dict1, dict2) # no exception because almost equal
# numpy_assert_almost_dict_values(dict1, dict3) # exception because not equal


(注意,上面的代码检查了确切的键和几乎相等的值)

-3

考虑以下代码

>>> import numpy as np
>>> np.identity(5)
array([[ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  1.]])
>>> np.identity(5)+np.ones([5,5])
array([[ 2.,  1.,  1.,  1.,  1.],
       [ 1.,  2.,  1.,  1.,  1.],
       [ 1.,  1.,  2.,  1.,  1.],
       [ 1.,  1.,  1.,  2.,  1.],
       [ 1.,  1.,  1.,  1.,  2.]])
>>> np.identity(5) == np.identity(5)+np.ones([5,5])
array([[False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)
>>> 

请注意,比较的结果是一个矩阵,而不是布尔值。字典比较将使用值的cmp方法来比较值,这意味着在比较矩阵值时,字典比较将得到一个复合结果。你想要做的是使用numpy.all将复合数组结果折叠成标量布尔结果。
>>> np.all(np.identity(5) == np.identity(5)+np.ones([5,5]))
False
>>> np.all(np.identity(5) == np.identity(5))
True
>>> 

您需要编写自己的函数来比较这些字典,测试值类型以查看它们是否为矩阵,然后使用numpy.all进行比较,否则使用==。当然,如果您想要更高级一些,可以开始子类化dict并重载cmp


我不太清楚,但我希望有一种通用的方式,而不需要显式地检查类型。今天可能是一个 numpy 数组,明天可能是我从未听说过的新类型。 - physicalattraction
我恐怕没有办法绕过它。如果你的(或numpy的或其他人的)类型覆盖__cmp__以返回非标量值,则标准Python比较将无法处理它。 - sirlark
你不需要编写自己的函数,因为numpy已经为你准备好了。请参考vitiral的答案。 - EL_DON

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接