使用pySpark - 在滚动窗口中获取最大值行

5

我有一个pyspark数据框,包含以下示例行。我正在尝试在10分钟的时间跨度内获取最大平均值。我尝试使用窗口函数,但无法实现结果。

下面是我的数据框,包含30分钟的随机数据。我期望输出3行,每10分钟1行。

+-------------------+---------+
|         event_time|avg_value|
+-------------------+---------+
|2019-12-29 00:01:00|      9.5|
|2019-12-29 00:02:00|      9.0|
|2019-12-29 00:04:00|      8.0|
|2019-12-29 00:06:00|     21.0|
|2019-12-29 00:08:00|      7.0|
|2019-12-29 00:11:00|      8.5|
|2019-12-29 00:12:00|     11.5|
|2019-12-29 00:14:00|      8.0|
|2019-12-29 00:16:00|     31.0|
|2019-12-29 00:18:00|      8.0|
|2019-12-29 00:21:00|      8.0|
|2019-12-29 00:22:00|     16.5|
|2019-12-29 00:24:00|      7.0|
|2019-12-29 00:26:00|     14.0|
|2019-12-29 00:28:00|      7.0|
+-------------------+---------+

我将使用以下代码进行此操作。
window_spec = Window.partitionBy('event_time').orderBy('event_time').rangeBetween(-60*10,0)
new_df = data.withColumn('rank', rank().over(window_spec))
new_df.show()

但是这段代码给了我以下错误:
pyspark.sql.utils.AnalysisException: 'Window Frame specifiedwindowframe(RangeFrame, -600, currentrow$()) must match the required frame specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$());'

我的期望输出是

+-------------------+---------+
|         event_time|avg_value|
+-------------------+---------+
|2019-12-29 00:06:00|     21.0|
|2019-12-29 00:16:00|     31.0|
|2019-12-29 00:22:00|     16.5|
+-------------------+---------+

有人可以帮我解决这个问题吗?

谢谢。

2个回答

7
你可以使用 groupBy 结合 window 使用。
from pyspark.sql import functions as F
df.groupBy(F.window("event_time","10 minutes"))\
  .agg(F.max("avg_value").alias("avg_value")).show()

#+--------------------+---------+
#|              window|avg_value|
#+--------------------+---------+
#|[2019-12-29 00:20...|     16.5|
#|[2019-12-29 00:10...|     31.0|
#|[2019-12-29 00:00...|     21.0|
#+--------------------+---------+

要获取您所需的 event_time 列的精确输出,您可以使用 collect_listarray_sortelement_atspark2.4+)函数。
from pyspark.sql import functions as F
df.groupBy(F.window("event_time","10 minutes"))\
  .agg(F.element_at(F.array_sort(F.collect_list("event_time")),-2).alias("event_time"),\
       F.max("avg_value").alias("avg_value")).drop("window").orderBy("event_time").show()

#+-------------------+---------+
#|event_time         |avg_value|
#+-------------------+---------+
#|2019-12-29 00:06:00|21.0     |
#|2019-12-29 00:16:00|31.0     |
#|2019-12-29 00:26:00|16.5     |
#+-------------------+---------+

更新:

df.groupBy(F.window("event_time","10 minutes"))\
  .agg(F.collect_list(F.struct("event_time","avg_value")).alias("event_time")\
       ,F.max("avg_value").alias("avg_value"))\
  .withColumn("event_time", F.expr("""filter(event_time, x-> x.avg_value=avg_value)"""))\
        .select((F.col("event_time.event_time")[0]).alias("event_time"),"avg_value").orderBy("event_time").show()

#+-------------------+---------+
#|         event_time|avg_value|
#+-------------------+---------+
#|2019-12-29 00:06:00|     21.0|
#|2019-12-29 00:16:00|     31.0|
#|2019-12-29 00:22:00|     16.5|
#+-------------------+---------+

2
谢谢。这真的很有帮助。唯一我能看到的问题是最大值与事件时间不匹配。例如:对于21-30分钟的窗口,最大值为第22分钟的16.5,而不是第26分钟。当我在整个数据框上运行此代码时,我在多个地方都看到了这种情况。 - user3497321
1
非常完美..!! 非常感谢。我会尽力理解你的代码。 - user3497321
1
没问题。我使用了 高阶函数 filter 来遍历这两列的结构,然后获取我们所需的时间。看一下这些函数,它们非常有帮助。 - murtihash

5

您的数据

data = [
    ('2019-12-29 00:01:00', 9.5,),
    ('2019-12-29 00:02:00', 9.0,),
    ('2019-12-29 00:04:00', 8.0,),
    ('2019-12-29 00:06:00', 21.0,),
    ('2019-12-29 00:08:00', 7.0,),
    ('2019-12-29 00:11:00', 8.5,),
    ('2019-12-29 00:12:00', 11.5,),
    ('2019-12-29 00:14:00', 8.0,),
    ('2019-12-29 00:16:00', 31.0,),
    ('2019-12-29 00:18:00', 8.0,),
    ('2019-12-29 00:21:00', 8.0,),
    ('2019-12-29 00:22:00', 16.5,),
    ('2019-12-29 00:24:00', 7.0,),
    ('2019-12-29 00:26:00', 14.0,),
    ('2019-12-29 00:28:00', 7.0,),
]
df = spark.createDataFrame(data, ['event_time', 'avg_value'])

解决方案

from pyspark.sql import Window
from pyspark.sql.functions import window, max, col

w = Window().partitionBy('group_col')

(
    df.
        withColumn(
            'group_col',
            window('event_time', '10 minutes')
        ).
        withColumn(
            'max_val',
            max(col('avg_value')).over(w)
        ).
        where(
            col('avg_value') == col('max_val')
        ).
        drop(
            'max_val',
            'group_col'
        ).
        orderBy('event_time').
        show(truncate=False)
)

+-------------------+---------+                                                 
|event_time         |avg_value|
+-------------------+---------+
|2019-12-29 00:06:00|21.0     |
|2019-12-29 00:16:00|31.0     |
|2019-12-29 00:22:00|16.5     |
+-------------------+---------+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接