Scala Spark UDF 类型转换异常:WrappedArray$ofRef 无法转换为 [Lscala.Tuple2。

3

所以我执行必要的导入等操作。

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import spark.implicits._

然后定义一些经纬度点。

val london = (1.0, 1.0)
val suburbia = (2.0, 2.0)
val southampton = (3.0, 3.0)  
val york = (4.0, 4.0)  

我接下来创建了一个 Spark Dataframe,如下所示,并检查其是否有效:

val exampleDF = Seq((List(london,suburbia),List(southampton,york)),
    (List(york,london),List(southampton,suburbia))).toDF("AR1","AR2")
exampleDF.show()

数据框架包含以下类型

DataFrame = [AR1: array<struct<_1:double,_2:double>>, AR2: array<struct<_1:double,_2:double>>]

我创建了一个函数来创建点的组合

// function to do what I want
val latlongexplode =  (x: Array[(Double,Double)], y: Array[(Double,Double)]) => {
 for (a <- x; b <-y) yield (a,b)
}

我检查函数是否正常工作

latlongexplode(Array(london,york),Array(suburbia,southampton))

它确实能够实现。但是,当我将此函数创建为用户定义函数后

// declare function into a Spark UDF
val latlongexplodeUDF = udf (latlongexplode) 

当我尝试像这样在我创建的Spark DataFrame中使用它:

exampleDF.withColumn("latlongexplode", latlongexplodeUDF($"AR1",$"AR2")).show(false)

我得到了一个非常长的堆栈跟踪,基本上可以归结为:

java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef无法转换为 [Lscala.Tuple2;
org.apache.spark.sql.catalyst.expressions.ScalaUDF.$anonfun$f$3(ScalaUDF.scala:121) org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1063) org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:151) org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:50) org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:32) scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:273)

如何在Scala Spark中使这个UDF工作?(我目前使用2.4版本)

编辑:可能是我构建示例数据框的方式有问题。但是我实际数据中每列都是大小不定的经纬度元组数组。


1
你可能想联系 Raphael Roth,他似乎比大多数人更进一步。 - thebluephantom
这与数组的结构方面有关,但我不确定如何解决这个问题。 - thebluephantom
@raphaelroth,您能否发表评论? - thebluephantom
2
@thebluephantom 不需要 Raphael,我已经解决了 :) - mck
@mck 感谢您的解释和解决方案。非常感激。 - Mamonu
1个回答

3

在使用UDF时,结构体类型被表示为Row对象,数组列则表示为Seq。同时,需要以Row的形式返回一个结构体,并定义一个模式来返回此结构体。

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

val london = (1.0, 1.0)
val suburbia = (2.0, 2.0)
val southampton = (3.0, 3.0)  
val york = (4.0, 4.0)
val exampleDF = Seq((List(london,suburbia),List(southampton,york)),
    (List(york,london),List(southampton,suburbia))).toDF("AR1","AR2")
exampleDF.show(false)
+------------------------+------------------------+
|AR1                     |AR2                     |
+------------------------+------------------------+
|[[1.0, 1.0], [2.0, 2.0]]|[[3.0, 3.0], [4.0, 4.0]]|
|[[4.0, 4.0], [1.0, 1.0]]|[[3.0, 3.0], [2.0, 2.0]]|
+------------------------+------------------------+

val latlongexplode = (x: Seq[Row], y: Seq[Row]) => {
    for (a <- x; b <- y) yield Row(a, b)
}

val udf_schema = ArrayType(
    StructType(Seq(
        StructField(
            "city1",
            StructType(Seq(
                StructField("lat", FloatType),
                StructField("long", FloatType)
            ))
        ),
        StructField(
            "city2",
            StructType(Seq(
                StructField("lat", FloatType),
                StructField("long", FloatType)
            ))
        )
    ))
)

// include this line if you see errors like 
// "You're using untyped Scala UDF, which does not have the input type information."
// spark.sql("set spark.sql.legacy.allowUntypedScalaUDF = true")

val latlongexplodeUDF = udf(latlongexplode, udf_schema)
result = exampleDF.withColumn("latlongexplode", latlongexplodeUDF($"AR1",$"AR2"))

result.show(false)
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+
|AR1                     |AR2                     |latlongexplode                                                                                          |
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+
|[[1.0, 1.0], [2.0, 2.0]]|[[3.0, 3.0], [4.0, 4.0]]|[[[1.0, 1.0], [3.0, 3.0]], [[1.0, 1.0], [4.0, 4.0]], [[2.0, 2.0], [3.0, 3.0]], [[2.0, 2.0], [4.0, 4.0]]]|
|[[4.0, 4.0], [1.0, 1.0]]|[[3.0, 3.0], [2.0, 2.0]]|[[[4.0, 4.0], [3.0, 3.0]], [[4.0, 4.0], [2.0, 2.0]], [[1.0, 1.0], [3.0, 3.0]], [[1.0, 1.0], [2.0, 2.0]]]|
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+

印象深刻,我差不多快完成了。明天再试。 - thebluephantom
刚刚看到了。我觉得错误信息很难理解。 - thebluephantom
1
我本以为需要一个 case class。这就是它。 - thebluephantom
1
@thebluephantom 是的,我想 case class 可能更好 - 定义 udf schema 已经被弃用了。但是对于一个需要定义复杂结构的情况来说,case class 似乎有点过于复杂,所以我选择了 udf schema。不过 OP 正在使用 spark 2.4,所以弃用并不是问题。 - mck
2
@mck + thebluephantom,非常感谢你们的帮助!我和Mamonu一起开发了一款名为Splink的开源数据链接软件,它使用Spark,这对我们非常有帮助! - RobinL

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接