如何在Spark数据框中展开结构体?

77

我有一个具有以下结构的数据框:

 |-- data: struct (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- keyNote: struct (nullable = true)
 |    |    |-- key: string (nullable = true)
 |    |    |-- note: string (nullable = true)
 |    |-- details: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

如何将结构扁平化并创建新的数据框:

     |-- id: long (nullable = true)
     |-- keyNote: struct (nullable = true)
     |    |-- key: string (nullable = true)
     |    |-- note: string (nullable = true)
     |-- details: map (nullable = true)
     |    |-- key: string
     |    |-- value: string (valueContainsNull = true)

有类似于explode的东西,但是用于结构体吗?


1
https://dev59.com/9FoU5IYBdhLWcg3wamga 上的答案也很有帮助。 - erwaman
1
这里还提供了一个不错的解决方案:https://dev59.com/JRD6s4cB2Jgan1znjkP8?rq=1 - TobiStraub
15个回答

130

这应该适用于Spark 1.6或更高版本:

df.select(df.col("data.*"))

或者

df.select(df.col("data.id"), df.col("data.keyNote"), df.col("data.details"))

16
主线程中发生异常,异常类型为org.apache.spark.sql.AnalysisException:没有名为*的结构字段。 - djWann
但是像 df.select(df.col1, df.col2, df.col3) 这样选择所有列的方法是可行的,所以我会接受这个答案。 - djWann
我只是在编辑,但出现了奇怪的情况。也许是版本问题导致无法使用*符号? - user6022341
是的,也许吧。我正在使用Spark 1.6.1和Scala 2.10。 - djWann
你如何在嵌套结构体keyNote下选择键或音符? - SparkleGoat
当我选择 df.select("data.") 时,它会给我 nn 行(每行重复 n 次)。 我的数据框中有 n-2 个不同的行,所以如果我使用 distinct,它会给我 n-2 个结果。 但是我想要 n 个实际存在于我的数据中的结果。如何使用上述 select 命令来实现这一点? - Madman

40
这里有一个函数可以满足您的需求,可以处理包含相同名称列的多个嵌套列:
import pyspark.sql.functions as F

def flatten_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
    nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']

    flat_df = nested_df.select(flat_cols +
                               [F.col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols
                                for c in nested_df.select(nc+'.*').columns])
    return flat_df

之前:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)
 |-- bar: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)

之后:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo_a: float (nullable = true)
 |-- foo_b: float (nullable = true)
 |-- foo_c: integer (nullable = true)
 |-- bar_a: float (nullable = true)
 |-- bar_b: float (nullable = true)
 |-- bar_c: integer (nullable = true)

19

对于Spark 2.4.5版本,

然而,df.select(df.col("data.*"))会导致异常org.apache.spark.sql.AnalysisException: No such struct field * in

这将起作用:

df.select($"data.*")

1
这也适用于 Spark 3.1.0,但它不会保留所选的data或其父级,并且如果存在进一步嵌套的结构,则不会下降。 - Nevermore
当我选择 df.select("data.") 时,它会给我 nn 行(每行重复 n 次)。我的数据框中有 n-2 个不同的行,所以如果我使用 distinct,它会给我 n-2 个结果。但是我想要 n 个实际存在于我的数据中的结果。如何使用上述 select 命令来实现这一点? - Madman
4
美元可以省略 :) - meniluca
经测试,适用于DBR 9.1,Spark 3.1.2。df.select("data.*") - Aman Sehgal

15

这个 flatten_df 版本在每个层级上展开数据框,使用堆栈来避免递归调用:

from pyspark.sql.functions import col


def flatten_df(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        parents, df = stack.pop()

        flat_cols = [
            col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]

        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))

    return nested_df.select(columns)

示例:

from pyspark.sql.types import StringType, StructField, StructType


schema = StructType([
    StructField("some", StringType()),

    StructField("nested", StructType([
        StructField("nestedchild1", StringType()),
        StructField("nestedchild2", StringType())
    ])),

    StructField("renested", StructType([
        StructField("nested", StructType([
            StructField("nestedchild1", StringType()),
            StructField("nestedchild2", StringType())
        ]))
    ]))
])

data = [
    {
        "some": "value1",
        "nested": {
            "nestedchild1": "value2",
            "nestedchild2": "value3",
        },
        "renested": {
            "nested": {
                "nestedchild1": "value4",
                "nestedchild2": "value5",
            }
        }
    }
]

df = spark.createDataFrame(data, schema)
flat_df = flatten_df(df)
print(flat_df.collect())

打印:

[Row(some=u'value1', renested_nested_nestedchild1=u'value4', renested_nested_nestedchild2=u'value5', nested_nestedchild1=u'value2', nested_nestedchild2=u'value3')]

