序列中的前n个最大元素(需要保留重复项)

8

我需要在元组列表中找到前n个最大的元素。以下是查找前3个元素的示例。

# I have a list of tuples of the form (category-1, category-2, value)
# For each category-1, ***values are already sorted descending by default***
# The list can potentially be approximately a million elements long.
lot = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4',  8), ('a', 'x5', 8), ('a', 'x6', 7),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), 
       ('b', 'x4',  7), ('b', 'x5', 6), ('b', 'x6', 5)]

# This is what I need. 
# A list of tuple with top-3 largest values for each category-1
ans = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4', 8), ('a', 'x5', 8),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

我尝试使用heapq.nlargest函数。然而,该函数仅返回前三个最大的元素,并且不会返回重复的元素。例如:

heapq.nlargest(3, [10, 10, 10, 9, 8, 8, 7, 6])
# returns
[10, 10, 10]
# I need
[10, 10, 10, 9, 8, 8]

我只能想到一种蛮力的方法。这是我现在拥有的,它可以工作。

res, prev_t, count = [lot[0]], lot[0], 1
for t in lot[1:]:
    if t[0] == prev_t[0]:
        count = count + 1 if t[2] != prev_t[2] else count
        if count <= 3:
            res.append(t)   
    else:
        count = 1
        res.append(t)
    prev_t = t

print res

还有其他的想法可以实现这个吗?

编辑:timeit结果表明,对于1百万元素列表,mhyfritz的解决方案运行时间是暴力解决方案的三分之一。不想让问题太长,所以在我的回答中添加了更多细节。


1
@Toader Mihai Claudiu是正确的。你可以尝试一种优化方法,将所有的键分成不同的列表,并在从每个列表中选择前三个后退出循环。这样你就不必遍历整个列表了。(这是假设你在第一次排序时没有花费时间。在没有排序约束的情况下,堆解决方案应该是最好的) - GeneralBecos
6个回答

7
我从你的代码片段中了解到,lot 是按照 category-1 进行分组的。那么以下代码应该可以工作:
from itertools import groupby, islice
from operator import itemgetter

ans = []
for x, g1 in groupby(lot, itemgetter(0)):
    for y, g2 in islice(groupby(g1, itemgetter(2)), 0, 3):
        ans.extend(list(g2))

print ans
# [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8),
#  ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

2
以下是一行代码: list(chain(*(list(g2) for x, g1 in groupby(lot, itemgetter(0)) for y, g2 in islice(groupby(g1, itemgetter(2)), 0, 3)))) - Andrew Clark
第二个循环中islicegroupby的组合真是太棒了!感谢您提供如此出色的解决方案! - Praveen Gollakota
是的。这看起来很漂亮 :) - Mihai Toader
主要注意事项:如果输入的顺序发生变化,代码将在未来出现错误。 - ninjagecko
@Praveen Gollakota:很高兴我能帮到你。同时,感谢您提供额外的细节(timeit比较,跟踪)。 - mhyfritz

2
如果您已经按照这种方式对输入数据进行了排序,那么您的解决方案很可能比基于heapq的解决方案好一点。
您的算法复杂度为O(n),而基于heapq的算法在概念上为O(n * log(3)),可能需要更多次地遍历数据以正确排列它。

1

一些额外的细节... 我计时了mhyfritz的优秀解决方案,它使用了itertools和我的代码(暴力解法)。

这里是n = 10和一个包含100万个元素的列表的timeit结果。

# Here's how I built the sample list of 1 million entries.
lot = []
for i in range(1001):
    for j in reversed(range(333)):
        for k in range(3):
            lot.append((i, 'x', j))

# timeit Results for n = 10
brute_force = 6.55s
itertools = 2.07s
# clearly the itertools solution provided by mhyfritz is much faster.

如果有人感兴趣,这里是他代码的追踪。
+ Outer loop - x, g1
| a [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8), ('a', 'x6', 7)]
+-- Inner loop - y, g2
  |- 10 [('a', 'x1', 10)]
  |- 9 [('a', 'x2', 9), ('a', 'x3', 9)]
  |- 8 [('a', 'x4', 8), ('a', 'x5', 8)]
