使用复杂类型查询Spark SQL DataFrame

79

我如何查询包含复杂类型(如maps/arrays)的RDD?例如,当我编写以下测试代码时:

case class Test(name: String, map: Map[String, String])
val map = Map("hello" -> "world", "hey" -> "there")
val map2 = Map("hello" -> "people", "hey" -> "you")
val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2)))

我认为语法应该像这样:

sqlContext.sql("SELECT * FROM rdd WHERE map.hello = world")
或者
sqlContext.sql("SELECT * FROM rdd WHERE map[hello] = world")

但我得到

无法访问类型为MapType(StringType,StringType,true)的嵌套字段

org.apache.spark.sql.catalyst.errors.package$TreeNodeException:未解决的属性

分别。


我点赞了被接受的答案,它是查询复杂字段的所有方法的绝佳来源。对于那些正在进行此操作的人来说,以下是一个快速参考:map[hello]不起作用的原因是键是字符串字段,因此必须加引号:map['hello'] - Tristan Reid
3个回答

205

这取决于列的类型。让我们从一些虚拟数据开始:

import org.apache.spark.sql.functions.{udf, lit}
import scala.util.Try

case class SubRecord(x: Int)
case class ArrayElement(foo: String, bar: Int, vals: Array[Double])
case class Record(
  an_array: Array[Int], a_map: Map[String, String], 
  a_struct: SubRecord, an_array_of_structs: Array[ArrayElement])


val df = sc.parallelize(Seq(
  Record(Array(1, 2, 3), Map("foo" -> "bar"), SubRecord(1),
         Array(
           ArrayElement("foo", 1, Array(1.0, 2.0, 2.0)),
           ArrayElement("bar", 2, Array(3.0, 4.0, 5.0)))),
  Record(Array(4, 5, 6), Map("foz" -> "baz"), SubRecord(2),
         Array(ArrayElement("foz", 3, Array(5.0, 6.0)), 
               ArrayElement("baz", 4, Array(7.0, 8.0))))
)).toDF
df.registerTempTable("df")
df.printSchema

