使用复杂条件的Spark SQL窗口函数

25

这可能最容易通过示例来解释。假设我有一个网站用户登录的DataFrame,例如:

scala> df.show(5)
+----------------+----------+
|       user_name|login_date|
+----------------+----------+
|SirChillingtonIV|2012-01-04|
|Booooooo99900098|2012-01-04|
|Booooooo99900098|2012-01-06|
|  OprahWinfreyJr|2012-01-10|
|SirChillingtonIV|2012-01-11|
+----------------+----------+
only showing top 5 rows

我想在表格中添加一列,用于指示用户何时成为了该网站的活跃用户。但是有一个限制条件:在一段时间内,用户才被认为是活跃用户,在此期间之后,如果他们再次登录,则他们的 became_active 日期将被重置。假设这段时间是5天。那么从上面的表格中得到的所需表格将类似于以下内容:

+----------------+----------+-------------+
|       user_name|login_date|became_active|
+----------------+----------+-------------+
|SirChillingtonIV|2012-01-04|   2012-01-04|
|Booooooo99900098|2012-01-04|   2012-01-04|
|Booooooo99900098|2012-01-06|   2012-01-04|
|  OprahWinfreyJr|2012-01-10|   2012-01-10|
|SirChillingtonIV|2012-01-11|   2012-01-11|
+----------------+----------+-------------+

因此,特别地,SirChillingtonIV的became_active日期被重置,因为他们的第二次登录在活动期限过期之后,但是Booooooo99900098的became_active日期在第二次登录时没有被重置,因为它在活动期间内。

我的初步想法是使用带有lag的窗口函数,然后使用lag值来填充became_active列。例如,大致如下:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val window = Window.partitionBy("user_name").orderBy("login_date")
val df2 = df.withColumn("tmp", lag("login_date", 1).over(window))

那么,填写became_active日期的规则将是,如果tmpnull(即,如果这是第一次登录),或者如果login_date - tmp >= 5,则 became_active = login_date ; 否则,转到tmp中下一个最近的值,并应用相同的规则。这表明了一种递归方法,我无法想象如何实现。

我的问题是:这是可行的方法吗?如果是,我如何“返回”并查看tmp之前的早期值,直到找到一个停止的值?据我所知,我不能迭代Spark SQL Column的值。是否有另一种方法来实现此结果?

2个回答

44

Spark >= 3.2

最近的 Spark 发布版本提供了原生支持会话窗口的功能,可以在批处理和结构化流查询中使用(请参见 SPARK-10816 和其子任务,特别是 SPARK-34893)。

官方文档提供了一个很好的用法示例

Spark < 3.2

这里有个技巧。导入一堆函数:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{coalesce, datediff, lag, lit, min, sum}

定义Windows操作系统:

val userWindow = Window.partitionBy("user_name").orderBy("login_date")
val userSessionWindow = Window.partitionBy("user_name", "session")

找到新会话开始的点:

val newSession =  (coalesce(
  datediff($"login_date", lag($"login_date", 1).over(userWindow)),
  lit(0)
) > 5).cast("bigint")

val sessionized = df.withColumn("session", sum(newSession).over(userWindow))

找出每个会话的最早日期:

val result = sessionized
  .withColumn("became_active", min($"login_date").over(userSessionWindow))
  .drop("session")

数据集定义为:

val df = Seq(
  ("SirChillingtonIV", "2012-01-04"), ("Booooooo99900098", "2012-01-04"),
  ("Booooooo99900098", "2012-01-06"), ("OprahWinfreyJr", "2012-01-10"), 
  ("SirChillingtonIV", "2012-01-11"), ("SirChillingtonIV", "2012-01-14"),
  ("SirChillingtonIV", "2012-08-11")
).toDF("user_name", "login_date")

结果是:

