Pyspark:如何在数据框中将一行复制n次?

16

我有一个这样的数据框,如果列n大于1,我想要将该行复制n次:

A   B   n  
1   2   1  
2   9   1  
3   8   2    
4   1   1    
5   3   3 

然后变成这样:

A   B   n  
1   2   1  
2   9   1  
3   8   2
3   8   2       
4   1   1    
5   3   3 
5   3   3 
5   3   3 

我认为我应该使用explode,但我不明白它是如何工作的...
谢谢


1
@Learningstatsbyexample:这是用Python编写的。 - Chjul
3个回答

21

使用Spark 2.4.0+的内置函数,这变得更加容易:array_repeat + explode:

from pyspark.sql.functions import expr

df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)], ["A", "B", "n"])

new_df = df.withColumn('n', expr('explode(array_repeat(n,int(n)))'))

>>> new_df.show()
+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
|  5|  3|  3|
|  5|  3|  3|
+---+---+---+

在使用PySpark中的array_repeat API函数时,如何引用第二个参数(count)的n?当我尝试使用F.col()时,会收到“Column is not iterable”的错误提示。 - David Foster
1
@DavidFoster,使用pyspark API函数无法完成此操作,请查看此常见 问题。使用SQL表达式,这不是问题,而且我认为代码通常更简洁易于维护。 - jxc
感谢确认这不受支持。在底层 SQL 中很简单,但 Python API 不支持似乎有点奇怪。关于简洁性的观点是公正的,然而在这种情况下,Python 解决方案看起来与 SQL 相同,我的 IDE 会将其视为字符串而不是函数/方法等。 - David Foster
@jxc,array_repeat在长类型列上不起作用。 - Dariusz Krynicki

10

explode函数用于将数组或映射中的每个元素提取出来,生成新的行。

利用该函数的一种方法是使用udf为每一行创建一个大小为n的列表,然后将结果数组进行展开。

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType
    
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) 

+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
+---+---+---+

# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))

+---+---+---------+
|  A|  B|        n|
+---+---+---------+
|  1|  2|      [1]|
|  2|  9|      [1]|
|  3|  8|   [2, 2]|
|  4|  1|      [1]|
|  5|  3|[3, 3, 3]|
+---+---+---------+ 

# now use explode  
df2.withColumn('n', explode(df2.n)).show()

+---+---+---+ 
| A | B | n | 
+---+---+---+ 
|  1|  2|  1| 
|  2|  9|  1| 
|  3|  8|  2| 
|  3|  8|  2| 
|  4|  1|  1| 
|  5|  3|  3| 
|  5|  3|  3| 
|  5|  3|  3| 
+---+---+---+ 

1
我更喜欢@jxc的答案,适用于Spark 2.4+,因为它使用所有内置函数而不是UDF。 - DVL
@DVL 不幸的是,我们中的一些人仍然被困在之前的Spark版本中... - Mehdi LAMRANI
@DVL array_repeat 在 long 类型列上不起作用。 - Dariusz Krynicki

3

我认为@Ahmed提供的udf答案是最好的方法,但这里有一种替代方法,对于小的n可能同样好或更好:

首先,收集整个DataFrame中n的最大值:

max_n = df.select(f.max('n').alias('max_n')).first()['max_n']
print(max_n)
#3

现在为每一行创建一个长度为max_n的数组,其中包含range(max_n)中的数字。这个中间步骤的输出将会得到一个类似如下的DataFrame:
df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)])).show()
#+---+---+---+---------+
#|  A|  B|  n|  n_array|
#+---+---+---+---------+
#|  1|  2|  1|[0, 1, 2]|
#|  2|  9|  1|[0, 1, 2]|
#|  3|  8|  2|[0, 1, 2]|
#|  4|  1|  1|[0, 1, 2]|
#|  5|  3|  3|[0, 1, 2]|
#+---+---+---+---------+

现在我们要将 n_array 列拆分,然后过滤只保留数组中小于 n 的值。这样可以确保每行有 n 个副本。最后我们删除拆分的列,得到最终结果:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)]))\
    .select('A', 'B', 'n', f.explode('n_array').alias('col'))\
    .where(f.col('col') < f.col('n'))\
    .drop('col')\
    .show()
#+---+---+---+
#|  A|  B|  n|
#+---+---+---+
#|  1|  2|  1|
#|  2|  9|  1|
#|  3|  8|  2|
#|  3|  8|  2|
#|  4|  1|  1|
#|  5|  3|  3|
#|  5|  3|  3|
#|  5|  3|  3|
#+---+---+---+

然而,对于每一行我们正在创建一个长度为max_n的数组- 而不是udf解决方案中的长度为n的数组。我不确定对于大型max_n,这种方法如何才能够具有可伸缩性,但我认为udf将获胜。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接