如何使用Pyspark迭代一个分组并创建一个数组列?

3
我有一个包含组和百分比的数据框。
| Group | A % | B % | Target % |
| ----- | --- | --- | -------- |
| A     | .05 | .85 | 1.0      |
| A     | .07 | .75 | 1.0      |
| A     | .08 | .95 | 1.0      |
| B     | .03 | .80 | 1.0      |
| B     | .05 | .83 | 1.0      |
| B     | .04 | .85 | 1.0      |

我希望能够按照列A%迭代列Group,并找到从列B%中获取的一组值,当与列A%中的每个值相加时,总和小于或等于列Target %

| Group | A % | B % | Target % | SumArray     |
| ----- | --- | --- | -------- | ------------ |
| A     | .05 | .85 | 1.0      | [.85,.75,.95]|
| A     | .07 | .75 | 1.0      | [.85,.75]    |
| A     | .08 | .95 | 1.0      | [.85,.75]   |
| B     | .03 | .80 | 1.0      | [.80,.83,.85]|
| B     | .05 | .83 | 1.0      | [.80,.83,.85]|
| B     | .04 | .85 | 1.0      | [.80,.83,.85]|

我希望能够使用PySpark解决这个问题。你有什么想法吗?

1个回答

2
您可以使用collect_list函数获取按Group列分组的B %列值数组,然后使用条件A + B <= Target 过滤结果数组:
from pyspark.sql import Window
import pyspark.sql.functions as F

df2 = df.withColumn(
    "SumArray",
    F.collect_list(F.col("B")).over(Window.partitionBy("Group"))
).withColumn(
    "SumArray",
    F.expr("filter(SumArray, x -> x + A <= Target)")
)
df2.show()

# +-----+----+----+------+------------------+
# |Group|   A|   B|Target|          SumArray|
# +-----+----+----+------+------------------+
# |    B|0.03| 0.8|   1.0| [0.8, 0.83, 0.85]|
# |    B|0.05|0.83|   1.0| [0.8, 0.83, 0.85]|
# |    B|0.04|0.85|   1.0| [0.8, 0.83, 0.85]|
# |    A|0.05|0.85|   1.0|[0.85, 0.75, 0.95]|
# |    A|0.07|0.75|   1.0|      [0.85, 0.75]|
# |    A|0.08|0.95|   1.0|      [0.85, 0.75]|
# +-----+----+----+------+------------------+

谢谢!这个很好用! - Alex Triece

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接