从PySpark数组列中删除重复项

4

我有一个 PySpark 数据帧,其中包含一个 ArrayType(StringType()) 列。这个列中包含重复的字符串数组,我需要将它们移除。例如,一行条目可能看起来像 [milk, bread, milk, toast]。假设我的数据框名为 df,我的列名为 arraycol。我需要像这样的东西:

df = df.withColumn("arraycol_without_dupes", F.remove_dupes_from_array("arraycol"))

我的直觉告诉我,这个问题一定有简单的解决方法,但是在stackoverflow上浏览了15分钟后,我没有找到比爆炸列、在整个数据框上去重然后再分组更好的方法。肯定有一种更简单的方法,只是我没想到对吧?

我正在使用Spark 2.4.0版本。


为什么你不能这样做:df = df.dropDuplicates(subset = ["arraycol"]) - YOLO
@YOLO:重复项在单行数组中...我会重新表述我的问题,以使其更加精确。 - Thomas
1个回答

22

对于pyspark版本2.4+,你可以使用pyspark.sql.functions.array_distinct函数:

from pyspark.sql.functions import array_distinct
df = df.withColumn("arraycol_without_dupes", array_distinct("arraycol"))
对于旧版本,您可以使用API函数使用explode + groupBy和collect_set来执行此操作,但在这里可能使用udf更有效:
from pyspark.sql.functions import udf

remove_dupes_from_array = udf(lambda row: list(set(row)), ArrayType(StringType()))
df = df.withColumn("arraycol_without_dupes", remove_dupes_from_array("arraycol"))

谢谢,这正是我在寻找的!我会看看是否可以将我的集群升级到2.4.0。 - Thomas
我已经成功升级了我的集群,并且可以确认这正好满足了我的需求。 - Thomas

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接