根据另一列的元素,从Pyspark数组中删除元素。

4

我希望验证一个数组是否包含一个字符串在Pyspark中(Spark < 2.4)。

示例数据框:

column_1 <Array>           |    column_2 <String>
--------------------------------------------
["2345","98756","8794"]    |       8794
--------------------------------------------
["8756","45678","987563"]  |       1234
--------------------------------------------
["3475","8956","45678"]    |       3475
--------------------------------------------

我希望比较两列column_1和column_2。如果column_1中包含column_2,则应跳过column_1的值。我编写了一个UDF以从column_1中减去column_2,但它没有起作用。

def contains(x, y):
        try:
            sx, sy = set(x), set(y)
            if len(sx) == 0:
                return sx
            elif len(sy) == 0:
                return sx
            else:
                return sx - sy            
        # in exception, for example `x` or `y` is None (not a list)
        except:
            return sx
    udf_contains = udf(contains, 'string')
    new_df = my_df.withColumn('column_1', udf_contains(my_df.column_1, my_df.column_2))  

期望结果:

column_1 <Array>           |    column_2 <String>
--------------------------------------------------
["2345","98756"]           |       8794
--------------------------------------------------
["8756","45678","987563"]  |       1234
--------------------------------------------------
["8956","45678"]           |       3475
--------------------------------------------------

如何处理 column_1 为空数组且 column_2 为 null 的情况?谢谢。


1
检查 udf_contains = udf(lambda x,y: [e for e in x if e != y], 'array<string>') - jxc
3
如果x可以为null或非列表,那么这段代码的作用是创建一个UDF(用户定义函数),其功能是:如果x是一个列表,则返回除y以外的所有元素组成的新列表;如果x不是列表,则直接返回x。这个UDF的数据类型是string类型的数组。 - jxc
@jxc 我需要你的帮助 :) https://stackoverflow.com/questions/58875531/concatenate-array-pyspark/58875920#58875920 - verojoucla
1个回答

5

Spark 2.4.0+

尝试使用array_remove函数,它自Spark 2.4.0版本开始提供:

val df = Seq(
    (Seq("2345","98756","8794"), "8794"), 
    (Seq("8756","45678","987563"), "1234"), 
    (Seq("3475","8956","45678"), "3475"),
    (Seq(), "empty"),
    (null, "null")
).toDF("column_1", "column_2")
df.show(5, false)

df
    .select(
        $"column_1",
        $"column_2",
        array_remove($"column_1", $"column_2") as "diff"
    ).show(5, false)

它将返回:

它将返回:

+---------------------+--------+
|column_1             |column_2|
+---------------------+--------+
|[2345, 98756, 8794]  |8794    |
|[8756, 45678, 987563]|1234    |
|[3475, 8956, 45678]  |3475    |
|[]                   |empty   |
|null                 |null    |
+---------------------+--------+

+---------------------+--------+---------------------+
|column_1             |column_2|diff                 |
+---------------------+--------+---------------------+
|[2345, 98756, 8794]  |8794    |[2345, 98756]        |
|[8756, 45678, 987563]|1234    |[8756, 45678, 987563]|
|[3475, 8956, 45678]  |3475    |[8956, 45678]        |
|[]                   |empty   |[]                   |
|null                 |null    |null                 |
+---------------------+--------+---------------------+

很抱歉,对于Scala来说,我认为使用Pyspark做同样的事情应该很容易。 Spark < 2.4.0
%pyspark

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType


data = [
    (["2345","98756","8794"], "8794"), 
    (["8756","45678","987563"], "1234"), 
    (["3475","8956","45678"], "3475"),
    ([], "empty"),
    (None,"null")    
    ]
df = spark.createDataFrame(data, ['column_1', 'column_2'])
df.printSchema()
df.show(5, False)

def contains(x, y):
    if x is None or y is None:
        return x
    else:
        sx, sy = set(x), set([y])
        return list(sx - sy)
udf_contains = udf(contains, ArrayType(StringType()))

df.select("column_1", "column_2", udf_contains("column_1", "column_2")).show(5, False)

result:

root
 |-- column_1: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- column_2: string (nullable = true)
+---------------------+--------+
|column_1             |column_2|
+---------------------+--------+
|[2345, 98756, 8794]  |8794    |
|[8756, 45678, 987563]|1234    |
|[3475, 8956, 45678]  |3475    |
|[]                   |empty   |
|null                 |null    |
+---------------------+--------+
+---------------------+--------+----------------------------+
|column_1             |column_2|contains(column_1, column_2)|
+---------------------+--------+----------------------------+
|[2345, 98756, 8794]  |8794    |[2345, 98756]               |
|[8756, 45678, 987563]|1234    |[8756, 987563, 45678]       |
|[3475, 8956, 45678]  |3475    |[8956, 45678]               |
|[]                   |empty   |[]                          |
|null                 |null    |null                        |
+---------------------+--------+----------------------------+

谢谢你的帮助,我是这样做的:df.select(array_remove(df.data, 1)).collect(),但是我遇到了“TypeError: 'Column' object is not callable”的问题,可能是因为我使用的是Spark < 2.4版本。我已经在我的问题中提到了这一点。 - verojoucla
1
@verojoucla,我使用pyspark添加了Spark < 2.4版本。你的代码片段不起作用,因为将字符串设置为集合会返回只包含单个字符的集合,例如 set("abc") > set(['a', 'c', 'b']) - shuvalov

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接