DataFrame / Dataset groupBy行为/优化

37
假设我们有一个名为df的DataFrame,包含以下列:

名称、姓氏、尺寸、宽度、长度、重量

现在我们想执行一些操作,例如创建一些包含尺寸和宽度数据的DataFrame。
val df1 = df.groupBy("surname").agg( sum("size") )
val df2 = df.groupBy("surname").agg( sum("width") )

正如您所注意到的,其他列(例如长度)在任何地方都没有使用。Spark是否足够智能,在洗牌阶段之前删除多余的列,还是它们被携带着呢?运行:

val dfBasic = df.select("surname", "size", "width")

在分组之前会以某种方式影响性能吗?


1
Spark会按照他要求分组的列进行选择。您可以使用explain来获取查询的物理计划。 - eliasah
整体计划取决于输入数据。如果输入数据是parquet(列式)格式,则Spark可以直接访问列。但是,如果输入数据是CSV格式,则Spark需要读取整个文件和所有列,然后继续仅投影所需的列。 - Kent Pawar
1个回答

51

是的,它是“足够聪明”的。在 DataFrame 上执行的 groupBy 操作与在普通 RDD 上执行的 groupBy 操作并不相同。在您描述的情况下,根本没有必要移动原始数据。让我们创建一个小例子来说明:

val df = sc.parallelize(Seq(
   ("a", "foo", 1), ("a", "foo", 3), ("b", "bar", 5), ("b", "bar", 1)
)).toDF("x", "y", "z")

df.groupBy("x").agg(sum($"z")).explain

// == Physical Plan ==
// *HashAggregate(keys=[x#148], functions=[sum(cast(z#150 as bigint))])
// +- Exchange hashpartitioning(x#148, 200)
//    +- *HashAggregate(keys=[x#148], functions=[partial_sum(cast(z#150 as bigint))])
//       +- *Project [_1#144 AS x#148, _3#146 AS z#150]
//          +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#144, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#145, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#146]
//             +- Scan ExternalRDDScan[obj#143]

如您所见,第一阶段是一个投影操作,只保留必需的列。接下来在本地聚合数据,最后传输和全局聚合数据。如果您使用的是Spark <= 1.4,则输出的答案会略有不同,但总体结构应该完全相同。

最后,DAG可视化展示上述描述实际作业:

group by and agg DAG

同样地,Dataset.groupByKey 随后跟随 reduceGroups,包含了 map-side (ObjectHashAggregate with partial_reduceaggregator) 和 reduce-side (ObjectHashAggregate with reduceaggregator) 的聚合操作:

case class Foo(x: String, y: String, z: Int)

val ds = df.as[Foo]
ds.groupByKey(_.x).reduceGroups((x, y) => x.copy(z = x.z + y.z)).explain

// == Physical Plan ==
// ObjectHashAggregate(keys=[value#126], functions=[reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
// +- Exchange hashpartitioning(value#126, 200)
//    +- ObjectHashAggregate(keys=[value#126], functions=[partial_reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
//       +- AppendColumns <function1>, newInstance(class $line40.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#126]
//          +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//             +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                +- Scan ExternalRDDScan[obj#3]

groupByKey + reduceGroups

然而,KeyValueGroupedDataset 的其他方法可能与RDD.groupByKey类似。例如,mapGroups(或flatMapGroups)不使用部分聚合。

ds.groupByKey(_.x)
  .mapGroups((_, iter) => iter.reduce((x, y) => x.copy(z = x.z + y.z)))
  .explain

//== Physical Plan ==
//*SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).x, true, false) AS x#37, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).y, true, false) AS y#38, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).z AS z#39]
//+- MapGroups <function2>, value#32.toString, newInstance(class $line15.$read$$iw$$iw$Foo), [value#32], [x#8, y#9, z#10], obj#36: $line15.$read$$iw$$iw$Foo
//   +- *Sort [value#32 ASC NULLS FIRST], false, 0
//      +- Exchange hashpartitioning(value#32, 200)
//         +- AppendColumns <function1>, newInstance(class $line15.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#32]
//            +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//               +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                  +- Scan ExternalRDDScan[obj#3]

groupByKey + mapGroups


2
@Niemand我建议阅读这篇文章,了解有关催化剂的深入信息。 - eliasah
@A.B 像答案中所说的那样,不!这个 group by 不像 RDD 级别的 group by 函数那样工作。 - eliasah
@eliasah,感谢您提供的信息。我尝试搜索和阅读任何解释DataFrame(特别是)和RDD在节点之间执行洗牌操作的性能和分布的来源,但都找不到,只有示例和输出。您能否指导我参加任何教授此类概念的课程(例如RDD中的groupbyKey很昂贵,而DF中的groupby则不是)? - A.B
1
我能想到唯一一个讨论这个话题的文献就是@holden的书《高性能Spark》。 - eliasah

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接