1
这似乎不能递归到数组内部嵌套的结构体中。 - malthe
1
我最终想出了如何做,并在这里发布了一个脚本:https://dev59.com/Lr_qa4cB1Zd3GeqPOL7B#66482320。 - malthe
1
如果这段代码是为了这个答案而原创的,那么你应该在这里得到认可!https://learn.microsoft.com/en-us/azure/synapse-analytics/how-to-analyze-complex-schema#define-a-function-to-flatten-the-nested-schema - Martin Smith
1
@MartinSmith 是的,我为了这个回答而写了那段代码片段。 - federicojasson
https://github.com/MicrosoftDocs/azure-docs/issues/113914 - undefined
显示剩余3条评论

8

更紧凑和高效的实现方式:

不需要创建列表并对它们进行迭代。 您可以根据字段的类型(是否为结构体)来“操作”它们。

如果列是嵌套的(结构体),则需要将其展平(.*),否则使用点表示法(parent.child)访问并用下划线替换点(parent_child)。

def flatten_df(nested_df):
    stack = [((), nested_df)]
    columns = []
    while len(stack) > 0:
        parents, df = stack.pop()
        for column_name, column_type in df.dtypes:
            if column_type[:6] == "struct":
                projected_df = df.select(column_name + ".*")
                stack.append((parents + (column_name,), projected_df))
            else:
                columns.append(col(".".join(parents + (column_name,))).alias("_".join(parents + (column_name,))))
    return nested_df.select(columns)

1
欢迎来到 Stack Overflow。在 Stack Overflow 上,我们不鼓励仅包含代码的答案,因为它们无法解释如何解决问题。请编辑您的答案,解释这段代码的作用以及它为什么比其他答案更高效,这样其他遇到类似问题的用户就可以从中学习并受益。 - FluffyKitten
1
对我来说完美地工作了。如果您添加更多的解释,肯定会得到更多的喜欢。 - user2201536
"parents" 的数据类型是什么? @Domenico Di Nicola - soMuchToLearnAndShare
当列名包含“点”时,我遇到了问题。我知道如果列名是显式的,我可以使用反引号进行转义,但在上面的情况下,我无法这样做。 - soMuchToLearnAndShare

5
我将stecos的解决方案进行了一些泛化,以便可以在多个结构层次深度上进行展平:
def flatten_df(nested_df, layers):
    flat_cols = []
    nested_cols = []
    flat_df = []

    flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
    nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])

    flat_df.append(nested_df.select(flat_cols[0] +
                               [col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols[0]
                                for c in nested_df.select(nc+'.*').columns])
                  )
    for i in range(1, layers):
        print (flat_cols[i-1])
        flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
        nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])

        flat_df.append(flat_df[i-1].select(flat_cols[i] +
                                [col(nc+'.'+c).alias(nc+'_'+c)
                                    for nc in nested_cols[i]
                                    for c in flat_df[i-1].select(nc+'.*').columns])
        )

    return flat_df[-1]

只需使用以下方式进行调用:

my_flattened_df = flatten_df(my_df_having_nested_structs, 3)

(第二个参数是要被压扁的层数,对于我的情况是3)

4

使用PySpark解决具有structarray类型以及任意深度的嵌套df扁平化问题。这是在此基础上进行改进:https://dev59.com/dlkT5IYBdhLWcg3wV-CY#56533459

from pyspark.sql.types import *
from pyspark.sql import functions as f

def flatten_structs(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        
        parents, df = stack.pop()
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
        
        flat_cols = [
            f.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]
        
        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))
        
    return nested_df.select(columns)

def flatten_array_struct_df(df):
    
    array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    
    while len(array_cols) > 0:
        
        for array_col in array_cols:
            
            cols_to_select = [x for x in df.columns if x != array_col ]
            
            df = df.withColumn(array_col, f.explode(f.col(array_col)))
            
        df = flatten_structs(df)
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    return df

flat_df = flatten_array_struct_df(df)

1
这个非常好用。我会谨慎使用它,并承担起使用的责任,因为展平“结构体数组”可能会产生重复行,正如其他人所说。 - soMuchToLearnAndShare
你能解释一下每个for循环在做什么吗?我特别想知道为什么你在函数中定义了那些具体的数字(:5,:6)。 - user9532692
"array" -> 数组
"struct" -> 结构体
- Narahari B M
啊,字符串“数组”的长度为5,而“结构体”的长度为6。谢谢! - user9532692

2

根据https://dev59.com/dlkT5IYBdhLWcg3wV-CY#49532496,以下是结构体和多层嵌套数组字段的解决方案。

from pyspark.sql.functions import col, explode


