使用元组创建任意嵌套的字典

3

给定一个字典,其中键是元组(值为数字或标量),如何用Python的方式将其转换为嵌套字典?问题在于,从输入到输入,元组的长度是任意的。

下面的代码示例中,d1、d2和d3表现出逐渐增加的嵌套性:

from itertools import product

d1 = dict(zip(product('AB', [0, 1]), range(2*2)))
d2 = dict(zip(product('AB', [0, 1], [True, False]), range(2*2*2)))
d3 = dict(zip(product('CD', [0, 1], [True, False], 'AB'), range(2*2*2*2)))

它们所产生的嵌套版本将是:

# For d1
{'A': {0: 0, 1: 1}, 'B': {0: 2, 1: 3}}

# For d2
{'A': {0: {True: 0, False: 1}, 1: {True: 2, False: 3}},
 'B': {0: {True: 4, False: 5}, 1: {True: 6, False: 7}}}

# Beginning of result for d3
{
'C': {
    0: {
        True: {
            'A': 0
            'B': 1
        },
        False: {
            'A': 2,
            'B': 3
        },
    1: # ...

我的尝试:我喜欢在许多其他SO答案中提供的初始化空数据结构的技巧:

from collections import defaultdict

def nested_dict():
    return defaultdict(nested_dict)

但是我在实施时遇到了困难,因为层级数不确定。我可以使用类似以下的内容:

def nest(d: dict) -> dict:
    res = nested_dict()
    for (i, j, k), v in d.items():
        res[i][j][k] = v
    return res

但是这只适用于d2,因为其键名拥有3个级别(i、j、k)。

以下是我试图推广此解决方案的尝试,但我猜测还有更简单的方法。

def set_arbitrary_nest(keys, value):
    """
    >>> keys = 1, 2, 3
    >>> value = 5
    result --> {1: {2: {3: 5}}}
    """

    it = iter(keys)
    last = next(it)
    res = {last: {}}
    lvl = res
    while True:
        try:
            k = next(it)
            lvl = lvl[last]
            lvl[k] = {}
            last = k
        except StopIteration:
            lvl[k] = value
            return res

>>> set_arbitrary_nest([1, 2, 3], 5)
{1: {2: {3: 5}}}
2个回答

3

只需循环遍历每个键,并使用除了最后一个元素之外的所有元素添加字典。保留对最后一个字典的引用,然后使用键元组中的最后一个元素在输出字典中实际设置键-值对:

def nest(d: dict) -> dict:
    result = {}
    for key, value in d.items():
        target = result
        for k in key[:-1]:  # traverse all keys but the last
            target = target.setdefault(k, {})
        target[key[-1]] = value
    return result

你可以使用 functools.reduce() 函数来处理向下遍历字典的工作。
from functools import reduce

def nest(d: dict) -> dict:
    result = {}
    traverse = lambda r, k: r.setdefault(k, {})
    for key, value in d.items():
        reduce(traverse, key[:-1], result)[key[-1]] = value
    return result

我使用了 dict.setdefault() 而不是自动填充的选项 defaultdict(nested_dict),因为这样可以生成一个常规的字典,不会在键不存在时隐式地添加键。

演示:

>>> from pprint import pprint
>>> pprint(nest(d1))
{'A': {0: 0, 1: 1}, 'B': {0: 2, 1: 3}}
>>> pprint(nest(d2))
{'A': {0: {False: 1, True: 0}, 1: {False: 3, True: 2}},
 'B': {0: {False: 5, True: 4}, 1: {False: 7, True: 6}}}
>>> pprint(nest(d3))
{'C': {0: {False: {'A': 2, 'B': 3}, True: {'A': 0, 'B': 1}},
       1: {False: {'A': 6, 'B': 7}, True: {'A': 4, 'B': 5}}},
 'D': {0: {False: {'A': 10, 'B': 11}, True: {'A': 8, 'B': 9}},
       1: {False: {'A': 14, 'B': 15}, True: {'A': 12, 'B': 13}}}}

请注意,上述解决方案是一个干净的O(N)循环(N为输入字典的长度),而由Ajax1234提出的groupby解决方案需要对输入进行排序,使其成为一个O(NlogN)的解决方案。这意味着对于具有1000个元素的字典,groupby()需要10000步骤来产生输出,而O(N)线性循环只需要1000步骤。对于一百万个键,这增加到2000万步骤等等。
此外,Python中的递归速度较慢,因为Python无法将这样的解决方案优化为迭代方法。函数调用相对较昂贵,因此使用递归可能会带来显着的性能成本,因为您大大增加了函数调用次数和框架堆栈操作次数。
时间试验表明,这点问题会导致很大影响;使用您的示例d3和100k运行,我们轻松看到5倍的速度差异:
>>> from timeit import timeit
>>> timeit('n(d)', 'from __main__ import create_nested_dict as n, d3; d=d3.items()', number=100_000)
8.210276518017054
>>> timeit('n(d)', 'from __main__ import nest as n, d3 as d', number=100_000)
1.6089267160277814

0

您可以使用递归的方式来使用 itertools.groupby

from itertools import groupby
def create_nested_dict(d):
  _c = [[a, [(c, d) for (_, *c), d in b]] for a, b in groupby(sorted(d, key=lambda x:x[0][0]), key=lambda x:x[0][0])]
  return {a:b[0][-1] if not any([c for c, _ in b]) else create_nested_dict(b) for a, b in _c}

from itertools import product

d1 = dict(zip(product('AB', [0, 1]), range(2*2)))
d2 = dict(zip(product('AB', [0, 1], [True, False]), range(2*2*2)))
d3 = dict(zip(product('CD', [0, 1], [True, False], 'AB'), range(2*2*2*2)))
print(create_nested_dict(d1.items()))
print(create_nested_dict(d2.items()))
print(create_nested_dict(d3.items())) 

输出:

{'A': {0: 0, 1: 1}, 'B': {0: 2, 1: 3}}
{'A': {0: {False: 1, True: 0}, 1: {False: 3, True: 2}}, 'B': {0: {False: 5, True: 4}, 1: {False: 7, True: 6}}}
{'C': {0: {False: {'A': 2, 'B': 3}, True: {'A': 0, 'B': 1}}, 1: {False: {'A': 6, 'B': 7}, True: {'A': 4, 'B': 5}}}, 'D': {0: {False: {'A': 10, 'B': 11}, True: {'A': 8, 'B': 9}}, 1: {False: {'A': 14, 'B': 15}, True: {'A': 12, 'B': 13}}}}

在这里使用groupby()昂贵的,因为输入保证需要排序!排序会增加O(NlogN)的成本,当一个简单的O(N)循环就可以完成时,这是过度的。此外,把所有内容都放在一行上也非常昂贵。 - Martijn Pieters
@MartijnPieters 是的,然而对于 itertools.groupby 来说,有序输入是必要的。不过,我认为 itertools.groupby 使递归逻辑变得更加容易,也许更加清晰。 - Ajax1234
我强烈反对让逻辑更清晰或更简单。那段代码非常难以阅读。 - Martijn Pieters

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接