Pypsark - 使用collect_list函数时保留null值

10
根据pyspark collect_set or collect_list with groupby中的最佳答案,当您对某一列进行collect_list操作时,该列中的null值将被删除。我已经验证了这一点。

但在我的情况下,我需要保留空列--怎样实现这个需求呢?

我没有找到任何有关这种变体collect_list函数的信息。


背景说明为什么我要使用空列:

我有一个数据框df如下:

cId   |  eId  |  amount  |  city
1     |  2    |   20.0   |  Paris
1     |  2    |   30.0   |  Seoul
1     |  3    |   10.0   |  Phoenix
1     |  3    |   5.0    |  null

我希望将以下内容写入一个Elasticsearch索引,并使用以下映射:

"mappings": {
    "doc": {
        "properties": {
            "eId": { "type": "keyword" },
            "cId": { "type": "keyword" },
            "transactions": {
                "type": "nested", 
                "properties": {
                    "amount": { "type": "keyword" },
                    "city": { "type": "keyword" }
                }
            }
        }
    }
 }      

为了符合上述嵌套映射,我转换了我的数据框,以便于每个eId和cId的组合,我都有一个像这样的交易数组:
df_nested = df.groupBy('eId','cId').agg(collect_list(struct('amount','city')).alias("transactions"))
df_nested.printSchema()
root
 |-- cId: integer (nullable = true)
 |-- eId: integer (nullable = true)
 |-- transactions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- amount: float (nullable = true)
 |    |    |-- city: string (nullable = true)

df_nested保存为json文件后,我得到了以下json记录:
{"cId":1,"eId":2,"transactions":[{"amount":20.0,"city":"Paris"},{"amount":30.0,"city":"Seoul"}]}
{"cId":1,"eId":3,"transactions":[{"amount":10.0,"city":"Phoenix"},{"amount":30.0}]}

如您所见-当cId=1eId=3时,我的一个数组元素中amount=30.0的元素没有city属性,因为这在我的原始数据(df)中是null。使用collect_list函数时,空值被删除了。
然而,当我尝试使用上述索引将df_nested写入elasticsearch时,会出现模式不匹配的错误。这基本上就是我希望在应用collect_list函数后保留我的空值的原因。

1
是否可以用其他东西替换null值,比如字符串'null' - pault
1个回答

3
    from pyspark.sql.functions import create_map, collect_list, lit, col, to_json, from_json
    from pyspark import SparkContext, SparkConf
    from pyspark.sql import SQLContext, HiveContext, SparkSession, types, Row
    from pyspark.sql import functions as f
    import os
    
    app_name = "CollList"
    conf = SparkConf().setAppName(app_name)
    spark = SparkSession.builder.appName(app_name).config(conf=conf).enableHiveSupport().getOrCreate()
    
    df = spark.createDataFrame([[1, 2, 20.0, "Paris"], [1, 2, 30.0, "Seoul"],
        [1, 3, 10.0, "Phoenix"], [1, 3, 5.0, None]],
        ["cId", "eId", "amount", "city"])
    print("Actual data")
    df.show(10,False)
```
Actual data
+---+---+------+-------+
|cId|eId|amount|city   |
+---+---+------+-------+
|1  |2  |20.0  |Paris  |
|1  |2  |30.0  |Seoul  |
|1  |3  |10.0  |Phoenix|
|1  |3  |5.0   |null   |
+---+---+------+-------+
```
    #collect_list that skips null columns
    df1 = df.groupBy(f.col('city'))\
            .agg(f.collect_list(f.to_json(f.struct([f.col(x).alias(x) for x in (c for c in df.columns if c != 'cId' and c != 'eId' )])))).alias('newcol')
    print("Collect List Data - Missing Null Columns in the list")
    df1.show(10, False)
```
Collect List Data - Missing Null Columns in the list
+-------+-------------------------------------------------------------------------------------------------------------------+
|city   |collect_list(structstojson(named_struct(NamePlaceholder(), amount AS `amount`, NamePlaceholder(), city AS `city`)))|
+-------+-------------------------------------------------------------------------------------------------------------------+
|Phoenix|[{"amount":10.0,"city":"Phoenix"}]                                                                                 |
|null   |[{"amount":5.0}]                                                                                                   |
|Paris  |[{"amount":20.0,"city":"Paris"}]                                                                                   |
|Seoul  |[{"amount":30.0,"city":"Seoul"}]                                                                                   |
+-------+-------------------------------------------------------------------------------------------------------------------+
``` 
    my_list = []
    for x in (c for c in df.columns if c != 'cId' and c != 'eId' ):
        my_list.append(lit(x))
        my_list.append(col(x))
    
    grp_by = ["eId","cId"]
    df_nested = df.withColumn("transactions", create_map(my_list))\
                  .groupBy(grp_by)\
                  .agg(collect_list(f.to_json("transactions")).alias("transactions"))
    
    print("collect list after create_map")
    df_nested.show(10,False)
```
collect list after create_map
+---+---+--------------------------------------------------------------------+
|eId|cId|transactions                                                        |
+---+---+--------------------------------------------------------------------+
|2  |1  |[{"amount":"20.0","city":"Paris"}, {"amount":"30.0","city":"Seoul"}]|
|3  |1  |[{"amount":"10.0","city":"Phoenix"}, {"amount":"5.0","city":null}]  |
+---+---+--------------------------------------------------------------------+
```   

请注意,create_map 将把 key: value 强制转换为 string: string,因此 amount 的值将是字符串而不是浮点数。 - pakobill

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接