Spark SQL 引用 UDT 的属性

6
我正在尝试实现自定义UDT,并能够从Spark SQL中引用它(正如Spark SQL白皮书第4.4.2节所述)。

真正的例子是拥有一个由Cap'n Proto或类似的离线数据结构支持的自定义UDT。

对于这篇文章,我举了一个牵强附会的例子。我知道我可以使用Scala case类而不需要做任何工作,但那不是我的目标。

例如,我有一个名为Person的对象,包含几个属性,我想能够执行SELECT person.first_name FROM person。我遇到了错误Can't extract value from person#1,不确定原因所在。

这里有完整的源代码(也可在https://github.com/andygrove/spark-sql-udt上获得)

package com.theotherandygrove

import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

object Example {

  def main(arg: Array[String]): Unit = {

    val conf = new SparkConf()
      .setAppName("Example")
      .setMaster("local[*]")

    val sc = new SparkContext(conf)

    val sqlContext = new SQLContext(sc)

    val schema = StructType(List(
      StructField("person_id", DataTypes.IntegerType, true),
      StructField("person", new MockPersonUDT, true)))

    // load initial RDD
    val rdd = sc.parallelize(List(
      MockPersonImpl(1),
      MockPersonImpl(2)
    ))

    // convert to RDD[Row]
    val rowRdd = rdd.map(person => Row(person.getAge, person))

    // convert to DataFrame (RDD + Schema)
    val dataFrame = sqlContext.createDataFrame(rowRdd, schema)

    // register as a table
    dataFrame.registerTempTable("person")

    // selecting the whole object works fine
    val results = sqlContext.sql("SELECT person.first_name FROM person WHERE person.age < 100")

    val people = results.collect

    people.map(row => {
      println(row)
    })

  }

}

trait MockPerson {
  def getFirstName: String
  def getLastName: String
  def getAge: Integer
  def getState: String
}

class MockPersonUDT extends UserDefinedType[MockPerson] {

  override def sqlType: DataType = StructType(List(
    StructField("firstName", StringType, nullable=false),
    StructField("lastName", StringType, nullable=false),
    StructField("age", IntegerType, nullable=false),
    StructField("state", StringType, nullable=false)
  ))

  override def userClass: Class[MockPerson] = classOf[MockPerson]

  override def serialize(obj: Any): Any = obj.asInstanceOf[MockPersonImpl].getAge

  override def deserialize(datum: Any): MockPerson = MockPersonImpl(datum.asInstanceOf[Integer])
}

@SQLUserDefinedType(udt = classOf[MockPersonUDT])
@SerialVersionUID(123L)
case class MockPersonImpl(n: Integer) extends MockPerson with Serializable {
  def getFirstName = "First" + n
  def getLastName = "Last" + n
  def getAge = n
  def getState = "AK"
}

如果我只是执行 SELECT person FROM person,那么查询就可以正常工作。尽管在模式中定义了这些属性,但我无法在SQL中引用它们。
1个回答

4

您会收到此错误,是因为sqlType定义的模式从未公开,并且不打算直接访问它。它只是提供了一种使用本地Spark SQL类型表示复杂数据类型的方法。

您可以使用UDF访问单个属性,但首先让我们展示内部结构确实未公开:

dataFrame.printSchema
// root
//  |-- person_id: integer (nullable = true)
//  |-- person: mockperso (nullable = true)

为了创建UDF,我们需要函数,该函数以所表示的给定UDT类型的对象作为参数。
import org.apache.spark.sql.functions.udf

val getFirstName = (person: MockPerson) => person.getFirstName
val getLastName = (person: MockPerson) => person.getLastName
val getAge = (person: MockPerson) => person.getAge

可以使用udf函数进行包装:

val getFirstNameUDF = udf(getFirstName)
val getLastNameUDF = udf(getLastName)
val getAgeUDF = udf(getAge)

dataFrame.select(
  getFirstNameUDF($"person").alias("first_name"),
  getLastNameUDF($"person").alias("last_name"),
  getAgeUDF($"person").alias("age")
).show()

// +----------+---------+---+
// |first_name|last_name|age|
// +----------+---------+---+
// |    First1|    Last1|  1|
// |    First2|    Last2|  2|
// +----------+---------+---+

要使用这些与原始SQL,您需要通过SQLContext注册函数:

sqlContext.udf.register("first_name", getFirstName)
sqlContext.udf.register("last_name", getLastName)
sqlContext.udf.register("age", getAge)

sqlContext.sql("""
  SELECT first_name(person) AS first_name, last_name(person) AS last_name
  FROM person
  WHERE age(person) < 100""").show

// +----------+---------+
// |first_name|last_name|
// +----------+---------+
// |    First1|    Last1|
// |    First2|    Last2|
// +----------+---------+

不幸的是,这需要付出代价。首先,每个操作都需要反序列化。它还大大限制了查询优化的方式。特别是在这些字段中任何一个上执行 join 操作都需要一个笛卡尔积。

实际上,如果要编码包含可以使用内置类型表示的属性的复杂结构,最好使用 StructType

case class Person(first_name: String, last_name: String, age: Int)

val df = sc.parallelize(
  (1 to 2).map(i => (i, Person(s"First$i", s"Last$i", i)))).toDF("id", "person")

df.printSchema

// root
//  |-- id: integer (nullable = false)
//  |-- person: struct (nullable = true)
//  |    |-- first_name: string (nullable = true)
//  |    |-- last_name: string (nullable = true)
//  |    |-- age: integer (nullable = false)

df
  .where($"person.age" < 100)
  .select($"person.first_name", $"person.last_name")
  .show

// +----------+---------+
// |first_name|last_name|
// +----------+---------+
// |    First1|    Last1|
// |    First2|    Last2|
// +----------+---------+

同时,应当保留UDTs(用户自定义类型)用于实际类型扩展,例如内置的VectorUDT或那些能够从特定表示中受益的事物,如枚举。


Stack Overflow不允许我在这里放太多评论,所以我要编辑原帖。我正在尝试创建一个UDT,如Spark SQL白皮书第4.4.2节所述。 - andygrove
1
我的答案仍然是一样的。如果你想要操作属性,你必须使用UDFs。 - zero323
1
@zero323 很好的回答!我仍然不太使用UDTs或UDFs,但这个解决方案听起来非常合理。 ;) - eliasah
据我所知,Spark SQL 论文中没有任何暗示可以直接访问内部表示的内容。相反,它将 UDF 描述为访问 UDT 的方法。@andygrove - zero323
我们可以定义自定义UDT,但不能自由和直接地使用它,这真的很糟糕。 - Tom

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接