断言两个字典几乎相等

21
我想要确认两个字典几乎相等,但似乎无法做到。
以下是一个例子:
>>> import nose.tools as nt
>>> nt.assert_dict_equal({'a' : 12.4}, {'a' : 5.6 + 6.8})
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python2.7/unittest/case.py", line 838, in assertDictEqual
    self.fail(self._formatMessage(msg, standardMsg))
  File "/usr/lib/python2.7/unittest/case.py", line 413, in fail
    raise self.failureException(msg)
AssertionError: {'a': 12.4} != {'a': 12.399999999999999}
- {'a': 12.4}
+ {'a': 12.399999999999999}

我希望这个能够通过,就像那样:

>>> nt.assert_almost_equal(12.4, 5.6 + 6.8)

我希望自己只是漏掉了一些简单的东西,比如nt.assert_almost_dict_equal,或者可能有一个参数可以传递给nt.assert_dict_equal,用于指定浮点数应该接近多少,但我找不到任何相关内容。

当然,我可以循环遍历字典,并使用nt.assert_almost_equal逐个比较值;但在我的情况下,字典更加复杂,因此我希望避免这种方法。

有什么最好的方法来断言两个字典几乎相等吗?


2
我认为你需要自己迭代和比较这些值。assert_almost_equal仅适用于可以直接计算其差异的数字类型。 - BrenBarn
7
如果你发现需要自己编写代码,可以在这里查看"assertDeepAlmostEqual":https://github.com/larsbutler/oq-engine/blob/master/tests/utils/helpers.py - dano
@dano,很有趣,谢谢。 - Akavall
5个回答

22

评论@dano回答了我的问题:

我从dano提供的链接中复制了一个函数。

import unittest
import numpy

def assertDeepAlmostEqual(test_case, expected, actual, *args, **kwargs):
    """
    Assert that two complex structures have almost equal contents.

    Compares lists, dicts and tuples recursively. Checks numeric values
    using test_case's :py:meth:`unittest.TestCase.assertAlmostEqual` and
    checks all other values with :py:meth:`unittest.TestCase.assertEqual`.
    Accepts additional positional and keyword arguments and pass those
    intact to assertAlmostEqual() (that's how you specify comparison
    precision).

    :param test_case: TestCase object on which we can call all of the basic
    'assert' methods.
    :type test_case: :py:class:`unittest.TestCase` object
    """
    is_root = not '__trace' in kwargs
    trace = kwargs.pop('__trace', 'ROOT')
    try:
        if isinstance(expected, (int, float, long, complex)):
            test_case.assertAlmostEqual(expected, actual, *args, **kwargs)
        elif isinstance(expected, (list, tuple, numpy.ndarray)):
            test_case.assertEqual(len(expected), len(actual))
            for index in xrange(len(expected)):
                v1, v2 = expected[index], actual[index]
                assertDeepAlmostEqual(test_case, v1, v2,
                                      __trace=repr(index), *args, **kwargs)
        elif isinstance(expected, dict):
            test_case.assertEqual(set(expected), set(actual))
            for key in expected:
                assertDeepAlmostEqual(test_case, expected[key], actual[key],
                                      __trace=repr(key), *args, **kwargs)
        else:
            test_case.assertEqual(expected, actual)
    except AssertionError as exc:
        exc.__dict__.setdefault('traces', []).append(trace)
        if is_root:
            trace = ' -> '.join(reversed(exc.traces))
            exc = AssertionError("%s\nTRACE: %s" % (exc.message, trace))
        raise exc

# My part, using the function

class TestMyClass(unittest.TestCase):
    def test_dicts(self):
        assertDeepAlmostEqual(self, {'a' : 12.4}, {'a' : 5.6 + 6.8})
    def test_dicts_2(self):
        dict_1 = {'a' : {'b' : [12.4, 0.3]}}
        dict_2 = {'a' : {'b' : [5.6 + 6.8, 0.1 + 0.2]}}

        assertDeepAlmostEqual(self, dict_1, dict_2)

def main():
    unittest.main()

if __name__ == "__main__":
    main()

结果:

Ran 2 tests in 0.000s

OK

3
给我下投票的人,能否解释一下投票的原因?我认为自问自答是可以的。你认为我没有给予@dano足够的信用吗? - Akavall
2
这可能不是很好的编程风格,但是如果你通过“unittest.TestCase.assertDeepAlmostEqual = assertDeepAlmostEqual”来进行monkey-patch TestCase,那么你就可以像使用其他测试一样使用它,例如“self.assertDeepAlmostEqual(dict_1, dict_2)”。 - patricksurry
还应该注意的是,在我的情况下,我必须对所有嵌套集合进行排序,因此您可以将此处的代码与此有关排序嵌套集合的代码片段一起使用。 - Daniel Dror
@patricksurry 这可能是将其转换为mixin的情况。 - Elias Dorneles
2
不需要猴子补丁TestCase。你可以简单地继承它。实际上,这就是testtools.TestCase所做的。 - Helmut Grohne
2
Python 2 已经不再受支持,因此我会用 range 替换 xrange。此外,long 应该是 numpy.long,或者你可以明确地 from numpy import long, ndarray,以避免导入整个 numpy。 - Zemogle

4

Pytest的 "approx" 功能非常实用

In [10]: {'a': 2.000001} == pytest.approx({'a': 2}) Out[10]: True


3

我知道你不会为了这个而导入pandas,但如果你正在使用pandas,你可以将字典转换为序列,并使用来自 pandas.testingassert_series_equal ,默认情况下具有 check_exact=False

>>> import pandas as pd
>>> from pandas.testing import assert_series_equal
>>> a = pd.Series({'a' : 12.4})
>>> b = pd.Series({'a': 12.399999999999999})
>>> assert_series_equal(a, b)

