如何在Java / Kotlin中创建一个返回复杂类型的Spark UDF?

7

我正在尝试编写一个返回复杂类型的UDF:

private val toPrice = UDF1<String, Map<String, String>> { s ->
    val elements = s.split(" ")
    mapOf("value" to elements[0], "currency" to elements[1])
}


val type = DataTypes.createStructType(listOf(
        DataTypes.createStructField("value", DataTypes.StringType, false),
        DataTypes.createStructField("currency", DataTypes.StringType, false)))
df.sqlContext().udf().register("toPrice", toPrice, type)

但是每当我使用这个:

df = df.withColumn("price", callUDF("toPrice", col("price")))

我遇到了一个神秘的错误:

Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$28: (string) => struct<value:string,currency:string>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
Caused by: scala.MatchError: {value=138.0, currency=USD} (of class java.util.LinkedHashMap)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:379)
    ... 19 more

我尝试使用自定义数据类型:

class Price(val value: Double, val currency: String) : Serializable

使用返回该类型的UDF:
private val toPrice = UDF1<String, Price> { s ->
    val elements = s.split(" ")
    Price(elements[0].toDouble(), elements[1])
}

但是,我得到了另一个MatchError,它抱怨Price类型。

如何正确编写可以返回复杂类型的UDF?

2个回答

12

TL;DR 该函数应返回一个org.apache.spark.sql.Row类的对象。

Spark提供了两种主要的UDF定义变体。

  1. 使用Scala反射的udf变体:

    • def udf[RT](f: () ⇒ RT)(implicit arg0: TypeTag[RT]): UserDefinedFunction
    • def udf[RT, A1](f: (A1) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1]): UserDefinedFunction
    • ...
    • def udf[RT, A1, A2, ..., A10](f: (A1, A2, ..., A10) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1], arg2: TypeTag[A2], ..., arg10: TypeTag[A10])

    这些变量是在没有模式的情况下使用原子或代数数据类型定义的Scala闭包作为用户定义函数(UDF)。根据Scala闭包的签名自动推断数据类型。

    这些变量用于原子或代数数据类型,不需要模式。例如,所讨论的函数将在Scala中定义:

    case class Price(value: Double, currency: String) 
    
    val df = Seq("1 USD").toDF("price")
    
    val toPrice = udf((s: String) => scala.util.Try { 
      s split(" ") match {
        case Array(price, currency) => Price(price.toDouble, currency)
      }
    }.toOption)
    
    df.select(toPrice($"price")).show
    // +----------+
    // |UDF(price)|
    // +----------+
    // |[1.0, USD]|
    // +----------+
    

    在这种变体中,返回类型会自动编码。

    由于它依赖反射,所以这个变体主要是为Scala用户设计的。

  2. udf变体提供模式定义(您在此处使用的一种)。此变体的返回类型应与Dataset [Row]的返回类型相同:

    • 正如其他答案指出的那样,您只能使用SQL类型映射表中列出的类型(原子类型或包装或未包装的类型、java.sql.Timestamp/java.sql.Date,以及高级集合)。

    • 复杂结构(structs/StructTypes)使用org.apache.spark.sql.Row表示。不允许与代数数据类型或等效物混合使用。例如(Scala代码)

      struct<_1:int,_2:struct<_1:string,_2:struct<_1:double,_2:int>>>
      

      应表达为

      Row(1, Row("foo", Row(-1.0, 42))))
      
      (1, ("foo", (-1.0, 42))))
      

      或者任何混合变体,比如

      Row(1, Row("foo", (-1.0, 42))))
      

    这个变量主要是为了确保与Java的互操作性。

    在这种情况下(相当于问题中的那个),定义应该类似于以下定义:

    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.Row
    
    
    val schema = StructType(Seq(
      StructField("value", DoubleType, false),
      StructField("currency", StringType, false)
    ))
    
    val toPrice = udf((s: String) => scala.util.Try { 
      s split(" ") match {
        case Array(price, currency) => Row(price.toDouble, currency)
      }
    }.getOrElse(null), schema)
    
    df.select(toPrice($"price")).show
    // +----------+
    // |UDF(price)|
    // +----------+
    // |[1.0, USD]|
    // |      null|
    // +----------+
    

    除了所有异常处理细节(一般来说,UDFs 应该检查 null 输入,并按照约定优雅地处理畸形数据),Java 的等价物看起来应该是这样的:

    UserDefinedFunction price = udf((String s) -> {
        String[] split = s.split(" ");
        return RowFactory.create(Double.parseDouble(split[0]), split[1]);
    }, DataTypes.createStructType(new StructField[]{
        DataTypes.createStructField("value", DataTypes.DoubleType, true),
        DataTypes.createStructField("currency", DataTypes.StringType, true)
    }));
    

上下文:

为了让你更好地理解,这种差异也反映在API的其他部分中。例如,您可以根据模式和Rows序列创建DataFrame

def createDataFrame(rows: List[Row], schema: StructType): DataFrame 

或者使用反射,结合一系列的Products

def createDataFrame[A <: Product](data: Seq[A])(implicit arg0: TypeTag[A]): DataFrame 

但不支持混合变体。

换言之,您应提供可使用RowEncoder编码的输入。

当然,您通常不会像这样使用udf来完成此任务:

import org.apache.spark.sql.functions._

df.withColumn("price", struct(
  split($"price", " ")(0).cast("double").alias("price"),
  split($"price", " ")(1).alias("currency")
))

相关文章:


2
很简单。前往数据类型参考,找到对应的类型。
在Spark 2.3中:
- 如果您将返回类型声明为StructType,则函数必须返回org.apache.spark.sql.Row。 - 如果您返回Map<String, String>,函数返回类型应为MapType - 显然不是您想要的内容。

我并没有转换整个“行”。我只是像这样转换列中的一个值:5.0 USD 转换成一个结构体:Price(5.0, "USD")。请再次阅读我的问题:我也尝试了使用正确的 StructType 声明来使用 Price,但仍然不起作用。 - Hexworks

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接