Spark collect_list和limit操作的结果列表

8
我有一个如下格式的数据框:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

我希望您能将数据框按照 name 进行分组,收集列表并限制列表的大小。
以下是按照 name 进行分组和收集列表的方法:
val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

生成的数据框类似于:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

我希望限制每个键产生的列表大小。 我尝试了多种方法,但没有成功。 我已经看到一些帖子建议使用第三方解决方案,但我想避免这样做。 有办法吗?

你能给一个限制列表大小的例子吗?这个限制是指元素数量吗?那么你的限制基于什么呢? - Chitral Verma
我已经更新了问题。 - pirox22
你能否使用UDF来截断过长的列表? - ayplam
4个回答

6

如果您正在寻找一种更高性能、内存敏感的方法,虽然UDF可以满足您的需求,但是编写UDAF的方式可能更适合您。不幸的是,UDAF API实际上不像随Spark一起提供的聚合函数那样可扩展。但是,您可以使用它们的内部API来构建所需的内部函数。

这里有一个实现collect_list_limit的示例,大部分内容都是从Spark的内部CollectList聚合函数复制粘贴而来。我只是需要覆盖更新和合并方法以考虑传入的限制:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

要实际注册它,我们可以通过Spark的内部FunctionRegistry来完成,该函数接受名称和构建器,这实际上是一种使用提供的表达式创建CollectListLimit的函数:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

编辑:

事实证明,如果您尚未创建SparkContext,则将其添加到内置中仅在启动时进行不可变克隆才起作用。如果您已经有现有上下文,则应使用反射来添加它:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )

5
您可以创建一个函数来限制聚合ArrayType列的大小,如下所示:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

case class KV(k: String, v: String)

val df = Seq(
  ("key1", KV("internalKey1", "value1")),
  ("key1", KV("internalKey2", "value2")),
  ("key2", KV("internalKey3", "value3")),
  ("key2", KV("internalKey4", "value4")),
  ("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")

def limitSize(n: Int, arrCol: Column): Column =
  array( (0 until n).map( arrCol.getItem ): _* )

df.
  groupBy("name").agg( collect_list(col("merged")).as("final") ).
  select( $"name", limitSize(2, $"final").as("final2") ).
  show(false)
// +----+----------------------------------------------+
// |name|final2                                        |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+

我在Databricks上有一个名为“leo c”的文件夹。 - thebluephantom

1
你可以使用UDF。
这里是一个可能的示例,无需模式并具有有意义的缩减:
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob1 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("key", 1L, "gargamel"),
  ("key", 4L, "pe_gadol"),
  ("key", 2L, "zaam"),
  ("key1", 5L, "naval")
).toDF("group", "quality", "other")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByQuality, rawSchema)

val aggDf = rawDf
  .groupBy("group")
  .agg(
    count(struct("*")).as("num_reads"),
    max(col("quality")).as("quality"),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .drop("horizontal")


aggDf.printSchema

aggDf.show(false)
}

def reduceByQuality= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val quality1 = r1.getAs[Long]("quality")
  val quality2 = r2.getAs[Long]("quality")

  val r3 = quality1 match {
    case a if a >= quality2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}
}

这是一个与您的数据类似的示例。
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._


val df1 = Seq(
  ("key1", ("internalKey1", "value1")),
  ("key1", ("internalKey2", "value2")),
  ("key2", ("internalKey3", "value3")),
  ("key2", ("internalKey4", "value4")),
  ("key2", ("internalKey5", "value5"))
)
  .toDF("name", "merged")

//    df1.printSchema
//
//    df1.show(false)

val res = df1
  .groupBy("name")
  .agg( collect_list(col("merged")).as("final") )

res.printSchema

res.show(false)

def f= (x: Any) => {

  val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

  val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head

  d1.toString
}

val fUdf = udf(f, StringType)

val d2 = res
  .withColumn("d", fUdf(col("final")))
  .drop("final")

d2.printSchema()

d2
  .show(false)
 }
 }

1

我很感激这是一个老问题,但我也想做同样的事情,现在自从3.1.0版本,slice函数可以帮助解决这个问题:

val df = Seq(
  ("key1", ("internalKey1", "value1")),
  ("key1", ("internalKey2", "value2")),
  ("key1", ("internalKey3", "value1")),
  ("key1", ("internalKey4", "value1")),  
  ("key2", ("internalKey3", "value3")),
  ("key2", ("internalKey4", "value4")),
  ("key2", ("internalKey5", "value5"))
).toDF("name", "merged")

val result = df.groupBy("name").agg(slice(collect_list("merged"),1,2).as("limited_list"))

result.show(false)

输出:

+----+------------------------------------------------+
|name|limited_list                                    |
+----+------------------------------------------------+
|key1|[{internalKey1, value1}, {internalKey2, value2}]|
|key2|[{internalKey3, value3}, {internalKey4, value4}]|
+----+------------------------------------------------+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接