在Pyspark中高效计算加权滚动平均,需要注意以下几点。

4
我正在尝试在Pyspark中计算一个滚动加权平均值,窗口为(partition by id1, id2 ORDER BY unixTime)。想知道有没有人有关于如何做到这一点的想法。
滚动平均将采用当前行的列值,该列的前9个行值和后9个行值,并根据它们距离行的位置进行加权。因此,当前行的权重为10倍,滞后1 / lead 1值的权重为9倍。
如果所有值都不为空,则加权平均数的分母将为100。唯一的例外是,如果存在空值,我们仍然希望计算移动平均值(除非超过1/2的值为空)。
例如,如果当前值之前的9个值都为空,则分母将为55。如果超过1/2的值为空,那么我们将输出NULL作为加权平均值。我们还可以使用逻辑,即如果分母小于40或其他值,则输出null。
我已经附上了一张截图来解释我的意思,以防有些地方不太清楚,希望这能澄清问题:enter image description here 我知道我可以在SQL中做到这一点(我可以将数据框保存为临时视图),但由于我必须对多个列执行此滚动平均值(完全相同的逻辑),所以如果我能在Pyspark中做到这一点,最理想的情况是我将能够编写一个for循环,然后对每个列执行它。此外,我希望能够高效地完成此操作。我已经阅读了许多关于滚动平均数的线程,但我认为这种情况略有不同。
如果这样做不容易高效,我知道如何通过列出 lag(val, 10) over window... lag(val, 9) over window... 等来在sql中计算它,并且可以采用这种方法。抱歉如果我过于复杂化了,希望这有意义。 如果这样做不容易高效,我知道如何通过列出 lag(val, 10) over window... lag(val, 9) over window... 等来在sql中计算它,并且可以采用这种方法。

这个回答解决了你的问题吗?:https://dev59.com/Eqfja4cB1Zd3GeqPsj6_ - pissall
@pissall 不是的,那是我读过的帖子,但那个解决方案意味着 null 值将作为 0 来处理 -> 平均值会被扭曲,而不是从分母中删除 null 值。我相信我可能能够找到一种方法来修改那个解决方案并使其适用于我,但这将非常低效。 - WIT
我的建议是,您需要根据自己的需求对那个答案进行微调。 - pissall
1个回答

6

我理解的是,您可以尝试使用窗口函数 collect_list,对列表进行排序,使用 array_position 找到当前行的位置 idx (需要 Spark 2.4+ 版本),然后基于此计算权重。让我们以窗口大小为7(或下面代码中的 N=3)为例:

from pyspark.sql.functions import expr, sort_array, collect_list, struct
from pyspark.sql import Window

df = spark.createDataFrame([
    (0, 0.5), (1, 0.6), (2, 0.65), (3, 0.7), (4, 0.77),
    (5, 0.8), (6, 0.7), (7, 0.9), (8, 0.99), (9, 0.95)
], ["time", "val"])

N = 3

w1 = Window.partitionBy().orderBy('time').rowsBetween(-N,N)

# note that the index for array_position is 1-based, `i` in transform function is 0-based
df1 = df.withColumn('data', sort_array(collect_list(struct('time','val')).over(w1))) \
    .withColumn('idx', expr("array_position(data, (time,val))-1")) \
    .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))"))

