如何在Spark SQL中对多个列进行数据透视?

32

我需要在PySpark数据框中旋转多列。示例数据框:

from pyspark.sql import functions as F
d = [(100,1,23,10),(100,2,45,11),(100,3,67,12),(100,4,78,13),(101,1,23,10),(101,2,45,13),(101,3,67,14),(101,4,78,15),(102,1,23,10),(102,2,45,11),(102,3,67,16),(102,4,78,18)]
mydf = spark.createDataFrame(d,['id','day','price','units'])
mydf.show()
# +---+---+-----+-----+
# | id|day|price|units|
# +---+---+-----+-----+
# |100|  1|   23|   10|
# |100|  2|   45|   11|
# |100|  3|   67|   12|
# |100|  4|   78|   13|
# |101|  1|   23|   10|
# |101|  2|   45|   13|
# |101|  3|   67|   14|
# |101|  4|   78|   15|
# |102|  1|   23|   10|
# |102|  2|   45|   11|
# |102|  3|   67|   16|
# |102|  4|   78|   18|
# +---+---+-----+-----+t

现在,如果我需要按天将每个id的价格列转换为行,则可以使用pivot方法:

pvtdf = mydf.withColumn('combcol', F.concat(F.lit('price_'), mydf['day'])).groupby('id').pivot('combcol').agg(F.first('price'))
pvtdf.show()
# +---+-------+-------+-------+-------+
# | id|price_1|price_2|price_3|price_4|
# +---+-------+-------+-------+-------+
# |100|     23|     45|     67|     78|
# |101|     23|     45|     67|     78|
# |102|     23|     45|     67|     78|
# +---+-------+-------+-------+-------+

所以当我需要将单位列和价格一起转置时,我必须像上面那样创建另一个数据框来处理单位,然后使用"id"将它与价格数据框进行join。但是,如果我有更多类似的列,则可以尝试使用函数来处理。

def pivot_udf(df, *cols):
    mydf = df.select('id').drop_duplicates()
    for c in cols:
       mydf = mydf.join(df.withColumn('combcol', F.concat(F.lit('{}_'.format(c)), df['day'])).groupby('id').pivot('combcol').agg(F.first(c)),' id')
    return mydf

pivot_udf(mydf, 'price', 'units').show()
# +---+-------+-------+-------+-------+-------+-------+-------+-------+
# | id|price_1|price_2|price_3|price_4|units_1|units_2|units_3|units_4|
# +---+-------+-------+-------+-------+-------+-------+-------+-------+
# |100|     23|     45|     67|     78|     10|     11|     12|     13|
# |101|     23|     45|     67|     78|     10|     13|     14|     15|
# |102|     23|     45|     67|     78|     10|     11|     16|     18|
# +---+-------+-------+-------+-------+-------+-------+-------+-------+

这样做是一个好习惯吗?还有没有其他更好的方法?


请参考此链接,希望能对您有所帮助![https://dev59.com/OloU5IYBdhLWcg3wamca][1] - Manu Gupta
4个回答

29

这里有一种非UDF的方法,它只涉及一个数据透视表(因此仅需进行一次单列扫描即可识别所有唯一日期)。

dff = mydf.groupBy('id').pivot('day').agg(F.first('price').alias('price'),F.first('units').alias('unit'))

这是结果(由于顺序和名称不匹配而道歉):

+---+-------+------+-------+------+-------+------+-------+------+               
| id|1_price|1_unit|2_price|2_unit|3_price|3_unit|4_price|4_unit|
+---+-------+------+-------+------+-------+------+-------+------+
|100|     23|    10|     45|    11|     67|    12|     78|    13|
|101|     23|    10|     45|    13|     67|    14|     78|    15|
|102|     23|    10|     45|    11|     67|    16|     78|    18|
+---+-------+------+-------+------+-------+------+-------+------+

在按日期进行数据透视后,我们只是对 priceunit 两列进行聚合。

如果问题需要命名,请参考以下内容:

dff.select([F.col(c).name('_'.join(x for x in c.split('_')[::-1])) for c in dff.columns]).show()

+---+-------+------+-------+------+-------+------+-------+------+
| id|price_1|unit_1|price_2|unit_2|price_3|unit_3|price_4|unit_4|
+---+-------+------+-------+------+-------+------+-------+------+
|100|     23|    10|     45|    11|     67|    12|     78|    13|
|101|     23|    10|     45|    13|     67|    14|     78|    15|
|102|     23|    10|     45|    11|     67|    16|     78|    18|
+---+-------+------+-------+------+-------+------+-------+------+

5

这个问题的解决方案是我所能得到的最好的。唯一的改进就是缓存输入数据集以避免双重扫描,即:

mydf.cache
pivot_udf(mydf,'price','units').show()

3

这是一个示例,展示了如何使用多个列进行分组、旋转和聚合

在使用多个列进行旋转时,并不是一件简单的事情,你首先需要创建另外一列用于旋转。

输入:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [('clsA', 'id1', 'a', 'x', 100, 15),
     ('clsA', 'id1', 'a', 'x', 110, 16),
     ('clsA', 'id1', 'a', 'y', 105, 14),
     ('clsA', 'id2', 'a', 'y', 110, 14),
     ('clsA', 'id1', 'b', 'y', 100, 13),
     ('clsA', 'id1', 'b', 'x', 120, 16),
     ('clsA', 'id2', 'b', 'y', 120, 17)],
    ['cls', 'id', 'grp1', 'grp2', 'price', 'units'])

聚合:

df = df.withColumn('_pivot', F.concat_ws('_', 'grp1', 'grp2'))
df = df.groupBy('cls', 'id').pivot('_pivot').agg(
    F.first('price').alias('price'),
    F.first('units').alias('unit')
)
df.show()
# +----+---+---------+--------+---------+--------+---------+--------+---------+--------+
# | cls| id|a_x_price|a_x_unit|a_y_price|a_y_unit|b_x_price|b_x_unit|b_y_price|b_y_unit|
# +----+---+---------+--------+---------+--------+---------+--------+---------+--------+
# |clsA|id2|     null|    null|      110|      14|     null|    null|      120|      17|
# |clsA|id1|      100|      15|      105|      14|      120|      16|      100|      13|
# +----+---+---------+--------+---------+--------+---------+--------+---------+--------+

2

在 Spark 1.6 版本中,我认为这是唯一的方法,因为 pivot 只接受一个列名,而第二个属性值可以传递该列的不同值,这将使您的代码运行更快,否则 Spark 必须为您运行,所以是的,这是正确的方法。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接