检查一个数组中的所有元素是否都存在于另一个数组中

3

I have a df1 Spark dataframe

id     transactions
1      [1, 2, 3, 5]
2      [1, 2, 3, 6]
3      [1, 2, 9, 8]
4      [1, 2, 5, 6]

root
 |-- id: int (nullable = true)
 |-- transactions: array (nullable = false)
     |-- element: int(containsNull = true)
 None

I have a df2 Spark dataframe

items   cost
  [1]    1.0
  [2]    1.0
 [2, 1]  2.0
 [6, 1]  2.0

root
 |-- items: array (nullable = false)
    |-- element: int (containsNull = true)
 |-- cost: int (nullable = true)
 None

我想检查items列中的所有数组元素是否都在transactions列中。

第一行([1, 2, 3, 5])包含来自items列的[1],[2],[2, 1]。因此,我需要对它们对应的成本进行求和:1.0 + 1.0 + 2.0 = 4.0

我想要的输出结果是

"最初的回答"

id     transactions    score
1      [1, 2, 3, 5]   4.0
2      [1, 2, 3, 6]   6.0
3      [1, 2, 9, 8]   4.0
4      [1, 2, 5, 6]   6.0

我尝试使用collect()/toLocalIterator循环,但效率似乎不高。我将有大量数据。

我认为创建一个像这样的UDF可以解决它。但是它会抛出错误。

修改后:

我尝试使用collect()/toLocalIterator循环来处理数据,但效率不高,因为数据量太大。我认为创建一个UDF可以解决这个问题,但是执行时出现了错误。

from pyspark.sql.functions import udf
def containsAll(x, y):
  result = all(elem in x for elem in y)

  if result:
    print("Yes, transactions contains all items")    
  else :
    print("No")

contains_udf = udf(containsAll)
dataFrame.withColumn("result", contains_udf(df2.items, df1.transactions)).show()

有其他解决办法吗?最初的回答。

你需要将两个DataFrames进行连接,使用groupbysum函数(不要使用循环或collect函数)。你的数据框架模式是什么?请在问题中编辑并添加df.printSchema()。我假设这些列表是整数数组 - 如果是这样,请参考以下帖子了解如何连接这两个数据框架:PySpark Join on Values Within A List - pault
@priya,请问df1df2的大小关系如何? - cph_sto
@cph_sto df1 可能有 100000 行,交易中的元素数量可能在 1000 到 10000 之间。df2 可以包含与 df1 相同的行数的两倍或三倍。 - priya
你正在使用哪个版本的Spark? - Shaido
@Shaido Spark 2.3.3 - priya
2个回答

7

在2.4版本之前,有效的udf(用户自定义函数)必须返回某些内容。

from pyspark.sql.functions import udf

@udf("boolean")
def contains_all(x, y):
    if x is not None and y is not None:
        return set(y).issubset(set(x))

在2.4或更高版本中,不需要使用UDF:

from pyspark.sql.functions import array_intersect, size

def contains_all(x, y):
    return size(array_intersect(x, y)) == size(y)

使用方法:

from pyspark.sql.functions import col, sum as sum_, when

