查找一个字典列表的所有排列组合

3

我有一个包含字母及其频率的字典列表。

letter_freq = [
   {'a': 10, 'b': 7},
   {'d': 15, 'g': 8},
   {'a': 12, 'q': 2}
]

我想找到这些字典的所有可能组合,以及它们的值的总和:
perms = {
   'ada': 37, 'adq': 27, 'aga': 30, 'agq': 20, 'bda': 34, 'bdq': 24, 'bga': 27, 'bgq': 17
}

我研究了itertools.product(),但我不知道如何将其应用于这种特定用例。我的直觉告诉我,实现这个最简单的方法是制作一个递归函数,但我还不知道如何添加值和键的字符串,并使它们都起作用。

此外,这个列表和字典可以是任意长度。有没有一种我还没有发现的简单方法来做到这一点?谢谢!


2
所有字典中的键是否保证唯一? - Jon Clements
不,键可以重复。 - Jakob
2
@Jakob 如果是这样的话,那么调整 perms 并展示 letter_freq 的示例会更有用。 - Brad Solomon
adq 应该是 39 而不是 27 吗? - Jon Clements
@JonClements 我不这么认为 - 10 + 15 + 2 = 27 - Jakob
3个回答

5

解决方案基准测试:

是的,itertools.product可行:

from itertools import product

perms = {
    ''.join(keys): sum(vals)
    for prod in product(*map(dict.items, letter_freq))
    for keys, vals in [zip(*prod)]
}

另外一种方法是将产品的键和值分别构建,这样我们就不必再把它们分开了:

perms = {
    ''.join(keys): sum(vals)
    for keys, vals in zip(product(*letter_freq),
                          product(*map(dict.values, letter_freq)))
}

或完全分离它们的结构(我最喜欢的方法):

keys = map(''.join, product(*letter_freq))
vals = map(sum, product(*map(dict.values, letter_freq)))
perms = dict(zip(keys, vals))

性能测试会很有趣,我猜我的最后一个测试将是这些测试中最快的,也比Samwise的测试要快。

另一个灵感来自于对constantstranger的一瞥(但在一些初始基准测试中比他们快得多):

items = [('', 0)]
for d in letter_freq:
    items = [(k0+k, v0+v)
             for k, v in d.items()
             for k0, v0 in items]
perms = dict(items)

性能测试:

以您的字典列表为例:

  6.6 μs  perms1
  4.5 μs  perms2
  4.1 μs  perms3
  4.0 μs  perms4
 11.0 μs  perms_Samwise
 12.7 μs  perms_constantstranger

有一个包含七个字典的列表,每个字典有四个条目。
 15.5 ms  perms1
  7.6 ms  perms2
  5.5 ms  perms3
  4.8 ms  perms4
 27.2 ms  perms_Samwise
 42.2 ms  perms_constantstranger

代码(在线尝试!):

def perms1(letter_freq):
    return {
        ''.join(keys): sum(vals)
        for prod in product(*map(dict.items, letter_freq))
        for keys, vals in [zip(*prod)]
    }

def perms2(letter_freq):
    return {
        ''.join(keys): sum(vals)
        for keys, vals in zip(product(*letter_freq),
                              product(*map(dict.values, letter_freq)))
    }

def perms3(letter_freq):
    keys = map(''.join, product(*letter_freq))
    vals = map(sum, product(*map(dict.values, letter_freq)))
    return dict(zip(keys, vals))

def perms4(letter_freq):
    items = [('', 0)]
    for d in letter_freq:
        items = [(k0+k, v0+v)
                 for k, v in d.items()
                 for k0, v0 in items]
    return dict(items)

def perms_Samwise(letter_freq):
    return {''.join(k for k, _ in p): sum(v for _, v in p) for p in itertools.product(*(d.items() for d in letter_freq))}

def perms_constantstranger(letter_freq):
    stack = [['', 0]]
    [stack.append((stack[i][0] + k, stack[i][1] + v)) for row in letter_freq if (lenStack := len(stack)) for k, v in row.items() for i in range(lenStack)]
    return dict(row for row in stack if len(row[0]) == len(letter_freq))

funcs = perms1, perms2, perms3, perms4, perms_Samwise, perms_constantstranger

letter_freq = [
   {'a': 10, 'b': 7, 'c': 5, 'd': 2},
   {'d': 15, 'g': 8, 'j': 6, 'm': 3},
   {'a': 12, 'q': 2, 'x': 1, 'z': 4},
   {'a': 10, 'b': 7, 'c': 5, 'd': 2},
   {'d': 15, 'g': 8, 'j': 6, 'm': 3},
   {'a': 12, 'q': 2, 'x': 1, 'z': 4},
   {'a': 10, 'b': 7, 'c': 5, 'd': 2},
]

from timeit import repeat
import itertools
from itertools import product

expect = funcs[0](letter_freq)
for func in funcs:
    result = func(letter_freq)
    assert result == expect

for _ in range(3):
    for func in funcs:
        t = min(repeat(lambda: func(letter_freq), number=1))
        print('%5.1f ms ' % (t * 1e3), func.__name__)
    print()

3

itertools.product 确实是你想要的。

>>> letter_freq = [
...    {'a': 10, 'b': 7},
...    {'d': 15, 'g': 8},
...    {'a': 12, 'q': 2}
... ]
>>> import itertools
>>> {''.join(k for k, _ in p): sum(v for _, v in p) for p in itertools.product(*(d.items() for d in letter_freq))}
{'ada': 37, 'adq': 27, 'aga': 30, 'agq': 20, 'bda': 34, 'bdq': 24, 'bga': 27, 'bgq': 17}

1

如果出于任何原因,您想使用列表推导式而不是product()map()来生成自己的排列,可以按照以下方式进行:

        letter_freq = [
           {'a': 10, 'b': 7},
           {'d': 15, 'g': 8},
           {'a': 12, 'q': 2}
        ]       
        stack = [['', 0]]
        [stack.append((stack[i][0] + k, stack[i][1] + v)) for row in letter_freq if (lenStack := len(stack)) for k, v in row.items() for i in range(lenStack)]
        perms = dict(row for row in stack if len(row[0]) == len(letter_freq))
        print(perms)

输出:

{'ada': 37, 'bda': 34, 'aga': 30, 'bga': 27, 'adq': 27, 'bdq': 24, 'agq': 20, 'bgq': 17}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接