Scala中“collect”函数的Spark Dataset等价函数,使用部分函数

5
正常的Scala集合有一个很有趣的collect方法,可以使用部分函数在一次遍历中执行filter-map操作。Spark Dataset上有等效的操作吗?我想出于两个原因:
  • 语法简洁
  • 它将filter-map样式操作减少为单次遍历(尽管在Spark中,我猜测有优化可以为您发现这些内容)
以下是一个示例,显示我的意思。假设我有一个选项序列,我想提取和加倍定义的整数(即Some中的那些)。
val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5)) 

Method 1 - collect

input.collect {
  case Some(value) => value * 2
} 
// List(6, -2, 8, 10)
collect 这个方法语法很简洁,只需要一次遍历即可完成。 方法2 - filter-map
input.filter(_.isDefined).map(_.get * 2)

我可以将这种模式应用到Spark中,因为数据集和数据框具有类似的方法。
但我不太喜欢这种方法,因为`isDefined`和`get`看起来像是代码异味。这里存在着一个隐含的假设,即map函数仅接收`Some`。编译器无法验证这一点。在一个更大的例子中,开发者很难发现这个假设,例如开发者可能会交换过滤和映射操作的顺序,而不会得到语法错误。
第三种方法-`fold*`操作
input.foldRight[List[Int]](Nil) {
  case (nextOpt, acc) => nextOpt match {
    case Some(next) => next*2 :: acc
    case None => acc
  }
}

我没有足够的Spark使用经验,不知道fold是否有一个等效的函数,因此这可能有点离题。

总之,模式匹配、折叠模板和列表重建都混杂在一起,很难阅读。


因此,总的来说,我认为collect语法最好,希望Spark也有类似的东西。


collect方法定义在RDDDataset上,用于将数据实现化到驱动程序中。尽管没有类似于集合API的collect方法,但您的直觉是正确的:由于两个操作都是惰性评估的,因此引擎有机会优化这些操作并将它们链接起来,以便以最大的局部性执行。 - stefanobaghino
5个回答

4

这里的答案是不正确的,至少在当前的Spark版本中是如此。

实际上,RDD确实有一个collect方法,它接受一个部分函数并对数据应用过滤器和映射。这与无参数的.collect()方法完全不同。请参阅Spark源代码RDD.scala @ line 955:

/**
 * Return an RDD that contains all matching values by applying `f`.
 */
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  filter(cleanF.isDefinedAt).map(cleanF)
}

这不会从RDD中提取数据,相反,它与RDD.scala文件中的无参.collect()方法有所区别(位于第923行)。
/**
 * Return an array that contains all of the elements in this RDD.
 */
def collect(): Array[T] = withScope {
  val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}

在文档中,注意看一下

的用法。
def collect[U](f: PartialFunction[T, U]): RDD[U]

该方法没有与其关联的警告,提示数据将被加载到驱动程序的内存中:

https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.rdd.RDD@collect[U](f:PartialFunction[T,U])(implicitevidence$29:scala.reflect.ClassTag[U]):org.apache.spark.rdd.RDD[U]

Spark拥有这些不同功能的重载方法非常令人困惑。


编辑:我犯了错误!我误读了问题,我们正在谈论的是DataSets而不是RDDs。然而,接受的答案说:

"然而,Spark文档指出,“只有在预期结果数组很小且所有数据都会被加载到驱动程序的内存中时才应使用此方法。”

这是不正确的!当调用带有部分函数的.collect()版本时,数据不会被加载到驱动程序的内存中 - 只有在调用无参数版本时才会加载。如上面的源代码所示,调用.collect(partial_function)应该与依次调用.filter()和.map()具有相同的性能。


1
谢谢你的回答。不过问题是关于数据集,而不是RDD。另一个答案提到了如何将数据集转换为RDD,然后调用collect函数。 - rmin
啊,抱歉,是我的错误!我会编辑答案,但对于一些人来说,看到.collect()和.collect(pf)之间的区别仍然可能有用。 - shoffing

3

出于完整性考虑:

RDD API确实有这样的方法,因此将给定的Dataset/DataFrame转换为RDD进行收集操作,然后再转换回来始终是一种选择,例如:

