在PySpark中计算特定值的连续出现次数

3

我有一个名为info的列也已经定义好了:

|     Timestamp     |   info   |
+-------------------+----------+
|2016-01-01 17:54:30|     0    |
|2016-02-01 12:16:18|     0    |
|2016-03-01 12:17:57|     0    |
|2016-04-01 10:05:21|     0    |
|2016-05-11 18:58:25|     1    |
|2016-06-11 11:18:29|     1    |
|2016-07-01 12:05:21|     0    |
|2016-08-11 11:58:25|     0    |
|2016-09-11 15:18:29|     1    |

我希望可以统计连续出现的1的次数并插入0。最终列应该是:
--------------------+----------+----------+
|     Timestamp     |   info   |    res   |
+-------------------+----------+----------+
|2016-01-01 17:54:30|     0    |     0    |
|2016-02-01 12:16:18|     0    |     0    |
|2016-03-01 12:17:57|     0    |     0    |
|2016-04-01 10:05:21|     0    |     0    |
|2016-05-11 18:58:25|     1    |     1    |
|2016-06-11 11:18:29|     1    |     2    |
|2016-07-01 12:05:21|     0    |     0    |
|2016-08-11 11:58:25|     0    |     0    |
|2016-09-11 15:18:29|     1    |     1    |

我尝试使用以下函数,但它没有起作用。
df_input = df_input.withColumn(
    "res",
    F.when(
        df_input.info == F.lag(df_input.info).over(w1),
        F.sum(F.lit(1)).over(w1)
    ).otherwise(0)
)

1
这里的w1是什么?是否有一个ID字段来记录info的顺序? - samkart
2
我不确定它是否有效。你需要一个列来orderBy你的值。当PySpark处理数据时,它不会保持顺序。 - Mykola Zotko
@samkart 我有一个时间戳列被我省略了。我的数据框以此为顺序排列。窗口被定义为w1= Window.partitionBy().orderBy('Timestamp')。 - Babbara
@Babbara 然后在你的示例中添加这一列。 - Mykola Zotko
@MykolaZotko 我做到了。 - Babbara
3个回答

3

来自向表中添加一个计算累积前面重复值的列,感谢@blackbishop

from pyspark.sql import functions as F, Window

df = spark.createDataFrame([0, 0, 0, 0, 1, 1, 0, 0, 1], 'int').toDF('info')

df.withColumn("ID", F.monotonically_increasing_id()) \
    .withColumn("group",
            F.row_number().over(Window.orderBy("ID"))
            - F.row_number().over(Window.partitionBy("info").orderBy("ID"))
    ) \
    .withColumn("Result", F.when(F.col('info') != 0, F.row_number().over(Window.partitionBy("group").orderBy("ID"))).otherwise(F.lit(0)))\
    .orderBy("ID")\
    .drop("ID", "group")\
    .show()

+----+------+
|info|Result|
+----+------+
|   0|     0|
|   0|     0|
|   0|     0|
|   0|     0|
|   1|     1|
|   1|     2|
|   0|     0|
|   0|     0|
|   1|     1|
+----+------+

在这个例子中,它也在计算0的出现次数。正如您在我的“结果”列中所看到的,我只计算连续的1。 - Babbara
@Babbara,你可以随时使用when()来摆脱它们。 - samkart
@samkart 在代码的哪个部分?在哪里创建了分组列? - Babbara
@Babbara使用Result创建你的res。查看withColumn()以获取group字段--最终会被删除。 - samkart
修改为仅考虑非零出现次数。 - Luiz Viola

3

简述 -- 复杂方法

我们遇到了类似的问题,希望通过逐行处理来查看前一行的计算字段。有多个计算需要跟踪,我们采用了一个 rdd 方法,并将 Python 函数发送到所有工作节点以进行最佳分布式处理。以下是基于该方法的示例:

创建与您的问题相同的虚拟数据。

data_ls = [
    (1, 0,),
    (2, 0,),
    (3, 0,),
    (4, 1,),
    (5, 1,),
    (6, 0,),
    (7, 1,)
]

data_sdf = spark.sparkContext.parallelize(data_ls).toDF(['ts', 'info'])

# +---+----+
# | ts|info|
# +---+----+
# |  1|   0|
# |  2|   0|
# |  3|   0|
# |  4|   1|
# |  5|   1|
# |  6|   0|
# |  7|   1|
# +---+----+

我们的方法是创建一个Python函数,以跟踪当前字段中先前计算的字段。使用该函数在数据框的rdd上使用flatMapValues()
def custom_cumcount(groupedRows):
    """
    keep track of the previously calculated result and use in the current calculation
    ship this for optimum resource usage
    """

    res = []
    prev_sumcol = 0

    for row in groupedRows:
        if row.info == 0:
            sum_col = 0
        else:
            sum_col = prev_sumcol + row.info
        
        prev_sumcol = sum_col

        res.append([col for col in row] + [sum_col])

    return res

# create a schema to be used for result's dataframe
data_sdf_schema_new = data_sdf.withColumn('dropme', func.lit(None).cast('int')). \
    drop('dropme'). \
    schema. \
    add('sum_col', 'integer')

# StructType(List(StructField(ts,LongType,true),StructField(info,LongType,true),StructField(sum_col,IntegerType,true)))

# run the function on the data
data_rdd = data_sdf.rdd. \
    groupBy(lambda i: 1). \
    flatMapValues(lambda k: custom_cumcount(sorted(k, key=lambda s: s.ts))). \
    values()

# create dataframe from resulting rdd
spark.createDataFrame(data_rdd, schema=data_sdf_schema_new). \
    show()

# +---+----+-------+
# | ts|info|sum_col|
# +---+----+-------+
# |  1|   0|      0|
# |  2|   0|      0|
# |  3|   0|      0|
# |  4|   1|      1|
# |  5|   1|      2|
# |  6|   0|      0|
# |  7|   1|      1|
# +---+----+-------+

2

这里有另外一种方法,使用条件运行总和来创建组,然后使用该列进行累积求和:

from pyspark.sql import Window, functions as F

w1 = Window.orderBy("Timestamp")
w2 = Window.partitionBy("grp").orderBy("Timestamp")

df1 = (df.withColumn("grp", F.sum(F.when(F.col("info") == 1, 0).otherwise(1)).over(w1))
       .withColumn("res", F.sum("info").over(w2))
       .drop("grp")
       )

df1.show()
# +-------------------+----+---+
# |          Timestamp|info|res|
# +-------------------+----+---+
# |2016-01-01 17:54:30|   0|  0|
# |2016-02-01 12:16:18|   0|  0|
# |2016-03-01 12:17:57|   0|  0|
# |2016-04-01 10:05:21|   0|  0|
# |2016-05-11 18:58:25|   1|  1|
# |2016-06-11 11:18:29|   1|  2|
# |2016-07-01 12:05:21|   0|  0|
# |2016-08-11 11:58:25|   0|  0|
# |2016-09-11 15:18:29|   1|  1|
# +-------------------+----+---+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接