def type_cols(df_dtypes, filter_type):
    cols = []
    for col_name, col_type in df_dtypes:
        if col_type.startswith(filter_type):
            cols.append(col_name)
    return cols


def flatten_df(nested_df, sep='_'):
    nested_cols = type_cols(nested_df.dtypes, "struct")
    flatten_cols = [fc for fc, _ in nested_df.dtypes if fc not in nested_cols]
    for nc in nested_cols:
        for cc in nested_df.select(f"{nc}.*").columns:
            if sep is None:
                flatten_cols.append(col(f"{nc}.{cc}").alias(f"{cc}"))
            else:
                flatten_cols.append(col(f"{nc}.{cc}").alias(f"{nc}{sep}{cc}"))
    return nested_df.select(flatten_cols)


def explode_df(nested_df):
    nested_cols = type_cols(nested_df.dtypes, "array")
    exploded_df = nested_df
    for nc in nested_cols:
        exploded_df = exploded_df.withColumn(nc, explode(col(nc)))
    return exploded_df


def flatten_explode_df(nested_df):
    df = nested_df
    struct_cols = type_cols(nested_df.dtypes, "struct")
    array_cols = type_cols(nested_df.dtypes, "array")
    if struct_cols:
        df = flatten_df(df)
        return flatten_explode_df(df)
    if array_cols:
        df = explode_df(df)
        return flatten_explode_df(df)
    return df


df = flatten_explode_df(nested_df)

2

如果您只需要转换结构类型,则可以使用此方法。我不建议转换数组,因为这可能会导致重复记录。

from pyspark.sql.functions import col
from pyspark.sql.types import StructType


def flatten_schema(schema, prefix=""):
    return_schema = []
    for field in schema.fields:
        if isinstance(field.dataType, StructType):
            if prefix:
                return_schema = return_schema + flatten_schema(field.dataType, "{}.{}".format(prefix, field.name))
            else:
                return_schema = return_schema + flatten_schema(field.dataType, field.name)
        else:
            if prefix:
                field_path = "{}.{}".format(prefix, field.name)
                return_schema.append(col(field_path).alias(field_path.replace(".", "_")))
            else:
                return_schema.append(field.name)
    return return_schema

您可以将其用作:
new_schema = flatten_schema(df.schema)
df1 = df.select(se)
df1.show()

