如何在Pyspark中将列表分成多列?

35

我有:

key   value
a    [1,2,3]
b    [2,3,4]

我想要:

key value1 value2 value3
a     1      2      3
b     2      3      4

在Scala中,似乎可以这样编写:

df.select($“ value._1”,$“ value._2”,$“ value._3”)

但在Python中不可能。那么有没有好的方法来实现类似的功能呢?

5个回答

74

这取决于你的“列表”类型:

  • 如果它是ArrayType()类型:

df = hc.createDataFrame(sc.parallelize([['a', [1,2,3]], ['b', [2,3,4]]]), ["key", "value"])
df.printSchema()
df.show()
root
 |-- key: string (nullable = true)
 |-- value: array (nullable = true)
 |    |-- element: long (containsNull = true)

您可以像使用 Python 一样使用 [] 访问值:

df.select("key", df.value[0], df.value[1], df.value[2]).show()
+---+--------+--------+--------+
|key|value[0]|value[1]|value[2]|
+---+--------+--------+--------+
|  a|       1|       2|       3|
|  b|       2|       3|       4|
+---+--------+--------+--------+

+---+-------+
|key|  value|
+---+-------+
|  a|[1,2,3]|
|  b|[2,3,4]|
+---+-------+
  • 如果它的类型是 StructType():(也许你是通过读取JSON构建你的数据框)

  • df2 = df.select("key", psf.struct(
            df.value[0].alias("value1"), 
            df.value[1].alias("value2"), 
            df.value[2].alias("value3")
        ).alias("value"))
    df2.printSchema()
    df2.show()
    root
     |-- key: string (nullable = true)
     |-- value: struct (nullable = false)
     |    |-- value1: long (nullable = true)
     |    |-- value2: long (nullable = true)
     |    |-- value3: long (nullable = true)
    
    +---+-------+
    |key|  value|
    +---+-------+
    |  a|[1,2,3]|
    |  b|[2,3,4]|
    +---+-------+
    

    你可以直接使用*对列进行分割:

    df2.select('key', 'value.*').show()
    +---+------+------+------+
    |key|value1|value2|value3|
    +---+------+------+------+
    |  a|     1|     2|     3|
    |  b|     2|     3|     4|
    +---+------+------+------+
    

    3
    当我使用*拆分一个StructType列时,能否重命名列? - Benjamin Du
    补充答案,对于数组类型要动态实现,可以这样做:df2.select(['key'] + [df2.features[x] for x in range(0,3)]) - VarunKumar

    12

    我希望在pault的回答中补充关于大小为固定值的列表(数组)的情况。

    如果我们的列包含中等大小的数组(或大型数组),仍然可以将它们分割成列。

    from pyspark.sql.types import *          # Needed to define DataFrame Schema.
    from pyspark.sql.functions import expr   
    
    # Define schema to create DataFrame with an array typed column.
    mySchema = StructType([StructField("V1", StringType(), True),
                           StructField("V2", ArrayType(IntegerType(),True))])
    
    df = spark.createDataFrame([['A', [1, 2, 3, 4, 5, 6, 7]], 
                                ['B', [8, 7, 6, 5, 4, 3, 2]]], schema= mySchema)
    
    # Split list into columns using 'expr()' in a comprehension list.
    arr_size = 7
    df = df.select(['V1', 'V2']+[expr('V2[' + str(x) + ']') for x in range(0, arr_size)])
    
    # It is posible to define new column names.
    new_colnames = ['V1', 'V2'] + ['val_' + str(i) for i in range(0, arr_size)] 
    df = df.toDF(*new_colnames)
    

    结果为:

    df.show(truncate= False)
    
    +---+---------------------+-----+-----+-----+-----+-----+-----+-----+
    |V1 |V2                   |val_0|val_1|val_2|val_3|val_4|val_5|val_6|
    +---+---------------------+-----+-----+-----+-----+-----+-----+-----+
    |A  |[1, 2, 3, 4, 5, 6, 7]|1    |2    |3    |4    |5    |6    |7    |
    |B  |[8, 7, 6, 5, 4, 3, 2]|8    |7    |6    |5    |4    |3    |2    |
    +---+---------------------+-----+-----+-----+-----+-----+-----+-----+
    

    3
    对于数组类型的数据,要进行动态处理,可以这样做:
    df2.select(['key'] + [df2.features[x] for x in range(0,3)])
    

    2
    我需要将一个712维数组转换为列,以便将其写入csv文件。我首先尝试了@MaFF的解决方案,但似乎会导致很多错误和额外的计算时间。我不确定是什么原因导致了这种情况,所以我使用了另一种方法,它显著减少了计算时间(22分钟与4个多小时相比)!
    @MaFF的方法:
    length = len(dataset.head()["list_col"])
    dataset = dataset.select(dataset.columns + [dataset["list_col"][k] for k in range(length)])
    

    我使用的工具:

    dataset = dataset.rdd.map(lambda x: (*x, *x["list_col"])).toDF()
    

    如果有人知道是什么原因导致计算时间的差异,请告诉我!我怀疑在我的情况下,瓶颈在于调用head()来获取列表长度(我希望它是自适应的)。由于(i)我的数据管道非常长且详尽,以及(ii)我必须取消多个列的列表,因此缓存整个数据集不是一个选项。

    1

    @jordi Aceiton 感谢您提供的解决方案。 我试图让它更加简洁,尝试在创建列时删除重命名新创建的列名的循环。 使用 df.columns 获取所有列名而不是手动创建。

    from pyspark.sql.types import *          
    from pyspark.sql.functions import * 
    from pyspark import Row
    
    df = spark.createDataFrame([Row(index=1, finalArray = [1.1,2.3,7.5], c =4),Row(index=2, finalArray = [9.6,4.1,5.4], c= 4)])
    #collecting all the column names as list
    dlist = df.columns
    #Appending new columns to the dataframe
    df.select(dlist+[(col("finalArray")[x]).alias("Value"+str(x+1)) for x in range(0, 3)]).show()
    

    输出:

     +---------------+-----+------+------+------+
     |  finalArray   |index|Value1|Value2|Value3|
     +---------------+-----+------+------+------+
     |[1.1, 2.3, 7.5]|  1  |   1.1|   2.3|   7.5|
     |[9.6, 4.1, 5.4]|  2  |   9.6|   4.1|   5.4|
     +---------------+-----+------+------+------+
    

    NameError: 名称 'col' 未定义 - testin3r

    网页内容由stack overflow 提供, 点击上面的
    可以查看英文原文,
    原文链接