按列分组并为其他列创建列表,保留顺序。

4

我有一个 PySpark 数据框,长这样:

Id               timestamp           col1               col2
abc                789                0                  1
def                456                1                  0
abc                123                1                  0
def                321                0                  1

我想按照ID列分组或分区,然后根据时间戳的顺序创建col1和col2的列表。
Id               timestamp            col1             col2
abc              [123,789]           [1,0]             [0,1]
def              [321,456]           [0,1]             [1,0]

我的方法:

from pyspark.sql import functions as F
from pyspark.sql import Window as W

window_spec = W.partitionBy("id").orderBy('timestamp')
ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

df1 = df.withColumn("col1", F.collect_list("reco").over(window_spec))\
  .withColumn("col2", F.collect_list("score").over(window_spec))\
df1.show()

但这并没有返回col1和col2的列表。
1个回答

4

我认为使用groupBy聚合不能可靠地保留顺序。因此,窗口函数似乎是可行的方法。

设置:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('abc', 789, 0, 1),
     ('def', 456, 1, 0),
     ('abc', 123, 1, 0),
     ('def', 321, 0, 1)],
    ['Id', 'timestamp', 'col1', 'col2'])

脚本:

w1 = W.partitionBy('Id').orderBy('timestamp')
w2 = W.partitionBy('Id').orderBy(F.desc('timestamp'))
df = df.select(
    'Id',
     *[F.collect_list(c).over(w1).alias(c) for c in df.columns if c != 'Id']
)
df = (df
    .withColumn('_rn', F.row_number().over(w2))
    .filter('_rn=1')
    .drop('_rn')
)

结果:

df.show()
# +---+----------+------+------+
# | Id| timestamp|  col1|  col2|
# +---+----------+------+------+
# |abc|[123, 789]|[1, 0]|[0, 1]|
# |def|[321, 456]|[0, 1]|[1, 0]|
# +---+----------+------+------+

您离所需的也很接近。我已经试过了,这个方法也似乎有效:

window_spec = W.partitionBy("Id").orderBy('timestamp')
ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

df1 = (df
    .withColumn("timestamp", F.collect_list("timestamp").over(ranged_spec))
    .withColumn("col1", F.collect_list("col1").over(ranged_spec))
    .withColumn("col2", F.collect_list("col2").over(ranged_spec))
).drop_duplicates()
df1.show()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接