1

我无法运行Akavall的函数,因此我自己写了一个。它有点太简单,但对我的目的有效。编写用于测试该函数是否工作的代码时使用pytest。

from numbers import Number
from math import isclose

def dictsAlmostEqual(dict1, dict2, rel_tol=1e-8):
    """
    If dictionary value is a number, then check that the numbers are almost equal, otherwise check if values are exactly equal
    Note: does not currently try converting strings to digits and comparing them. Does not care about ordering of keys in dictionaries
    Just returns true or false
    """
    if len(dict1) != len(dict2):
        return False
    # Loop through each item in the first dict and compare it to the second dict
    for key, item in dict1.items():
        # If it is a nested dictionary, need to call the function again
        if isinstance(item, dict):
            # If the nested dictionaries are not almost equal, return False
            if not dictsAlmostEqual(dict1[key], dict2[key], rel_tol=rel_tol):
                return False
        # If it's not a dictionary, then continue comparing
        # Put in else statement or else the nested dictionary will get compared twice and
        # On the second time will check for exactly equal and will fail
        else:
            # If the value is a number, check if they are approximately equal
            if isinstance(item, Number):
                # if not abs(dict1[key] - dict2[key]) <= rel_tol:
                # https://dev59.com/tG035IYBdhLWcg3wJcjT
                if not isclose(dict1[key], dict2[key], rel_tol=rel_tol):
                    return False
            else:
                if not (dict1[key] == dict2[key]):
                    return False
    return True

使用pytest验证函数输出

import pytest
import dictsAlmostEqual
def test_dictsAlmostEqual():
    a = {}
    b = {}
    assert dictsAlmostEqual(a, b)
    a = {"1": "a"}
    b = {}
    assert not dictsAlmostEqual(a, b)
    a = {"1": "a"}
    b = {"1": "a"}
    assert dictsAlmostEqual(a, b)
    a = {"1": "a"}
    b = {"1": "b"}
    assert not dictsAlmostEqual(a, b)
    a = {"1": "1.23"}
    b = {"1": "1.23"}
    assert dictsAlmostEqual(a, b)
    a = {"1": "1.234"}
    b = {"1": "1.23"}
    assert not dictsAlmostEqual(a, b)
    a = {"1": 1.000000000000001, "2": "a"}
    b = {"1": 1.000000000000002, "2": "a"}
    assert not dictsAlmostEqual(a, b, rel_tol=1e-20)
    assert dictsAlmostEqual(a, b, rel_tol=1e-8)
    assert dictsAlmostEqual(a, b)
    # Nested dicts
    a = {"1": {"2": 1.000000000000001}}
    b = {"1": {"2": 1.000000000000002}}
    assert not dictsAlmostEqual(a, b, rel_tol=1e-20)
    assert dictsAlmostEqual(a, b, rel_tol=1e-8)
    assert dictsAlmostEqual(a, b)
    a = {"1": {"2": 1.000000000000001, "3": "a"}, "2": "1.23"}
    b = {"1": {"2": 1.000000000000002, "3": "a"}, "2": "1.23"}
    assert not dictsAlmostEqual(a, b, rel_tol=1e-20)
    assert dictsAlmostEqual(a, b, rel_tol=1e-8)
    assert dictsAlmostEqual(a, b)

0
我需要一个递归版本的round来进行漂亮的打印。
def recursive_round(data, ndigits=5):
    if isinstance(data, dict):
        return {k: recursive_round(v, ndigits) for k, v in data.items()}
    if isinstance(data, list):
        return [recursive_round(v, ndigits) for v in data]
    if isinstance(data, tuple):
        return tuple(recursive_round(v, ndigits) for v in data)
    if isinstance(data, set):
        return {recursive_round(v, ndigits) for v in data}
    if isinstance(data, float):
        if data.is_integer():
            return int(data)
        return round(data, ndigits)
    return data

作为一个例子:
from pprint import pprint
import math

DATA = {
    'test': [1.23456, 'whatever', (2.34567, math.pi)],
    'another_key': [1.0, 2.0, 0.0, {math.e, math.inf}],
    'last_key': [0.123456789, 9.87654321]
}

pprint(recursive_round(DATA, 3))
# {'another_key': [1, 2, 0, {2.718, inf}],
#  'last_key': [0.123, 9.877],
#  'test': [1.235, 'whatever', (2.346, 3.142)]}

它也可以用于单元测试,使用assertEqual函数:
import unittest
import math


def recursive_round(data, ndigits=5):
    if isinstance(data, dict):
        return {k: recursive_round(v, ndigits) for k, v in data.items()}
    if isinstance(data, list):
        return [recursive_round(v, ndigits) for v in data]
    if isinstance(data, tuple):
        return tuple(recursive_round(v, ndigits) for v in data)
    if isinstance(data, set):
        return {recursive_round(v, ndigits) for v in data}
    if isinstance(data, float):
        if data.is_integer():
            return int(data)
        return round(data, ndigits)
    return data


DATA = {
    'test': [1.23456, 'whatever', (2.34567, math.pi)],
    'another_key': [1.0, 2.0, 0.0, {math.e, math.inf}],
    'last_key': [0.123456789, 9.87654321]
}


class TestClass(unittest.TestCase):

    def test_method(self):
        expected = {'another_key': [1, 2, 0, {2.718, math.inf}],
                    'last_key': [0.123, 9.877],
                    'test': [1.235, 'whatever', (2.346, 3.142)]}
        self.assertEqual(expected, recursive_round(DATA, 3))


if __name__ == '__main__':
    unittest.main()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接