如何获取 PySpark DataFrame 的引用列?

6

给定一个PySpark DataFrame,是否可能获得由该DataFrame引用的源列的列表?

也许更具体的例子可以帮助解释我的需求。假设我定义了一个DataFrame:

import pyspark.sql.functions as func
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
source_df = spark.createDataFrame(
    [("pru", 23, "finance"), ("paul", 26, "HR"), ("noel", 20, "HR")],
    ["name", "age", "department"],
)
source_df.createOrReplaceTempView("people")
sqlDF = spark.sql("SELECT name, age, department FROM people")
df = sqlDF.groupBy("department").agg(func.max("age").alias("max_age"))
df.show()

它返回:

+----------+--------+                                                           
|department|max_age |
+----------+--------+
|   finance|      23|
|        HR|      26|
+----------+--------+

df引用的列是[department,age]。是否可以以编程方式获取引用列的列表?

感谢在pyspark中捕获explain()的结果,我知道可以将计划提取为字符串:

df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "formatted")

该函数返回:

== Physical Plan ==
AdaptiveSparkPlan (6)
+- HashAggregate (5)
   +- Exchange (4)
      +- HashAggregate (3)
         +- Project (2)
            +- Scan ExistingRDD (1)


(1) Scan ExistingRDD
Output [3]: [name#0, age#1L, department#2]
Arguments: [name#0, age#1L, department#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)

(2) Project
Output [2]: [age#1L, department#2]
Input [3]: [name#0, age#1L, department#2]

(3) HashAggregate
Input [2]: [age#1L, department#2]
Keys [1]: [department#2]
Functions [1]: [partial_max(age#1L)]
Aggregate Attributes [1]: [max#22L]
Results [2]: [department#2, max#23L]

(4) Exchange
Input [2]: [department#2, max#23L]
Arguments: hashpartitioning(department#2, 200), ENSURE_REQUIREMENTS, [plan_id=60]

(5) HashAggregate
Input [2]: [department#2, max#23L]
Keys [1]: [department#2]
Functions [1]: [max(age#1L)]
Aggregate Attributes [1]: [max(age#1L)#12L]
Results [2]: [department#2, max(age#1L)#12L AS max_age#13L]

(6) AdaptiveSparkPlan
Output [2]: [department#2, max_age#13L]
Arguments: isFinalPlan=false

这很有用,但不是我需要的。我需要一个引用列的列表。这可行吗?

或许另一种提问方式是...有没有一种方法可以将执行计划作为对象获取,以便我可以遍历/探索它?


更新。感谢@matt-andruff的回复,我已经得到了这个:

df._jdf.queryExecution().executedPlan().treeString().split("+-")[-2]

返回:

' Project [age#1L, department#2]\n            '

我猜我可以从中解析出所需的信息,但这远非优雅的方法,而且容易出错。

实际上,我真正需要的是一种可靠、安全的API支持方式来获取这些信息。我开始觉得这似乎不可能。


也许只需要 df.columns - Steven
这将会给我最终数据框中的列,但这不是我想要的。 - jamiet
4个回答

4

有一个对象可以完成这个任务,但它是Java对象,而且没有被翻译成PySpark。

您仍然可以使用Spark构造函数来访问它:

>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(0).toString()
u'department#1621'
>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(1).toString()
u'max_age#1632L'

你可以通过循环上述的apply来获取你想要的信息,类似于这样:
plan = df._jdf.queryExecution().executedPlan()
steps = [ plan.apply(i) for i in range(1,100) if not isinstance(plan.apply(i), type(None)) ]
iterator = steps[0].inputSet().iterator()
>>> iterator.next().toString()
u'department#1621'
>>> iterator.next().toString()
u'max#1642L'

steps = [ plan.apply(i) for i in range(1,100) if not isinstance(plan.apply(i), type(None)) ]

projections = [ (steps[0].p(i).toJSON().encode('ascii','ignore')) for i in range(1,100) if not( isinstance(steps[0].p(i), type(None) )) and steps[0].p(i).nodeName().encode('ascii','ignore') == 'Project' ]
dd = spark.sparkContext.parallelize(projections)
df2 = spark.read.json(rdd)
>>> df2.show(1,False)
+-----+------------------------------------------+----+------------+------+--------------+------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+
|child|class                                     |name|num-children|output|outputOrdering|outputPartitioning|projectList                                                                                                                                                                                                                                                                                                                                                                                              |rdd |
+-----+------------------------------------------+----+------------+------+--------------+------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+
|0    |org.apache.spark.sql.execution.ProjectExec|null|1           |null  |null          |null              |[[[org.apache.spark.sql.catalyst.expressions.AttributeReference, long, [1620, 4ad48da6-03cf-45d4-9b35-76ac246fadac, org.apache.spark.sql.catalyst.expressions.ExprId], age, true, 0, [people]]], [[org.apache.spark.sql.catalyst.expressions.AttributeReference, string, [1621, 4ad48da6-03cf-45d4-9b35-76ac246fadac, org.apache.spark.sql.catalyst.expressions.ExprId], department, true, 0, [people]]]]|null|
+-----+------------------------------------------+----+------------+------+--------------+------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+
df2.select(func.explode(func.col('projectList'))).select( func.col('col')[0]["name"] ) .show(100,False)
+-----------+
|col[0].name|
+-----------+
|age        |
|department |
+-----------+

range --> 有点取巧,但显然size不起作用。如果有更多时间,我肯定可以改进范围的技巧。

然后可以使用json以编程方式提取信息。


谢谢,这很有帮助,但它仍然没有给我所需的答案。你的代码返回了 departmentmax_age。那不是我想要的,我想要的是 departmentage。希望我能够探索 _jdf 对象以找到我需要的内容。 - jamiet
1
如果您想以编程方式访问它,您需要编写一个Java类,然后在pyspark中调用Java类。这将使您能够正确地集成系统,直到他们更新pyspark代码以允许访问该计划。 - Matt Andruff
1
我对你的字符串回答不满意,所以我深入研究了一下,现在你可以通过编程访问它。 - Matt Andruff
1
我已经修复了我的代码中的拼写错误。它应该是 [ plan.apply(i) for i in range(1,100) if not isinstance(plan.apply(i), type(None))] - Matt Andruff
1
让我们在聊天中继续这个讨论 - Matt Andruff
显示剩余3条评论

0

PySpark并不是为这种低级技巧而设计的(更适合Scala,因为Spark是用Scala开发的,因此提供了所有内部功能)。

访问QueryExecution的这一步是进入Spark SQL查询执行引擎机制的主要入口点。

问题在于py4j(用作JVM和Python环境之间的桥梁)在PySpark方面没有用处。

如果您需要访问最终查询计划(就在将其转换为RDD之前),可以使用以下内容:

df._jdf.queryExecution().executedPlan().prettyJson()

查看 QueryExecution API。

QueryExecutionListener

你应该考虑使用Scala来拦截关于查询的任何想法,QueryExecutionListener 似乎是一个相当可行的起点。

还有更多,但都在Scala中:)

我真正想要的是一种可靠、可靠的、由API支持的方法来获取这些信息。我开始觉得这是不可能的。

我并不惊讶,因为你正在放弃最好的答案:Scala。我建议你用它来进行概念验证,看看你能得到什么,只有在必要时(如果你必须)寻找Python解决方案(我认为这是可行的,但非常容易出错)。


1
感谢Jacek,我得出了类似的结论。我在这里发布了一个回答,委婉地描述为“不完美”,原因与您表达的相似。对我来说,我正在编写一个基于pyspark的库,因此如果有Python解决方案,我更喜欢使用它。 - jamiet

0

我有一些东西,虽然不是对我的原始问题的答案(请参见Matt Andruff的答案),但仍然可能在这里有用。它是一种获取由pyspark.sql.column.Column引用的所有源列的方法。

简单的复制:

from pyspark.sql import functions as f, SparkSession
SparkSession.builder.getOrCreate()
col = f.concat(f.col("A"), f.col("B"))
type(col)
col._jc.expr().references().toList().toString()

返回:

<class 'pyspark.sql.column.Column'>
"List('A, 'B)"

这肯定不是完美的,它仍然需要您从返回的字符串中解析出列名,但至少我需要的信息是可用的。也许从 references() 返回的对象上还有一些更简单的方法来解析返回的字符串,但如果有的话,我还没有找到!

这是我编写的一个函数来执行解析操作。

def parse_references(references: str):
    return sorted(
        "".join(
            references.replace("'", "")
            .replace("List(", "")
            .replace(")", "")
            .replace(")", "")
            .split()
        ).split(",")
    )

assert parse_references("List('A, 'B)") == ["A", "B"]

这里是我使用它的提交记录:https://github.com/jamiekt/jstark/commit/dd0991c1241cf6934ff4b5c9028465442f1670d0 - jamiet

-1
你可以尝试以下代码,这将为你提供数据框中的列列表和其数据类型。
for field in df.schema.fields:
    print(field.name +" , "+str(field.dataType))

这将给我最终数据框中的字段,这不是我想要的。 - jamiet

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接