Apache Spark -- 将UDF的结果分配给多个数据框列

65

我正在使用pyspark,使用spark-csv将一个大型csv文件加载到dataframe中,作为预处理步骤,我需要对包含json字符串的一个列中可用数据应用各种操作。这将返回X个值,每个值都需要存储在自己单独的列中。

该功能将在UDF中实现,但是我不确定如何从UDF返回值列表并将其提供给各个列。以下是一个简单的示例:

(...)
from pyspark.sql.functions import udf
def udf_test(n):
    return [n/2, n%2]

test_udf=udf(udf_test)


df.select('amount','trans_date').withColumn("test", test_udf("amount")).show(4)

这会产生以下结果:

+------+----------+--------------------+
|amount|trans_date|                test|
+------+----------+--------------------+
|  28.0|2016-02-07|         [14.0, 0.0]|
| 31.01|2016-02-07|[15.5050001144409...|
| 13.41|2016-02-04|[6.70499992370605...|
| 307.7|2015-02-17|[153.850006103515...|
| 22.09|2016-02-05|[11.0450000762939...|
+------+----------+--------------------+
only showing top 5 rows

现在这个示例中,UDF返回的两个值被视为字符串,如何将它们存储在不同的列中是最佳的方法?

df.select('amount','trans_date').withColumn("test", test_udf("amount")).printSchema()

root
 |-- amount: float (nullable = true)
 |-- trans_date: string (nullable = true)
 |-- test: string (nullable = true)
2个回答

103

无法从单个UDF调用中创建多个顶级列,但您可以创建一个新的struct。 这需要具有指定returnType的UDF:

from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, FloatType

schema = StructType([
    StructField("foo", FloatType(), False),
    StructField("bar", FloatType(), False)
])

def udf_test(n):
    return (n / 2, n % 2) if n and n != 0.0 else (float('nan'), float('nan'))

test_udf = udf(udf_test, schema)
df = sc.parallelize([(1, 2.0), (2, 3.0)]).toDF(["x", "y"])

foobars = df.select(test_udf("y").alias("foobar"))
foobars.printSchema()
## root
##  |-- foobar: struct (nullable = true)
##  |    |-- foo: float (nullable = false)
##  |    |-- bar: float (nullable = false)

您可以通过简单的select进一步压缩模式:

foobars.select("foobar.foo", "foobar.bar").show()
## +---+---+
## |foo|bar|
## +---+---+
## |1.0|0.0|
## |1.5|1.0|
## +---+---+

另请参阅如何从Spark DataFrame的单个列中导出多个列


太棒了!这对我所需的功能非常有效。我已经完成了大部分工作,但是将StructType架构错误地输入到udf中,导致我的新列最终成为StringType。非常感谢! - Everaldo Aguiar
谢谢!这正是我在寻找的。 :) - dksahuji
7
你也可以使用foobars.select("foobar.*")来代替逐个列出每一列的名字。该方法不仅简洁,而且效果相同。 - pault
2
您还可以通过两步操作“混合”原始列和UDF列:df.select("x", test_udf("y").alias("foobar")).select("x", "foobar.*") - mjv
1
from pyspark.sql.types import StructType, StructField, FloatType - alvaro nortes
要将新列添加到现有的数据框中,可以尝试以下方法:df.withColumn('foobar', test_udf('y')).select( df.columns + [df['foobar.%s' % c.name] for c in schema] ) - undefined

2
你可以使用flatMap一次性获取所需数据框的列。
df=df.withColumn('udf_results',udf)  
df4=df.select('udf_results').rdd.flatMap(lambda x:x).toDF(schema=your_new_schema)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接