从PySpark DataFrame中的Python列表中删除一个元素。

7

我正在尝试从一个Python列表的列表中删除一个元素:

+---------------+
|        sources|
+---------------+
|           [62]|
|        [7, 32]|
|           [62]|
|   [18, 36, 62]|
|[7, 31, 36, 62]|
|    [7, 32, 62]|

我希望能够从以上列表中的每个列表中删除一个元素rm。我编写了一个可以对列表中的列表执行此操作的函数:

def asdf(df, rm):
    temp = df
    for n in range(len(df)):
        temp[n] = [x for x in df[n] if x != rm]
    return(temp)

这将去除rm = 1:

a = [[1,2,3],[1,2,3,4],[1,2,3,4,5]]
In:  asdf(a,1)
Out: [[2, 3], [2, 3, 4], [2, 3, 4, 5]]

但是我无法将其用于DataFrame:

asdfUDF = udf(asdf, ArrayType(IntegerType()))

In: df.withColumn("src_ex", asdfUDF("sources", 32))

Out: Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.lang.Integer]) does not exist

期望的行为:

In: df.withColumn("src_ex", asdfUDF("sources", 32))
Out: 

+---------------+
|         src_ex|
+---------------+
|           [62]|
|            [7]|
|           [62]|
|   [18, 36, 62]|
|[7, 31, 36, 62]|
|        [7, 62]|

除了将新列添加到 PySpark DataFrame(df)之外,您有任何建议或想法吗?
1个回答

19

Spark >= 2.4

您可以使用array_remove

from pyspark.sql.functions import array_remove

df.withColumn("src_ex", array_remove("sources", 32)).show()
+---------------+---------------+
|        sources|         src_ex|
+---------------+---------------+
|           [62]|           [62]|
|        [7, 32]|            [7]|
|           [62]|           [62]|
|   [18, 36, 62]|   [18, 36, 62]|
|[7, 31, 36, 62]|[7, 31, 36, 62]|
|    [7, 32, 62]|        [7, 62]|
+---------------+---------------+

或者filter

from pyspark.sql.functions import expr

df.withColumn("src_ex", expr("filter(sources, x -> not(x <=> 32))")).show()
+---------------+---------------+
|        sources|         src_ex|
+---------------+---------------+
|           [62]|           [62]|
|        [7, 32]|            [7]|
|           [62]|           [62]|
|   [18, 36, 62]|   [18, 36, 62]|
|[7, 31, 36, 62]|[7, 31, 36, 62]|
|    [7, 32, 62]|        [7, 62]|
+---------------+---------------+

Spark < 2.4

有几件事情:

  • DataFrame 不是 列表的列表。实际上它甚至不是一个普通的 Python 对象,它没有 len 方法也不是 可迭代对象
  • 你的列看起来像一个普通的 数组 类型。
  • 你不能在 UDF 内部引用 DataFrame(或任何其他分布式数据结构)。
  • 直接传递到 UDF 调用的每个参数都必须是一个 str (列名)或 Column 对象。要传递字面值,请使用 lit 函数。

现在唯一剩下的就是使用列表推导式:

from pyspark.sql.functions import lit, udf

def drop_from_array_(arr, item):
    return [x for x in arr if x != item]

drop_from_array = udf(drop_from_array_, ArrayType(IntegerType()))

使用示例:

df = sc.parallelize([
    [62], [7, 32], [62], [18, 36, 62], [7, 31, 36, 62], [7, 32, 62]
]).map(lambda x: (x, )).toDF(["sources"])

df.withColumn("src_ex", drop_from_array("sources", lit(32)))
结果为:

The result:

+---------------+---------------+
|        sources|         src_ex|
+---------------+---------------+
|           [62]|           [62]|
|        [7, 32]|            [7]|
|           [62]|           [62]|
|   [18, 36, 62]|   [18, 36, 62]|
|[7, 31, 36, 62]|[7, 31, 36, 62]|
|    [7, 32, 62]|        [7, 62]|
+---------------+---------------+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接