PySpark创建DataFrame列之间的关系

3

我正在尝试实现一些逻辑,以便根据以下逻辑获取ID和链接之间的关系。

逻辑 -

  • 如果ID 1与2有链接,2与3有链接,则关系为1->2、1->3、2->1、2->3、3->1、3->2
  • 同样,如果1与4、4与7、7与5,则关系为1->4、1->5、1->7、4->1、4->5、4->7、5->1、5->4、5->7

输入数据框 -

+---+----+
| id|link|
+---+----+
|  1|   2|
|  3|   1|
|  4|   2|
|  6|   5|
|  9|   7|
|  9|  10|
+---+----+

我正在尝试实现以下输出 -

+---+----+
| Id|Link|
+---+----+
|  1|   2|
|  1|   3|
|  1|   4|
|  2|   1|
|  2|   3|
|  2|   4|
|  3|   1|
|  3|   2|
|  3|   4|
|  4|   1|
|  4|   2|
|  4|   3|
|  5|   6|
|  6|   5|
|  7|   9|
|  7|  10|
|  9|   7|
|  9|  10|
| 10|   7|
| 10|   9|
+---+----+

我已经尝试了很多方法,但它根本不起作用。我也尝试了以下代码

df = spark.createDataFrame([(1, 2), (3, 1), (4, 2), (6, 5), (9, 7), (9, 10)], ["id", "link"])
ids = df.select("Id").distinct().rdd.flatMap(lambda x: x).collect()
links = df.select("Link").distinct().rdd.flatMap(lambda x: x).collect()
combinations = [(id, link) for id in ids for link in links]
df_combinations = spark.createDataFrame(combinations, ["Id", "Link"])
result = df_combinations.join(df, ["Id", "Link"], "left_anti").union(df).dropDuplicates()
result = result.sort(asc("Id"), asc("Link"))

并且

df = spark.createDataFrame([(1, 2), (3, 1), (4, 2), (6, 5), (9, 7), (9, 10)], ["id", "link"])

combinations = df.alias("a").crossJoin(df.alias("b")) \
    .filter(F.col("a.id") != F.col("b.id"))\
    .select(col("a.id").alias("a_id"), col("b.id").alias("b_id"), col("a.link").alias("a_link"), col("b.link").alias("b_link"))

window = Window.partitionBy("a_id").orderBy("a_id", "b_link")
paths = combinations.groupBy("a_id", "b_link") \
    .agg(F.first("b_id").over(window).alias("id")) \
    .groupBy("id").agg(F.collect_list("b_link").alias("links"))

result = paths.select("id", F.explode("links").alias("link"))
result = result.union(df.selectExpr("id as id_", "link as link_"))

任何帮助都将不胜感激。

1个回答

2

这不是一种通用方法,但您可以使用graphframes包。您可能会在设置时遇到困难,但是可以使用它,结果很简单。

import os
sc.addPyFile(os.path.expanduser('graphframes-0.8.1-spark3.0-s_2.12.jar'))

from graphframes import *

e = df.select('id', 'link').toDF('src', 'dst')
v = e.select('src').toDF('id') \
  .union(e.select('dst')) \
  .distinct()

g = GraphFrame(v, e)

sc.setCheckpointDir("/tmp/graphframes")
df = g.connectedComponents()

df.join(df.withColumnRenamed('id', 'link'), ['component'], 'inner') \
  .drop('component') \
  .filter('id != link') \
  .show()

+---+----+
| id|link|
+---+----+
|  7|  10|
|  7|   9|
|  3|   2|
|  3|   4|
|  3|   1|
|  5|   6|
|  6|   5|
|  9|  10|
|  9|   7|
|  1|   2|
|  1|   4|
|  1|   3|
| 10|   9|
| 10|   7|
|  4|   2|
|  4|   1|
|  4|   3|
|  2|   4|
|  2|   1|
|  2|   3|
+---+----+

connectedComponents方法返回每个顶点的组件ID,对于每个顶点组(由边连接并且如果没有边与其他组件分离),该ID是唯一的。因此,您可以对每个组执行笛卡尔积而不考虑顶点本身。

额外回答

受到上述方法的启发,我查找并找到了networkx包。

import networkx as nx

df = df.toPandas()
G = nx.from_pandas_edgelist(df, 'id', 'link')
components = [[list(c)] for c in nx.connected_components(G)]

df2 = spark.createDataFrame(components, ['array']) \
  .withColumn('component', f.monotonically_increasing_id()) \
  .select('component', f.explode('array').alias('id'))

df2.join(df2.withColumnRenamed('id', 'link'), ['component'], 'inner') \
  .drop('component') \
  .filter('id != link') \
  .show()

+---+----+
| id|link|
+---+----+
|  1|   2|
|  1|   3|
|  1|   4|
|  2|   1|
|  2|   3|
|  2|   4|
|  3|   1|
|  3|   2|
|  3|   4|
|  4|   1|
|  4|   2|
|  4|   3|
|  5|   6|
|  6|   5|
|  9|  10|
|  9|   7|
| 10|   9|
| 10|   7|
|  7|   9|
|  7|  10|
+---+----+

你是对的。我正在努力设置这个库,已经添加了jar文件。但是在Pycharm上运行这行代码“from graphframes import *”时出现编译错误。 - Avijit
你知道如何在Pycharm上从Jar中设置Python库的过程吗?目前我正在Pycharm上运行整个项目。一旦在本地Pycharm上成功测试,我将在EMR上运行它。 - Avijit

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接