在pyspark中检索DataFrame每个组的前n个

70

在pyspark中有一个DataFrame,数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

我希望每组返回相同用户ID的2条记录,并且这些记录需要具有最高分数。因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

我对pyspark很陌生,有人能提供一段代码片段或相关文档的链接吗?非常感谢!

6个回答

112

我认为你需要使用窗口函数,根据user_idscore来获得每一行的排名,并随后筛选结果仅保留前两个值。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

一般来说,官方的编程指南是学习Spark的好起点。

数据

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

我认为有些需要调整的地方。object_id对于groupbytop过程都没有影响。而我想要的是按user_id进行分组,在每个组中,分别检索前两条最高分记录,而不仅仅是第一条记录。非常感谢! - KAs
4
你可以在筛选器中使用窗口函数:df.filter(rank().over(window) <= 2) - Wilmerton
2
我感到非常惊讶... 我曾经确信我在过滤器中使用过窗口函数。但是我确实无法重现它(无论是在2还是1.6中)。我确实以一种奇特的方式使用了它,但我不记得是何时或如何使用的。抱歉。 - Wilmerton
6
如果您想获取相同排名的前n个结果,建议使用row_number而不是rank - Tomer Ben David
@TomerBenDavid 这条评论值得更多的赞,谢谢您。 - rluo

39

如果在获取排名相同的情况下使用row_number而不是rank,则Top-n更加准确:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()

注意在Jupyter笔记本中使用limit(20).toPandas()技巧代替show()以获得更好的格式。


2
请记得添加 from pyspark.sql.functions import row_number 以使其正常工作。 - Tapa Dipti Sitaula
哪种计算方式更高效(快速)?我怀疑它们差不多。有没有更高效的方法?我正在处理一个 110 GB 的数据集,其中包含 4.7M 个类别(要进行groupBy操作),每个类别大约有 4300 行数据,这在一个大型集群上花费了很长时间。 - JOSE DANIEL FERNANDEZ
1
这是描述排名、行号和密集排名之间差异的最佳链接:https://www.c-sharpcorner.com/blogs/difference-between-rownumber-rank-denserank-in-sql-server#:~:text=Row_Number%20%28%29%20will%20generate%20a%20unique%20number%20for,the%20same%20value%20without%20skipping%20the%20next%20number. - HT.

2
我知道这个问题是针对pyspark提出的,我正在寻找类似于Scala的答案,即:

在Scala中检索DataFrame每个组中的前n个值

这里是@mtoto答案的scala版本。

最初的回答
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc')
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 
# you can change the value 2 to any number you want. Here 2 represents the top 2 values

更多的例子可以在这里找到。原始答案翻译成"最初的回答"。

1
使用Python 3和Spark 2.4
from pyspark.sql import Window
import pyspark.sql.functions as f

def get_topN(df, group_by_columns, order_by_column, n=1):
    window_group_by_columns = Window.partitionBy(group_by_columns)
    ordered_df = df.select(df.columns + [
        f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
    topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
    return topN_df

top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) 

1
这是另一种不使用窗口函数从pySpark DataFrame获取前N条记录的解决方案。
# Import Libraries
from pyspark.sql.functions import col

# Sample Data
rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)

# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)

# Display DataFrame
sorted_df.show()

输出

+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2|    2|
| user_2| object_2|    2|
| user_1| object_1|    3|
| user_2| object_1|    5|
| user_2| object_2|    6|
+-------+---------+-----+

如果您对Spark中的更多窗口函数感兴趣,可以参考我的博客之一:https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86


这个是否比窗口中必要的 order by 进行更少的计算? - Manaslu

0

使用ROW_NUMBER()函数在PYSPARK SQL查询中查找第N个最高值:

SELECT * FROM (
    SELECT e.*, 
    ROW_NUMBER() OVER (ORDER BY col_name DESC) rn 
    FROM Employee e
)
WHERE rn = N

N是从该列中所需的第n个最高值

输出:

[Stage 2:>               (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name   |
+-----------+
|1183395    |
+-----------+

查询将返回N个最高值

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接