在PySpark数据框中使用flatmap和collect_set函数

5

我有两个数据框,使用groupby后,在agg中使用collect_set()。在聚合后,flatMap结果数组的最佳方法是什么。

schema = ['col1', 'col2', 'col3', 'col4']

a = [[1, [23, 32], [11, 22], [9989]]]

df1 = spark.createDataFrame(a, schema=schema)

b = [[1, [34], [43, 22], [888, 777]]]

df2 = spark.createDataFrame(b, schema=schema)

df = df1.union(
        df2
    ).groupby(
        'col1'
    ).agg(
        collect_set('col2').alias('col2'),
        collect_set('col3').alias('col3'),
        collect_set('col4').alias('col4')
    )

df.collect()

我得到这个输出:
[Row(col1=1, col2=[[34], [23, 32]], col3=[[11, 22], [43, 22]], col4=[[9989], [888, 777]])]

但是,我希望输出如下所示:
[Row(col1=1, col2=[23, 32, 34], col3=[11, 22, 43], col4=[9989, 888, 777])]
2个回答

3
你可以使用udf:
from itertools import chain
from pyspark.sql.types import *
from pyspark.sql.functions import udf

flatten = udf(lambda x: list(chain.from_iterable(x)), ArrayType(IntegerType()))

df.withColumn('col2_flat', flatten('col2'))

1

没有使用UDF的话,我认为这应该可以工作:

from pyspark.sql.functions import array_distinct, flatten

df.withColumn('col2_flat', array_distinct(flatten('col2')))

它将展平嵌套的数组,然后去重。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接