如何在Spark中动态切分数组列?

7

Spark 2.4引入了新的SQL函数slice,可以用于从数组列中提取一定范围的元素。

然而,简单地将该列传递给slice函数会失败,该函数似乎期望开始和结束值为整数。有没有一种方法可以在不编写UDF的情况下完成这个任务?

通过以下示例可视化问题:我有一个包含数组列arr的数据框,每行数组看起来像['a', 'b', 'c']。还有一个end_idx列,其中的元素分别是312

+---------+-------+
|arr      |end_idx|
+---------+-------+
|[a, b, c]|3      |
|[a, b, c]|1      |
|[a, b, c]|2      |
+---------+-------+

我尝试创建一个名为arr_trimmed的新列,代码如下:
import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]
df = spark.createDataFrame(l, ["arr", "end_idx"])

df = df.withColumn("arr_trimmed", F.slice(F.col("arr"), 1, F.col("end_idx")))

我期望这段代码能够创建一个新列,其元素分别为['a', 'b', 'c']['a']['a', 'b']

但是实际上我收到了一个错误TypeError: Column is not iterable


可能是 将列值用作spark DataFrame函数的参数 的重复问题。 - pault
2个回答

16

你可以通过传递以下SQL表达式来完成:

df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)"))

这是完整的工作示例:

import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]

df = spark.createDataFrame(l, ["arr", "end_idx"])

df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)")).show(truncate=False)

+---------+-------+-----------+
|arr      |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3      |[a, b, c]  |
|[a, b, c]|1      |[a]        |
|[a, b, c]|2      |[a, b]     |
+---------+-------+-----------+

2

自Spark 2.4.0版本开始,slice函数接受列作为参数。因此,可以按以下方式使用:

df.withColumn("arr_trimmed", F.slice(arr, F.lit(1), end_idx))

David Vrba的示例可以这样重写:

import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]

df = spark.createDataFrame(l, ["arr", "end_idx"])

df.withColumn("arr_trimmed", F.slice("arr", F.lit(1), F.col("end_idx"))).show(truncate=False)


+---------+-------+-----------+
|arr      |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3      |[a, b, c]  |
|[a, b, c]|1      |[a]        |
|[a, b, c]|2      |[a, b]     |
+---------+-------+-----------+

这个答案是正确的,应该被接受为最佳答案,并且需要以下澄清 - slice 函数接受列作为参数,只要 startlength 都作为列表达式给出。例如,如果 start 作为整数而没有使用 lit(),就像原始问题中一样,我会得到 py4j.Py4JException: Method slice([class org.apache.spark.sql.Column, class java.lang.Integer, class org.apache.spark.sql.Column]) does not exist 的错误(即使在 Spark 3.2.1 中也是如此)。 - Vic

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接