Spark Scala: 自定义聚合函数计算中位数

4

我正在尝试找到一种方法,来计算给定数据框的中位数。

val df = sc.parallelize(Seq(("a",1.0),("a",2.0),("a",3.0),("b",6.0), ("b", 8.0))).toDF("col1", "col2")

+----+----+
|col1|col2|
+----+----+
|   a| 1.0|
|   a| 2.0|
|   a| 3.0|
|   b| 6.0|
|   b| 8.0|
+----+----+

现在我想要做类似这样的事情:
df.groupBy("col1").agg(calcmedian("col2"))

结果应该看起来像这样:
+----+------+
|col1|median|
+----+------+
|   a|   2.0|
|   b|   7.0|
+----+------+` 

因此calcmedian()必须是一个UDAF,但问题是,UDAF的“evaluate”方法只接受一行数据,而我需要整个表来排序并返回中位数...
// Once all entries for a group are exhausted, spark will evaluate to get the final result  
def evaluate(buffer: Row) = {...}

有没有什么办法可以实现这个?或者有没有其他好的解决方案?我想强调一下,我知道如何在“一个组”数据集上计算中位数。但是我不想在“foreach”循环中使用此算法,因为这样效率低下!

谢谢!


编辑:

这是我到目前为止尝试过的:

object calcMedian extends UserDefinedAggregateFunction {
    // Schema you get as an input 
    def inputSchema = new StructType().add("col2", DoubleType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("col2", DoubleType)
    // Returned type
    def dataType = DoubleType
    // Self-explaining 
    def deterministic = true
    // initialize - called once for each group
    def initialize(buffer: MutableAggregationBuffer) = {
        buffer(0) = 0.0
    }

    // called for each input record of that group
    def update(buffer: MutableAggregationBuffer, input: Row) = {
        buffer(0) = input.getDouble(0)
    }
    // if function supports partial aggregates, spark might (as an optimization) comput partial results and combine them together
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1(0) = input.getDouble(0)   
    }
    // Once all entries for a group are exhausted, spark will evaluate to get the final result
    def evaluate(buffer: Row) = {
        val tile = 50
        var median = 0.0

        //PROBLEM: buffer is a Row --> I need DataFrame here???
        val rdd_sorted = buffer.sortBy(x => x)
        val c = rdd_sorted.count()
        if (c == 1){
            median = rdd_sorted.first()                
        }else{
            val index = rdd_sorted.zipWithIndex().map(_.swap)
            val last = c
            val n = (tile/ 100d) * (c*1d)
            val k = math.floor(n).toLong       
            val d = n - k
            if( k <= 0) {
                median = rdd_sorted.first()
            }else{
                if (k <= c){
                    median = index.lookup(last - 1).head
                }else{
                    if(k >= c){
                        median = index.lookup(last - 1).head
                    }else{
                        median = index.lookup(k-1).head + d* (index.lookup(k).head - index.lookup(k-1).head)
                    }
                }
            }
        }
    }   //end of evaluate

你需要使用 groupByKey 对数据进行分组,将聚合后的数据转换为 Buffer,可以使用一些 UDF 来实现这个过程,然后创建一个 UDF 来计算中位数。 - Alberto Bonsanto
UserDefinedAggregateFunction 基类有许多成员需要实现,而不仅仅是 evaluate。传递给 evaluateRow 缓冲区是最后一步。您是否尝试过任何实现?如果是,能否展示一下您目前的代码? - mattinbits
@mattinbits:我已经添加了到目前为止我所考虑的代码... - johntechendso
1
a) 在计算近似或精确中位数方面,已经有内置函数可用。 b) 无法在UDAF中访问数据框架。 c) 在分布式环境中计算精确中位数由于定义的原因非常低效。 - zero323
我使用的是Spark版本1.5.2,approxQuantile方法不可用!这些组不太大,因此一旦DFs被groupBy(...)分组,就不应该有太多数据需要重新排列。但是,如果有解决方案,我想尝试一下。可能仍然比foreach循环更有效率。 - johntechendso
1
然后 percentile_approx / percentile 是什么。数据框上的 groupBy 不会物理移动数据(这里洗牌没有区别)。它是 aggregate(ByKey) 的等效形式,这在 API 中清晰地反映出来。无论如何,您都无法在 UDAF 内部访问数据框。 - zero323
1个回答

7

试试这个:

import org.apache.spark.functions._

val result = data.groupBy("col1").agg(callUDF("percentile_approx", col("col2"), lit(0.5)))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接