df1 = spark.createDataFrame(
   [(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
   ("id", "transactions")
)

df2 = spark.createDataFrame(
    [([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
    ("items", "cost")
)


(df1
    .crossJoin(df2).groupBy("id", "transactions")
    .agg(sum_(when(
        contains_all("transactions", "items"), col("cost")
    )).alias("score"))
    .show())

结果如下:
+---+------------+-----+                                                        
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
+---+------------+-----+

如果df2比较小,可以考虑将其作为局部变量使用:
items = sc.broadcast([
    (set(items), cost) for items, cost in df2.select("items", "cost").collect()
])

def score(y):
    @udf("double")
    def _(x):
        if x is not None:
            transactions = set(x)
            return sum(
                cost for items, cost in y.value 
                if items.issubset(transactions))
    return _


df1.withColumn("score", score(items)("transactions")).show()

+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
+---+------------+-----+

终于可以拆分和连接了

from pyspark.sql.functions import explode

costs = (df1
    # Explode transactiosn
    .select("id", explode("transactions").alias("item"))
    .join(
        df2 
            # Add id so we can later use it to identify source
            .withColumn("_id", monotonically_increasing_id().alias("_id"))
             # Explode items
            .select(
                "_id", explode("items").alias("item"), 
                # We'll need size of the original items later
                size("items").alias("size"), "cost"), 
         ["item"])
     # Count matches in groups id, items
     .groupBy("_id", "id", "size", "cost")
     .count()
     # Compute cost
     .groupBy("id")
     .agg(sum_(when(col("size") == col("count"), col("cost"))).alias("score")))

costs.show()

+---+-----+                                                                      
| id|score|
+---+-----+
|  1|  4.0|
|  3|  4.0|
|  2|  6.0|
|  4|  6.0|
+---+-----+

然后将结果与原始的df1重新合并,

df1.join(costs, ["id"])

但这种方法不够直接,需要多次洗牌。它可能仍然比笛卡尔积(crossJoin)更可取,但这取决于实际数据。


非常感谢您的帮助。在使用2.4独立版时,我尝试了您的代码(包含使用array_intersect方法的contains_all函数)。但是它抛出了Py4JJavaError错误:调用o718.showString时发生错误。原因是:java.net.SocketTimeoutException:接受超时。 - priya
你使用哪个JDK和Spark版本? - priya
2.4.0,JDK8(这是目前Apache Spark支持的最新版本)。 - user10938362
我更喜欢使用Explode和Join方法。笛卡尔积和广播对我来说太昂贵了。感谢所有的解释! - priya
但是当项目和交易中存在重复实体时,代码无法运行。例如,Transactions = [1,2,1] items=[1,2,1],尽管items [1,2,1] 存在于transactions中。 - priya
在广播方法中,我们可以使用 df1 作为局部变量而不是 df2 并实现结果吗? - priya

2

Spark 3.0+有一个更多的选项,使用forall

F.expr("forall(look_for, x -> array_contains(look_in, x))")

Spark 3.1+的语法替代方案 - F.forall('look_for', lambda x: F.array_contains('look_in', x))


与Spark 2.4中的选项(array_intersect)进行比较。
F.size(F.array_intersect('look_for', 'look_in')) == F.size('look_for')

它们在处理重复值和空值方面有所不同。
from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(['a', 'b', 'c'], ['a']),
     (['a', 'b', 'c'], ['d']),
     (['a', 'b', 'c'], ['a', 'b']),
     (['a', 'b', 'c'], ['c', 'd']),
     (['a', 'b', 'c'], ['a', 'b', 'c']),
     (['a', 'b', 'c'], ['a', None]),
     (['a', 'b',None], ['a', None]),
     (['a', 'b',None], ['a']),
     (['a', 'b',None], [None]),
     (['a', 'b', 'c'], None),
     (None, ['a']),
     (None, None),
     (['a', 'b', 'c'], ['a', 'a']),
     (['a', 'a', 'a'], ['a']),
     (['a', 'a', 'a'], ['a', 'a', 'a']),
     (['a', 'a', 'a'], ['a', 'a',None]),
     (['a', 'a',None], ['a', 'a', 'a']),
     (['a', 'a',None], ['a', 'a',None])],
    ['look_in', 'look_for'])
df = df.withColumn('spark_3_0', F.expr("forall(look_for, x -> array_contains(look_in, x))"))
df = df.withColumn('spark_2_4', F.size(F.array_intersect('look_for', 'look_in')) == F.size('look_for'))

enter image description here

从数组中删除空值在某些情况下可能很有用,最简单的方法是使用 Spark 3.4+ 中的 array_compact 函数。

如果我的"look in"是在一个变量中定义的单独列表,如何实现相同的结果。 - undefined

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接