+----------------+----------+-------------+
|       user_name|login_date|became_active|
+----------------+----------+-------------+
|  OprahWinfreyJr|2012-01-10|   2012-01-10|
|SirChillingtonIV|2012-01-04|   2012-01-04| <- The first session for user
|SirChillingtonIV|2012-01-11|   2012-01-11| <- The second session for user
|SirChillingtonIV|2012-01-14|   2012-01-11| 
|SirChillingtonIV|2012-08-11|   2012-08-11| <- The third session for user
|Booooooo99900098|2012-01-04|   2012-01-04|
|Booooooo99900098|2012-01-06|   2012-01-04|
+----------------+----------+-------------+

1
我知道已经很久了,但你能帮我理解解决方案中的coalesce部分吗? - Sanchit Grover
2
如果datediff($"login_date", lag($"login_date", 1).over(userWindow))的结果为null(在窗口的第一行),则返回0。 - zero323
那么这个 val sessionized = df.withColumn("session", sum(newSession).over(userWindow)) 是如何增加计数的呢? - Sanchit Grover
这是集合{0, 1}中数值的累加和。 - zero323

6

重构其他答案,使其适用于Pyspark

Pyspark中,您可以像下面这样做。

创建数据框

df = sqlContext.createDataFrame(
[
("SirChillingtonIV", "2012-01-04"), 
("Booooooo99900098", "2012-01-04"), 
("Booooooo99900098", "2012-01-06"), 
("OprahWinfreyJr", "2012-01-10"), 
("SirChillingtonIV", "2012-01-11"), 
("SirChillingtonIV", "2012-01-14"), 
("SirChillingtonIV", "2012-08-11")
], 
("user_name", "login_date"))

以上代码创建一个如下所示的数据框:
+----------------+----------+
|       user_name|login_date|
+----------------+----------+
|SirChillingtonIV|2012-01-04|
|Booooooo99900098|2012-01-04|
|Booooooo99900098|2012-01-06|
|  OprahWinfreyJr|2012-01-10|
|SirChillingtonIV|2012-01-11|
|SirChillingtonIV|2012-01-14|
|SirChillingtonIV|2012-08-11|
+----------------+----------+

现在我们想要首先找出login_date超过5天的差异。

请按照以下步骤进行操作。

必要的导入

from pyspark.sql import functions as f
from pyspark.sql import Window


# defining window partitions  
login_window = Window.partitionBy("user_name").orderBy("login_date")
session_window = Window.partitionBy("user_name", "session")

session_df = df.withColumn("session", f.sum((f.coalesce(f.datediff("login_date", f.lag("login_date", 1).over(login_window)), f.lit(0)) > 5).cast("int")).over(login_window))

当我们运行上述代码时,如果date_diffNULL,那么coalesce函数将会把NULL替换为0
+----------------+----------+-------+
|       user_name|login_date|session|
+----------------+----------+-------+
|  OprahWinfreyJr|2012-01-10|      0|
|SirChillingtonIV|2012-01-04|      0|
|SirChillingtonIV|2012-01-11|      1|
|SirChillingtonIV|2012-01-14|      1|
|SirChillingtonIV|2012-08-11|      2|
|Booooooo99900098|2012-01-04|      0|
|Booooooo99900098|2012-01-06|      0|
+----------------+----------+-------+


# add became_active column by finding the `min login_date` for each window partitionBy `user_name` and `session` created in above step
final_df = session_df.withColumn("became_active", f.min("login_date").over(session_window)).drop("session")

+----------------+----------+-------------+
|       user_name|login_date|became_active|
+----------------+----------+-------------+
|  OprahWinfreyJr|2012-01-10|   2012-01-10|
|SirChillingtonIV|2012-01-04|   2012-01-04|
|SirChillingtonIV|2012-01-11|   2012-01-11|
|SirChillingtonIV|2012-01-14|   2012-01-11|
|SirChillingtonIV|2012-08-11|   2012-08-11|
|Booooooo99900098|2012-01-04|   2012-01-04|
|Booooooo99900098|2012-01-06|   2012-01-04|
+----------------+----------+-------------+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接