1
这是关于 Scala Spark 的。
val totalMainArrayBuffer=collection.mutable.ArrayBuffer[String]()
def flatten_df_Struct(dfTemp:org.apache.spark.sql.DataFrame,dfTotalOuter:org.apache.spark.sql.DataFrame):org.apache.spark.sql.DataFrame=
{
//dfTemp.printSchema
val totalStructCols=dfTemp.dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
val mainArrayBuffer=collection.mutable.ArrayBuffer[String]()
for(totalStructCol <- totalStructCols)
{
val tempArrayBuffer=collection.mutable.ArrayBuffer[String]()
tempArrayBuffer+=s"${totalStructCol.split(",")(0)}.*"
//tempArrayBuffer.toSeq.toDF.show(false)
val columnsInside=dfTemp.selectExpr(tempArrayBuffer:_*).columns
for(column <- columnsInside)
mainArrayBuffer+=s"${totalStructCol.split(",")(0)}.${column} as ${totalStructCol.split(",")(0)}_${column}"
//mainArrayBuffer.toSeq.toDF.show(false)
}
//dfTemp.selectExpr(mainArrayBuffer:_*).printSchema
val nonStructCols=dfTemp.selectExpr(mainArrayBuffer:_*).dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(!_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
for (nonStructCol <- nonStructCols)
totalMainArrayBuffer+=s"${nonStructCol.split(",")(0).replace("_",".")} as ${nonStructCol.split(",")(0)}" // replacing _ by . in origial select clause if it's an already nested column 
dfTemp.selectExpr(mainArrayBuffer:_*).dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(_.split(",",2)(1).contains("Struct")).size 
match {
case value if value ==0 => dfTotalOuter.selectExpr(totalMainArrayBuffer:_*)
case _ => flatten_df_Struct(dfTemp.selectExpr(mainArrayBuffer:_*),dfTotalOuter)
}
}


def flatten_df(dfTemp:org.apache.spark.sql.DataFrame):org.apache.spark.sql.DataFrame=
{
var totalArrayBuffer=collection.mutable.ArrayBuffer[String]()
val totalNonStructCols=dfTemp.dtypes.map(x => x.toString.substring(1,x.toString.size-1)).filter(!_.split(",",2)(1).contains("Struct")) // in case i the column names come with the word Struct embedded in it
for (totalNonStructCol <- totalNonStructCols)
totalArrayBuffer+=s"${totalNonStructCol.split(",")(0)}"
totalMainArrayBuffer.clear
flatten_df_Struct(dfTemp,dfTemp) // flattened schema is now in totalMainArrayBuffer 
totalArrayBuffer=totalArrayBuffer++totalMainArrayBuffer
dfTemp.selectExpr(totalArrayBuffer:_*)
}


flatten_df(dfTotal.withColumn("tempStruct",lit(5))).printSchema




文件
{"num1":1,"num2":2,"bool1":true,"bool2":false,"double1":4.5,"double2":5.6,"str1":"a","str2":"b","arr1":[3,4,5],"map1":{"cool":1,"okay":2,"normal":3},"carInfo":{"Engine":{"Make":"sa","Power":{"IC":"900","battery":"165"},"Redline":"11500"} ,"Tyres":{"Make":"Pirelli","Compound":"c1","Life":"120"}}}
{"num1":3,"num2":4,"bool1":false,"bool2":false,"double1":4.2,"double2":5.5,"str1":"u","str2":"n","arr1":[6,7,9],"map1":{"fast":1,"medium":2,"agressive":3},"carInfo":{"Engine":{"Make":"na","Power":{"IC":"800","battery":"150"},"Redline":"10000"} ,"Tyres":{"Make":"Pirelli","Compound":"c2","Life":"100"}}}
{"num1":8,"num2":4,"bool1":true,"bool2":true,"double1":5.7,"double2":7.5,"str1":"t","str2":"k","arr1":[11,12,23],"map1":{"preserve":1,"medium":2,"fast":3},"carInfo":{"Engine":{"Make":"ta","Power":{"IC":"950","battery":"170"},"Redline":"12500"} ,"Tyres":{"Make":"Pirelli","Compound":"c3","Life":"80"}}}
{"num1":7,"num2":9,"bool1":false,"bool2":true,"double1":33.2,"double2":7.5,"str1":"b","str2":"u","arr1":[12,14,5],"map1":{"normal":1,"preserve":2,"agressive":3},"carInfo":{"Engine":{"Make":"pa","Power":{"IC":"920","battery":"160"},"Redline":"11800"} ,"Tyres":{"Make":"Pirelli","Compound":"c4","Life":"70"}}}

之前:

root
 |-- arr1: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- bool1: boolean (nullable = true)
 |-- bool2: boolean (nullable = true)
 |-- carInfo: struct (nullable = true)
 |    |-- Engine: struct (nullable = true)
 |    |    |-- Make: string (nullable = true)
 |    |    |-- Power: struct (nullable = true)
 |    |    |    |-- IC: string (nullable = true)
 |    |    |    |-- battery: string (nullable = true)
 |    |    |-- Redline: string (nullable = true)
 |    |-- Tyres: struct (nullable = true)
 |    |    |-- Compound: string (nullable = true)
 |    |    |-- Life: string (nullable = true)
 |    |    |-- Make: string (nullable = true)
 |-- double1: double (nullable = true)
 |-- double2: double (nullable = true)
 |-- map1: struct (nullable = true)
 |    |-- agressive: long (nullable = true)
 |    |-- cool: long (nullable = true)
 |    |-- fast: long (nullable = true)
 |    |-- medium: long (nullable = true)
 |    |-- normal: long (nullable = true)
 |    |-- okay: long (nullable = true)
 |    |-- preserve: long (nullable = true)
 |-- num1: long (nullable = true)
 |-- num2: long (nullable = true)
 |-- str1: string (nullable = true)
 |-- str2: string (nullable = true

之后:

root
 |-- arr1: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- bool1: boolean (nullable = true)
 |-- bool2: boolean (nullable = true)
 |-- double1: double (nullable = true)
 |-- double2: double (nullable = true)
 |-- num1: long (nullable = true)
 |-- num2: long (nullable = true)
 |-- str1: string (nullable = true)
 |-- str2: string (nullable = true)
 |-- map1_agressive: long (nullable = true)
 |-- map1_cool: long (nullable = true)
 |-- map1_fast: long (nullable = true)
 |-- map1_medium: long (nullable = true)
 |-- map1_normal: long (nullable = true)
 |-- map1_okay: long (nullable = true)
 |-- map1_preserve: long (nullable = true)
 |-- carInfo_Engine_Make: string (nullable = true)
 |-- carInfo_Engine_Redline: string (nullable = true)
 |-- carInfo_Tyres_Compound: string (nullable = true)
 |-- carInfo_Tyres_Life: string (nullable = true)
 |-- carInfo_Tyres_Make: string (nullable = true)
 |-- carInfo_Engine_Power_IC: string (nullable = true)
 |-- carInfo_Engine_Power_battery: string (nullable = true)


尝试了2个级别,它起作用了。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接