val dataset = Seq(Some(1), None, Some(2)).toDS()
val dsResult = dataset.rdd.collect { case Some(i) => i * 2 }.toDS()
然而,这种方法的性能可能会比在数据集上使用map和filter差(原因已经在@stefanobaghino的答案中解释了)。 至于DataFrames,这个特定的例子(使用Option)有点误导人,因为将其转换为DataFrame实际上会将Options“展平”为它们的值(或对于None的情况下为null),所以等价表达式将是:
val dataframe = Seq(Some(1), None, Some(2)).toDF("opt")
dataframe.withColumn("opt", $"opt".multiply(2)).filter(not(isnull($"opt")))

我认为,这种方法不会受到您所担心的“假设”输入内容的影响,特别是在涉及地图操作时。

2
RDDDataset上定义的collect方法用于将数据实现化到驱动程序中。
尽管没有类似于集合API的collect方法,但你的直觉是正确的:由于两种操作都是惰性求值的,引擎有机会优化操作并链接它们,以便以最大的局部性能执行它们。
对于你特别提到的用例,我建议你考虑使用flatMap,它适用于RDDDataset
// Assumes the usual spark-shell environment
// sc: SparkContext, spark: SparkSession
val collection = Seq(Some(1), None, Some(2), None, Some(3))
val rdd = sc.parallelize(collection)
val dataset = spark.createDataset(rdd)

// Both operations will yield `Array(2, 4, 6)`
rdd.flatMap(_.map(_ * 2)).collect
dataset.flatMap(_.map(_ * 2)).collect

// You can also express the operation in terms of a for-comprehension
(for (option <- rdd; n <- option) yield n * 2).collect
(for (option <- dataset; n <- option) yield n * 2).collect

// The same approach is valid for traditional collections as well
collection.flatMap(_.map(_ * 2))
for (option <- collection; n <- option) yield n * 2

编辑

正如另一个问题中正确指出的那样,RDDs实际上有collect方法,它通过应用部分函数来转换RDD,就像在普通集合中一样。然而,正如Spark文档所指出的那样,"只有在预计结果数组很小的情况下才应使用此方法,因为所有数据都加载到驱动程序的内存中。"


感谢您的回答@stefanobaghino!目前看来只有方法2可用,但我不太喜欢。即使没有收集,是否有更符合惯用法且更简洁的方法来解决我在Spark数据集上的示例问题? - rmin
对于你的回答中的这种情况,flatMap 是最好的选择。:-) val rdd = sc.parallelize(Seq(Some(1), None, Some(2), None, Some(3))); rdd.flatMap(_.map(_ * 2)).collect 将输出 Array(2, 4, 6)。你也可以使用 for-comprehension。我会在我的回答中加上这个。 - stefanobaghino
感谢您更新答案!我忘记了 for - rmin
5
我认为使用PartialFunction的collect方法并不会出现这个问题......警告是在另一个没有参数的collect方法中。 - Carlos Verdes

1

您可以随时创建自己的扩展方法:

implicit class DatasetOps[T](ds: Dataset[T]) {

  def collectt[U](pf: PartialFunction[T, U])(implicit enc: Encoder[U]): Dataset[U] = {
    ds.flatMap(pf.lift(_))
  }
}

如下所示:

// val ds = Dataset(1, 2, 3)
ds.collectt { case x if x % 2 == 1 => x * 3 }
// Dataset(3, 9)

请注意,我不幸地无法将其命名为collect(因此出现了可怕的后缀t),因为签名否则会与现有的Dataset#collect方法冲突,该方法将Dataset转换为Array

1

我希望扩展stefanobaghino的答案,包括使用案例类作为for推导表达式的示例,因为许多用例都涉及案例类。

此外,选项是单子类型,这使得接受的答案在这种情况下非常简单,因为for可以干净地删除None值,但这种方法不会扩展到非单子类型(例如案例类):

case class A(b: Boolean, i: Int, d: Double)

val collection = Seq(A(true, 3), A(false, 10), A(true, -1))
val rdd = ...
val dataset = ...

// Select out and double all the 'i' values where 'b' is true:
for {
  A(b, i, _) <- dataset
  if b
} yield i * 2

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接