df1.show(truncate=False)
+----+----+-------------------------------------------------------------------------+---+----------------------+
|time|val |data                                                                     |idx|weights               |
+----+----+-------------------------------------------------------------------------+---+----------------------+
|0   |0.5 |[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7]]                                |0  |[10, 9, 8, 7]         |
|1   |0.6 |[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77]]                     |1  |[9, 10, 9, 8, 7]      |
|2   |0.65|[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8]]           |2  |[8, 9, 10, 9, 8, 7]   |
|3   |0.7 |[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7]] |3  |[7, 8, 9, 10, 9, 8, 7]|
|4   |0.77|[[1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9]] |3  |[7, 8, 9, 10, 9, 8, 7]|
|5   |0.8 |[[2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99]]|3  |[7, 8, 9, 10, 9, 8, 7]|
|6   |0.7 |[[3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]|3  |[7, 8, 9, 10, 9, 8, 7]|
|7   |0.9 |[[4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]          |3  |[7, 8, 9, 10, 9, 8]   |
|8   |0.99|[[5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]                     |3  |[7, 8, 9, 10, 9]      |
|9   |0.95|[[6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]                               |3  |[7, 8, 9, 10]         |
+----+----+-------------------------------------------------------------------------+---+----------------------+

接下来,我们可以使用SparkSQL内置函数aggregate来计算权重和加权值的总和:

N = 9

w1 = Window.partitionBy().orderBy('time').rowsBetween(-N,N)

df_new = df.withColumn('data', sort_array(collect_list(struct('time','val')).over(w1))) \
    .withColumn('idx', expr("array_position(data, (time,val))-1")) \
    .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))"))\
    .withColumn('sum_weights', expr("aggregate(weights, 0D, (acc,x) -> acc+x)")) \
    .withColumn('weighted_val', expr("""
      aggregate(
        zip_with(data,weights, (x,y) -> x.val*y),
        0D, 
        (acc,x) -> acc+x,
        acc -> acc/sum_weights
      )""")) \
    .drop("data", "idx", "sum_weights", "weights")

df_new.show()
+----+----+------------------+
|time| val|      weighted_val|
+----+----+------------------+
|   0| 0.5|0.6827272727272726|
|   1| 0.6|0.7001587301587302|
|   2|0.65|0.7169565217391304|
|   3| 0.7|0.7332876712328767|
|   4|0.77|            0.7492|
|   5| 0.8|0.7641333333333333|
|   6| 0.7|0.7784931506849315|
|   7| 0.9|0.7963768115942028|
|   8|0.99|0.8138095238095238|
|   9|0.95|0.8292727272727273|
+----+----+------------------+

注意事项:

  • you can calculate multiple columns by setting struct('time','val1', 'val2') in the first line of calculating df_new and then adjust the corresponding calculation of idx and x.val*y in weighted_val etc.

  • to set NULL when less than half values are not able to be collected, add a IF(size(data) <= 9, NULL, ...) or IF(sum_weights < 40, NULL, ...) statement to the following:

      df_new = df.withColumn(...) \
      ...
          .withColumn('weighted_val', expr(""" IF(size(data) <= 9, NULL, 
            aggregate( 
              zip_with(data,weights, (x,y) -> x.val*y), 
              0D,  
              (acc,x) -> acc+x, 
              acc -> acc/sum_weights 
           ))""")) \
          .drop("data", "idx", "sum_weights", "weights")
    

编辑:如果需要多列,可以尝试以下方法:

cols = ['val1', 'val2', 'val3']

# function to set SQL expression to calculate weighted values for the field `val`
weighted_vals = lambda val: """
    aggregate(
      zip_with(data,weights, (x,y) -> x.{0}*y),
      0D,
      (acc,x) -> acc+x,
      acc -> acc/sum_weights
    ) as weighted_{0}
""".format(val)

df_new = df.withColumn('data', sort_array(collect_list(struct('time',*cols)).over(w1))) \
  .withColumn('idx', expr("array_position(data, (time,{}))-1".format(','.join(cols)))) \
  .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))")) \
  .withColumn('sum_weights', expr("aggregate(weights, 0D, (acc,x) -> acc+x)")) \
  .selectExpr(df.columns + [ weighted_vals(c) for c in cols ])

如果列的数量有限,我们可以编写 SQL 表达式来使用一个聚合函数计算加权值:
df_new = df.withColumn('data', sort_array(collect_list(struct('time',*cols)).over(w1))) \
  .withColumn('idx', expr("array_position(data, (time,{}))-1".format(','.join(cols)))) \
  .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))")) \
  .withColumn('sum_weights', expr("aggregate(weights, 0D, (acc,x) -> acc+x)")) \
  .withColumn("vals", expr(""" 
   aggregate( 
     zip_with(data, weights, (x,y) -> (x.val1*y as val1, x.val2*y as val2)),
     (0D as val1, 0D as val2), 
     (acc,x) -> (acc.val1 + x.val1, acc.val2 + x.val2),
     acc -> (acc.val1/sum_weights as weighted_val1, acc.val2/sum_weights as weighted_val2)
   )     
   """)).select(*df.columns, "vals.*")

这非常有帮助,而且完美运作,非常感谢!一个快速的问题 - 如果我想要计算多列并使用结构体方法,如何调整idx和x.val*y?因为它将是x.val1或x.val2(例如),所以它几乎就像一个for each语句。 - WIT
你的方法在计算多个列时似乎要好得多。我原本计划做的是列出这些列,然后执行for col in col_list: df = df.withColumn(...),但我认为这种方法效率更低。 - WIT
1
太棒了 - 这个可行。感谢您在这里寻找清洁解决方案时的帮助! - WIT
@jxc 我现在意识到的一件事是,如果存在空值,sum_weights 仍然会将其包括在分母中。有没有办法在聚合函数中执行 sum_weights,以便它对每个 val 都是唯一的? - WIT
嗨,@WIT,看起来你可以尝试在计算中使用SQL的IFNULLIF语句。 (1) 将sum_weights更改为aggregate(weights, 0D, (acc,x) -> acc+ifnull(x,0)),然后 (2) 在计算weighted_val时,将聚合函数的第一个参数更改为zip_with(data,weights, (x,y) -> if(x.val is null, 0, x.val*y)) - jxc
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接