Pyspark:按组分组,然后获取每个组的最大值

9

我想使用PySpark按值分组,然后在每个组中找到最大值。我有以下代码,但现在我在如何提取最大值方面卡住了。

# some file contains tuples ('user', 'item', 'occurrences')
data_file = sc.textData('file:///some_file.txt')
# Create the triplet so I index stuff
data_file = data_file.map(lambda l: l.split()).map(lambda l: (l[0], l[1], float(l[2])))
# Group by the user i.e. r[0]
grouped = data_file.groupBy(lambda r: r[0])
# Here is where I am stuck 
group_list = grouped.map(lambda x: (list(x[1]))) #?

返回类似于以下内容:
[[(u'u1', u's1', 20), (u'u1', u's2', 5)], [(u'u2', u's3', 5), (u'u2', u's2', 10)]]

我现在希望找到每个用户的最大“occurrence”。在执行最大值后,最终结果将得到一个类似于以下RDD的结果:
[[(u'u1', u's1', 20)], [(u'u2', u's2', 10)]]

只保留文件中每个用户的最大数据集。换句话说,我想将RDD的更改为仅包含单个三元组,即每个用户的最大出现次数。

2个回答

13

这里不需要使用groupBy,简单的reduceByKey就足够了,而且大多数情况下会更高效:

data_file = sc.parallelize([
   (u'u1', u's1', 20), (u'u1', u's2', 5),
   (u'u2', u's3', 5), (u'u2', u's2', 10)])

max_by_group = (data_file
  .map(lambda x: (x[0], x))  # Convert to PairwiseRD
  # Take maximum of the passed arguments by the last element (key)
  # equivalent to:
  # lambda x, y: x if x[-1] > y[-1] else y
  .reduceByKey(lambda x1, x2: max(x1, x2, key=lambda x: x[-1])) 
  .values()) # Drop keys

max_by_group.collect()
## [('u2', 's2', 10), ('u1', 's1', 20)]

1
你能解释一下这个代码 (lambda x1, x2: max(x1, x2, key=lambda x: x[-1])) 吗? - WoodChopper
1
@WoodChopper max只是标准的Python max。它接受元素并返回最大的元素。key参数描述了应该如何比较元素(这里是通过最后一个项目)。 - zero323

2
我想我找到了解决方案:
from pyspark import SparkContext, SparkConf

def reduce_by_max(rdd):
    """
    Helper function to find the max value in a list of values i.e. triplets. 
    """
    max_val = rdd[0][2]
    the_index = 0

    for idx, val in enumerate(rdd):
        if val[2] > max_val:
            max_val = val[2]
            the_index = idx

    return rdd[the_index]

conf = SparkConf() \
    .setAppName("Collaborative Filter") \
    .set("spark.executor.memory", "5g")
sc = SparkContext(conf=conf)

# some file contains tuples ('user', 'item', 'occurrences')
data_file = sc.textData('file:///some_file.txt')

# Create the triplet so I can index stuff
data_file = data_file.map(lambda l: l.split()).map(lambda l: (l[0], l[1], float(l[2])))

# Group by the user i.e. r[0]
grouped = data_file.groupBy(lambda r: r[0])

# Get the values as a list
group_list = grouped.map(lambda x: (list(x[1]))) 

# Get the max value for each user. 
max_list = group_list.map(reduce_by_max).collect()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接