Spark UDAF:如何在UDAF(用户定义的聚合函数)中通过列字段名称从输入中获取值?

3

我正在尝试使用Spark UDAF将两个现有的列汇总到一个新列中。大部分关于Spark UDAF的教程都使用索引来获取输入行中每个列中的值。像这样:

input.getAs[String](1)

在我的更新方法中(override def update(buffer: MutableAggregationBuffer, input: Row): Unit),我使用了一个变量

,它在我的情况下也有效。然而,我想使用该列的字段名来获取该值。就像这样:

input.getAs[String](ColumnNames.BehaviorType)

其中ColumnNames.BehaviorType是一个在对象中定义的字符串对象:

 /**
    * Column names in the original dataset
    */
  object ColumnNames {
    val JobSeekerID = "JobSeekerID"
    val JobID = "JobID"
    val Date = "Date"
    val BehaviorType = "BehaviorType"
  }

This time it does not work. I got the following exception:
java.lang.IllegalArgumentException: 字段 "BehaviorType" 不存在。 at org.apache.spark.sql.types.StructType$$anonfun$fieldIndex$1.apply(StructType.scala:292) ... at org.apache.spark.sql.Row$class.getAs(Row.scala:333) at org.apache.spark.sql.catalyst.expressions.GenericRow.getAs(rows.scala:165) at com.recsys.UserBehaviorRecordsUDAF.update(UserBehaviorRecordsUDAF.scala:44)
Some relevant code segments:
This is part of my UDAF:
class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {


  override def inputSchema: StructType = StructType(
    StructField("JobID", IntegerType) ::
      StructField("BehaviorType", StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
    println(input.schema.treeString)
    println
    println(input.mkString(","))
    println
    println(this.inputSchema.treeString)
//    println
//    println(bufferSchema.treeString)

    input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
      case BehaviourTypes.viewed_job =>
        buffer(0) =
          buffer.getAs[Seq[Int]](0) :+ //Array[Int]  //TODO WHY??
          input.getAs[Int](0) //ColumnNames.JobID
      case BehaviourTypes.bookmarked_job =>
        buffer(1) =
          buffer.getAs[Seq[Int]](1) :+ //Array[Int]
            input.getAs[Int](0)//ColumnNames.JobID
      case BehaviourTypes.applied_job =>
        buffer(2) =
          buffer.getAs[Seq[Int]](2) :+  //Array[Int]
            input.getAs[Int](0) //ColumnNames.JobID
    }
  }

以下是调用 UDAF 的代码部分:
val ubrUDAF = new UserBehaviorRecordsUDAF

val userProfileDF = userBehaviorDS
  .groupBy(ColumnNames.JobSeekerID)
  .agg(
    ubrUDAF(
      userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
      userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
    ).as("profile str"))

看起来输入行的模式中的字段名没有传递到UDAF中:

XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
root
 |-- input0: integer (nullable = true)
 |-- input1: string (nullable = true)


30917,viewed_job

root
 |-- JobID: integer (nullable = true)
 |-- BehaviorType: string (nullable = true)

我的代码存在什么问题?


1
UADF应该是通用的函数,只要满足UADF要求,就可以与任何“DateSet”一起使用。 UADF的示例包括countavg等。您不需要定义一个UADF来与单个数据集一起使用,这就是为什么UADF没有设计支持列名称的原因。此外,在update中的“row”不是数据集的实际行,而只是传递给UADF的部分。例如,如果您对具有名为“num_users”的列的任何数据集执行“avg(num_users)”,那么更新函数中的“row”将只有1列,即原始数据集行的“num_users”列。 - sarveshseri
但是我已经在UDAF中定义了inputSchema: override def inputSchema: StructType = StructType( StructField("JobID", IntegerType) :: StructField("BehaviorType", StringType) :: Nil) 有没有关于“更新中的行不是数据集的实际行,而只是传递给UADF的部分”的参考资料?谢谢! - CyberPlayerOne
@Tyler提督九门步军巡捕五营统领,你解决了这个问题吗? - Ivan Balashov
@IvanBalashov 我使用了Aggregator而不是UDAF。 Aggregator是强类型的。请参见:https://dev59.com/kKjja4cB1Zd3GeqP5hlk - CyberPlayerOne
@Tyler提督九门步军巡捕五营统领,好的,我会尽力的。谢谢! - Ivan Balashov
1个回答

0

我也想在我的更新方法中使用输入模式中的字段名称,以创建易于维护的代码。

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class MyUDAF extends UserDefinedAggregateFunction {
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    val inputWSchema = new GenericRowWithSchema(input.toSeq.toArray, inputSchema)

最终转换为聚合器,运行时间减半。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接