// root
// |-- an_array: array (nullable = true)
// |    |-- element: integer (containsNull = false)
// |-- a_map: map (nullable = true)
// |    |-- key: string
// |    |-- value: string (valueContainsNull = true)
// |-- a_struct: struct (nullable = true)
// |    |-- x: integer (nullable = false)
// |-- an_array_of_structs: array (nullable = true)
// |    |-- element: struct (containsNull = true)
// |    |    |-- foo: string (nullable = true)
// |    |    |-- bar: integer (nullable = false)
// |    |    |-- vals: array (nullable = true)
// |    |    |    |-- element: double (containsNull = false)
  • 数组(ArrayType)列:

    • Column.getItem 方法

      df.select($"an_array".getItem(1)).show
      
      // +-----------+
      // |an_array[1]|
      // +-----------+
      // |          2|
      // |          5|
      // +-----------+
      
    • Hive括号语法:

      sqlContext.sql("SELECT an_array[1] FROM df").show
      
      // +---+
      // |_c0|
      // +---+
      // |  2|
      // |  5|
      // +---+
      
      一个UDF。
      val get_ith = udf((xs: Seq[Int], i: Int) => Try(xs(i)).toOption)
      
      df.select(get_ith($"an_array", lit(1))).show
      
      // +---------------+
      // |UDF(an_array,1)|
      // +---------------+
      // |              2|
      // |              5|
      // +---------------+
      
      此外,除了上述列出的方法之外,Spark还支持对复合类型进行操作的日益增长的内置函数列表。值得注意的示例包括高阶函数,如transform(SQL 2.4+、Scala 3.0+、PySpark/SparkR 3.1+)。
      df.selectExpr("transform(an_array, x -> x + 1) an_array_inc").show
      // +------------+
      // |an_array_inc|
      // +------------+
      // |   [2, 3, 4]|
      // |   [5, 6, 7]|
      // +------------+
      
      import org.apache.spark.sql.functions.transform
      
      df.select(transform($"an_array", x => x + 1) as "an_array_inc").show
      // +------------+
      // |an_array_inc|
      // +------------+
      // |   [2, 3, 4]|
      // |   [5, 6, 7]|
      // +------------+
      
    • 筛选器(SQL 2.4+,Scala 3.0+,Python / SparkR 3.1+)

    • df.selectExpr("filter(an_array, x -> x % 2 == 0) an_array_even").show
      // +-------------+
      // |an_array_even|
      // +-------------+
      // |          [2]|
      // |       [4, 6]|
      // +-------------+
      
      import org.apache.spark.sql.functions.filter
      
      df.select(filter($"an_array", x => x % 2 === 0) as "an_array_even").show
      // +-------------+
      // |an_array_even|
      // +-------------+
      // |          [2]|
      // |       [4, 6]|
      // +-------------+
      
    • aggregate(SQL 2.4+,Scala 3.0+,PySpark / SparkR 3.1+):

    • df.selectExpr("aggregate(an_array, 0, (acc, x) -> acc + x, acc -> acc) an_array_sum").show
      // +------------+
      // |an_array_sum|
      // +------------+
      // |           6|
      // |          15|
      // +------------+
      
      import org.apache.spark.sql.functions.aggregate
      
      df.select(aggregate($"an_array", lit(0), (x, y) => x + y) as "an_array_sum").show
      // +------------+                                                                  
      // |an_array_sum|
      // +------------+
      // |           6|
      // |          15|
      // +------------+
      
    • 数组处理函数(array_*),例如 array_distinct (2.4+):

    • import org.apache.spark.sql.functions.array_distinct
      
      df.select(array_distinct($"an_array_of_structs.vals"(0))).show
      // +-------------------------------------------+
      // |array_distinct(an_array_of_structs.vals[0])|
      // +-------------------------------------------+
      // |                                 [1.0, 2.0]|
      // |                                 [5.0, 6.0]|
      // +-------------------------------------------+
      
    • array_max (array_min, 2.4+):

    • import org.apache.spark.sql.functions.array_max
      
      df.select(array_max($"an_array")).show
      // +-------------------+
      // |array_max(an_array)|
      // +-------------------+
      // |                  3|
      // |                  6|
      // +-------------------+
      
    • flatten (2.4+)

      import org.apache.spark.sql.functions.flatten
      
      df.select(flatten($"an_array_of_structs.vals")).show
      // +---------------------------------+
      // |flatten(an_array_of_structs.vals)|
      // +---------------------------------+
      // |             [1.0, 2.0, 2.0, 3...|
      // |             [5.0, 6.0, 7.0, 8.0]|
      // +---------------------------------+
      
    • arrays_zip (2.4+):

      import org.apache.spark.sql.functions.arrays_zip
      
      df.select(arrays_zip($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show(false)
      // +--------------------------------------------------------------------+
      // |arrays_zip(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
      // +--------------------------------------------------------------------+
      // |[[1.0, 3.0], [2.0, 4.0], [2.0, 5.0]]                                |
      // |[[5.0, 7.0], [6.0, 8.0]]                                            |
      // +--------------------------------------------------------------------+
      
    • array_union(2.4+):

      import org.apache.spark.sql.functions.array_union
      
      df.select(array_union($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show
      // +---------------------------------------------------------------------+
      // |array_union(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
      // +---------------------------------------------------------------------+
      // |                                                 [1.0, 2.0, 3.0, 4...|
      // |                                                 [5.0, 6.0, 7.0, 8.0]|
      // +---------------------------------------------------------------------+
      
    • slice(2.4+):

      import org.apache.spark.sql.functions.slice
      
      df.select(slice($"an_array", 2, 2)).show
      // +---------------------+
      // |slice(an_array, 2, 2)|
      // +---------------------+
      // |               [2, 3]|
      // |               [5, 6]|
      // +---------------------+
      
    • 映射(MapType)列

      • 使用Column.getField方法:

  • df.select($"a_map".getField("foo")).show
    
    // +----------+
    // |a_map[foo]|
    // +----------+
    // |       bar|
    // |      null|
    // +----------+
    
    使用Hive括号语法:
    sqlContext.sql("SELECT a_map['foz'] FROM df").show
    
    // +----+
    // | _c0|
    // +----+
    // |null|
    // | baz|
    // +----+
    
    使用点语法和完整路径:
    df.select($"a_map.foo").show
    
    // +----+
    // | foo|
    // +----+
    // | bar|
    // |null|
    // +----+
    
    使用用户自定义函数(UDF)
    val get_field = udf((kvs: Map[String, String], k: String) => kvs.get(k))
    
    df.select(get_field($"a_map", lit("foo"))).show
    
    // +--------------+
    // |UDF(a_map,foo)|
    // +--------------+
    // |           bar|
    // |          null|
    // +--------------+
    
    不断增加的map_*函数,如map_keys(2.3+)
    import org.apache.spark.sql.functions.map_keys
    
    df.select(map_keys($"a_map")).show
    // +---------------+
    // |map_keys(a_map)|
    // +---------------+
    // |          [foo]|
    // |          [foz]|
    // +---------------+
    
  • 或者 map_values (2.3+)

    import org.apache.spark.sql.functions.map_values
    
    df.select(map_values($"a_map")).show
    // +-----------------+
    // |map_values(a_map)|
    // +-----------------+
    // |            [bar]|
    // |            [baz]|
    // +-----------------+
    
    请查看SPARK-23899获取详细列表。
  • 使用点语法的完整路径来操作结构体 (StructType) 列:

    • 使用 DataFrame API

df.select($"a_struct.x").show

// +---+
// |  x|
// +---+
// |  1|
// |  2|
// +---+
  • 使用原始的SQL语句

  • sqlContext.sql("SELECT a_struct.x FROM df").show
    
    // +---+
    // |  x|
    // +---+
    // |  1|
    // |  2|
    // +---+
    
  • 可以使用点语法、名称和标准的Column方法访问structs数组内的字段:

  • df.select($"an_array_of_structs.foo").show
    
    // +----------+
    // |       foo|
    // +----------+
    // |[foo, bar]|
    // |[foz, baz]|
    // +----------+
    
    sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show
    
    // +---+
    // |_c0|
    // +---+
    // |foo|
    // |foz|
    // +---+
    
    df.select($"an_array_of_structs.vals".getItem(1).getItem(1)).show
    
    // +------------------------------+
    // |an_array_of_structs.vals[1][1]|
    // +------------------------------+
    // |                           4.0|
    // |                           8.0|
    // +------------------------------+
    
  • 用户定义类型(UDT)字段可以通过UDF进行访问。有关详细信息,请参见Spark SQL referencing attributes of UDT

  • 注意事项:

    • 根据Spark版本的不同,其中一些方法可能仅在HiveContext中可用。UDF应该独立于版本使用标准的SQLContextHiveContext来工作。
    • 一般来说,嵌套值是二等公民。并非所有典型的操作都支持嵌套字段。根据上下文,将模式展平和/或展开集合可能更好

    df.select(explode($"an_array_of_structs")).show
    
    // +--------------------+
    // |                 col|
    // +--------------------+
    // |[foo,1,WrappedArr...|
    // |[bar,2,WrappedArr...|
    // |[foz,3,WrappedArr...|
    // |[baz,4,WrappedArr...|
    // +--------------------+
    
  • 使用通配符(*)可以与点语法结合使用,选择(可能多个)字段而不需要明确指定名称:

  • df.select($"a_struct.*").show
    // +---+
    // |  x|
    // +---+
    // |  1|
    // |  2|
    // +---+
    
  • 使用get_json_objectfrom_json函数可以查询JSON列。有关详细信息,请参见如何使用Spark DataFrames查询JSON数据列?


  • 能否获取结构体数组中的所有元素?是否可以像这样做... sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show - user1384205
    如何使用代码而非Spark SQL执行与SELECT an_array_of_structs[0].foo FROM df相同的操作?并且是否支持在一个结构体数组列(an_array_of_structs)上执行UDF?例如,使用代码执行SELECT max(an_array_of_structs.bar) FROM df - DeepNightTwo
    哇,太棒了,开放式的回答。非常感谢你。 - Pasha
    1
    哇,十分惊人的答案! - SkyWalker
    当我尝试导入org.apache.spark.sql.functions.transform时,出现了错误。所有其他的导入似乎都可以工作,你有什么想法为什么会发生这种情况? - Benji Kok

    2
    一旦您将其转换为DF,就可以简单地获取数据,如下所示:
      val rddRow= rdd.map(kv=>{
        val k = kv._1
        val v = kv._2
        Row(k, v)
      })
    
    val myFld1 =  StructField("name", org.apache.spark.sql.types.StringType, true)
    val myFld2 =  StructField("map", org.apache.spark.sql.types.MapType(StringType, StringType), true)
    val arr = Array( myFld1, myFld2)
    val schema = StructType( arr )
    val rowrddDF = sqc.createDataFrame(rddRow, schema)
    rowrddDF.registerTempTable("rowtbl")  
    val rowrddDFFinal = rowrddDF.select(rowrddDF("map.one"))
    or
    val rowrddDFFinal = rowrddDF.select("map.one")
    

    当我尝试这个时,我得到了 error: value _1 is not a member of org.apache.spark.sql.Row 的错误。 - Paul

    -1

    这是我所做的,它起作用了

    case class Test(name: String, m: Map[String, String])
    val map = Map("hello" -> "world", "hey" -> "there")
    val map2 = Map("hello" -> "people", "hey" -> "you")
    val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2)))
    val rdddf = rdd.toDF
    rdddf.registerTempTable("mytable")
    sqlContext.sql("select m.hello from mytable").show
    

    结果

    +------+
    | hello|
    +------+
    | world|
    |people|
    +------+
    

    网页内容由stack overflow 提供, 点击上面的
    可以查看英文原文,
    原文链接