+ Outer loop - x, g1
| b [('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), ('b', 'x4', 7), ('b', 'x5', 6), ('b', 'x6', 5)]
+-- Inner loop - y, g2
  |- 10 [('b', 'x1', 10)]
  |- 9 [('b', 'x2', 9)]
  |- 8 [('b', 'x3', 8)]

0
这是一个想法,创建一个字典,以您想要排序的值作为键,具有该值的元组列表作为值。
然后按键对字典中的项目进行排序,从顶部获取项目,提取它们的值并将它们连接起来。
快速而丑陋的代码:
>>> sum(
        map(lambda x: x[1],
            sorted(dict([(x[2], filter(lambda y: y[2] == x[2], lot))
                for x in lot]).items(),
                reverse=True)[:3]),
    [])

7: [('a', 'x1', 10),
 ('b', 'x1', 10),
 ('a', 'x2', 9),
 ('a', 'x3', 9),
 ('b', 'x2', 9),
 ('a', 'x4', 8),
 ('a', 'x5', 8),
 ('b', 'x3', 8)]

仅供参考,希望有所帮助。如果需要澄清,请在评论中提问。


0
这个怎么样?它并不完全返回您想要的结果,因为它是按照 y 的反向排序。
# split lot by first element of values
lots = defaultdict(list)
for x, y, z in lot:
    lots[x].append((y, z))

ans = []
for x, l in lots.iteritems():
    # find top-3 unique values
    top = nlargest(3, set(z for (y, z) in l))
    ans += [(x, y, z) for (z, y) in sorted([(z, y) for (y, z) in l
                                                   if z in top],
                                           reverse=True)]

print ans

0
from collections import *

categories = defaultdict(lambda: defaultdict(lambda: set()))
for t in myTuples:
    cat1,cat2,val = t
    categories[cat1][val].add(t)

def onlyTopThreeKeys(d):
    keys = sorted(d.keys())[-3:]
    return {k:d[k] for k in keys}

print( {cat1:onlyTopThreeKeys(sets) for cat1,sets in categories.items()} )

结果:

{'a': {8: {('a', 'x5', 8), ('a', 'x4', 8)},
       9: {('a', 'x3', 9), ('a', 'x2', 9)},
       10: {('a', 'x1', 10)}},
 'b': {8: {('b', 'x3', 8)}, 
       9: {('b', 'x2', 9)}, 
       10: {('b', 'x1', 10)}}}

平面列表:我使用上述方法是因为它可以提供更多的信息。如果你仅需要一个平面列表,可以使用闭包来发出onlyTopThreeKeys的结果:

from collections import *

def topTiedThreeInEachCategory(tuples):
    categories = defaultdict(lambda: defaultdict(lambda: set()))
    for t in myTuples:
        cat1,cat2,val = t
        categories[cat1][val].add(t)

    reap = set()

    def sowTopThreeKeys(d):
        keys = sorted(d.keys())[-3:]
        for k in keys:
            for x in d[k]:
                reap.add(x)
    for sets in categories.values():
        sowTopThreeKeys(sets)

    return reap

结果:

>>> topTiedThreeInEachCategory(myTuples)
{('b', 'x2', 9), ('a', 'x1', 10), ('b', 'x3', 8), ('a', 'x2', 9), ('a', 'x4', 8), ('a', 'x3', 9), ('a', 'x5', 8), ('b', 'x1', 10)}

如果你的输入保证像样例输入一样是有序的,你也可以使用 itertools.groupby,但是如果排序方式改变了,这会导致你的代码出错。


这样做是否无法提取这两个条目:('a', 'x2', 9), ('a', 'x3', 9)? - Mihai Toader
@Toader:这就是为什么我已经提供了样本输出,你可以自己查看。 - ninjagecko
现在它能工作是因为这个:categories[cat1][val].add(t)。当我注释掉时,它是这样的:categories[cat1][val] = t :) - Mihai Toader
@Toader:哦,糟糕,抱歉,我没有意识到你的评论是20分钟之前的。=) - ninjagecko

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接