Spark 2.4引入了新的SQL函数slice
,可以用于从数组列中提取一定范围的元素。
然而,简单地将该列传递给slice函数会失败,该函数似乎期望开始和结束值为整数。有没有一种方法可以在不编写UDF的情况下完成这个任务?
通过以下示例可视化问题:我有一个包含数组列arr
的数据框,每行数组看起来像['a', 'b', 'c']
。还有一个end_idx
列,其中的元素分别是3
、1
和2
:
+---------+-------+
|arr |end_idx|
+---------+-------+
|[a, b, c]|3 |
|[a, b, c]|1 |
|[a, b, c]|2 |
+---------+-------+
我尝试创建一个名为
arr_trimmed
的新列,代码如下:import pyspark.sql.functions as F
l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]
df = spark.createDataFrame(l, ["arr", "end_idx"])
df = df.withColumn("arr_trimmed", F.slice(F.col("arr"), 1, F.col("end_idx")))
我期望这段代码能够创建一个新列,其元素分别为['a', 'b', 'c']
、['a']
和['a', 'b']
。
但是实际上我收到了一个错误TypeError: Column is not iterable
。