Scala Spark中比较两个数组列

3

我有一个以下格式的数据框。

movieId1 | genreList1              | genreList2
--------------------------------------------------
1        |[Adventure,Comedy]       |[Adventure]
2        |[Animation,Drama,War]    |[War,Drama]
3        |[Adventure,Drama]        |[Drama,War]

尝试创建另一个标志列,以显示genreList2是否是genreList1的子集。

movieId1 | genreList1              | genreList2        | Flag
---------------------------------------------------------------
1        |[Adventure,Comedy]       | [Adventure]       |1
2        |[Animation,Drama,War]    | [War,Drama]       |1
3        |[Adventure,Drama]        | [Drama,War]       |0

我尝试过这个:

我尝试过这个:

def intersect_check(a: Array[String], b: Array[String]): Int = {
  if (b.sameElements(a.intersect(b))) { return 1 } 
  else { return 2 }
}

def intersect_check_udf =
  udf((colvalue1: Array[String], colvalue2: Array[String]) => intersect_check(colvalue1, colvalue2))

data = data.withColumn("Flag", intersect_check_udf(col("genreList1"), col("genreList2")))

但是这会引发错误

org.apache.spark.SparkException:无法执行用户定义的函数。

附言:上述函数(intersect_check)适用于数组

5个回答

6
我们可以定义一个udf,计算两个Array列之间的intersection长度,并检查它是否等于第二列的长度。如果是,则第二个数组是第一个数组的子集。
此外,您的udf的输入需要是WrappedArray[String]类,而不是Array[String]
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col

val same_elements = udf { (a: WrappedArray[String], 
                           b: WrappedArray[String]) => 
  if (a.intersect(b).length == b.length){ 1 }else{ 0 }  
}

df.withColumn("test",same_elements(col("genreList1"),col("genreList2")))
  .show(truncate = false)
+--------+-----------------------+------------+----+
|movieId1|genreList1             |genreList2  |test|
+--------+-----------------------+------------+----+
|1       |[Adventure, Comedy]    |[Adventure] |1   |
|2       |[Animation, Drama, War]|[War, Drama]|1   |
|3       |[Adventure, Drama]     |[Drama, War]|0   |
+--------+-----------------------+------------+----+

数据

val df = List((1,Array("Adventure","Comedy"), Array("Adventure")),
              (2,Array("Animation","Drama","War"), Array("War","Drama")),
              (3,Array("Adventure","Drama"),Array("Drama","War"))).toDF("movieId1","genreList1","genreList2")

3

以下是使用 subsetOf 进行转换的解决方案。

  val spark =
    SparkSession.builder().master("local").appName("test").getOrCreate()

  import spark.implicits._

  val data = spark.sparkContext.parallelize(
  Seq(
    (1,Array("Adventure","Comedy"),Array("Adventure")),
  (2,Array("Animation","Drama","War"),Array("War","Drama")),
  (3,Array("Adventure","Drama"),Array("Drama","War"))
  )).toDF("movieId1", "genreList1", "genreList2")


  val subsetOf = udf((col1: Seq[String], col2: Seq[String]) => {
    if (col2.toSet.subsetOf(col1.toSet)) 1 else 0
  })

  data.withColumn("flag", subsetOf(data("genreList1"), data("genreList2"))).show()

希望这能帮到您!

1

Spark 3.0+ (forall)

Spark 3.0+{{link1:forall}}

forall($"genreList2", x => array_contains($"genreList1", x)).cast("int")

完整示例:

val df = Seq(
     (1, Seq("Adventure", "Comedy"), Seq("Adventure")),
     (2, Seq("Animation", "Drama","War"), Seq("War", "Drama")),
     (3, Seq("Adventure", "Drama"), Seq("Drama", "War"))
     ).toDF("movieId1", "genreList1", "genreList2")

val df2 = df.withColumn("Flag", forall($"genreList2", x => array_contains($"genreList1", x)).cast("int"))

df2.show()
// +--------+--------------------+------------+----+
// |movieId1|          genreList1|  genreList2|Flag|
// +--------+--------------------+------------+----+
// |       1| [Adventure, Comedy]| [Adventure]|   1|
// |       2|[Animation, Drama...|[War, Drama]|   1|
// |       3|  [Adventure, Drama]|[Drama, War]|   0|
// +--------+--------------------+------------+----+

0

这也可以在这里工作,而且不使用UDF

 import spark.implicits._
 val data = Seq(
        (1,Array("Adventure","Comedy"),Array("Adventure")),
        (2,Array("Animation","Drama","War"),Array("War","Drama")),
        (3,Array("Adventure","Drama"),Array("Drama","War"))
      ).toDF("movieId1", "genreList1", "genreList2")

 data
     .withColumn("size",size(array_except($"genreList2",$"genreList1")))
     .withColumn("flag",when($"size" === lit(0), 1) otherwise(0))
     .show(false)

0
一个解决方案可能是利用Spark数组内置函数:如果两者之间的交集等于genreList2,则genreList2genreList1的子集。在下面的代码中,添加了一个sort_array操作,以避免具有不同排序但相同元素的两个数组不匹配。
val spark = {
    SparkSession
    .builder()
    .master("local")
    .appName("test")
    .getOrCreate()
}

import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._

val df = Seq(
    (1, Array("Adventure","Comedy"), Array("Adventure")),
    (2, Array("Animation","Drama","War"), Array("War","Drama")),
    (3, Array("Adventure","Drama"), Array("Drama","War"))
).toDF("movieId1", "genreList1", "genreList2")

df
.withColumn("flag",
 sort_array(array_intersect($"genreList1",$"genreList2"))
 .equalTo(
   sort_array($"genreList2")
 )
.cast("integer")
)
.show()

输出结果为

+--------+--------------------+------------+----+
|movieId1|          genreList1|  genreList2|flag|
+--------+--------------------+------------+----+
|       1| [Adventure, Comedy]| [Adventure]|   1|
|       2|[Animation, Drama...|[War, Drama]|   1|
|       3|  [Adventure, Drama]|[Drama, War]|   0|
+--------+--------------------+------------+----+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接