将 Spark DataFrame 列转换为 Python 列表

174

我在处理一个包含两列mvv和count的数据框。

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

我希望获得两个列表,分别包含MVV值和计数值。就像这样:

mvv = [1,2,3,4]
count = [5,9,3,1]
所以,我尝试了以下代码:第一行应返回一个Python列表中的行。我想看到第一个值:
mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

但是在第二行代码中我会收到一个错误消息:

AttributeError: getInt


2
从Spark 2.3开始,此代码是最快且最不可能导致OutOfMemory异常的:list(df.select('mvv').toPandas()['mvv'])Arrow已集成到PySpark,这显著加速了toPandas。如果您使用的是Spark 2.3+,请勿使用其他方法。有关更多基准测试细节,请参见我的答案。 - Powers
11个回答

222

注意,你目前的做法是不起作用的。首先,你试图从Row类型中获取整数,你的collect输出看起来像这样:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

如果您有这样一个东西:

>>> firstvalue = mvv_list[0].mvv
Out: 1

您将获得mvv值。如果您想要数组的所有信息,可以采用以下方式:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

但是如果你尝试用同样的方法来处理另一列,你会得到:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

这是因为count是一个内置方法,而列名与count相同。解决此问题的方法是将count列的列名更改为_count

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

但是这种解决方法不再需要,因为您可以使用字典语法访问该列:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

最终它会正常工作!


它对于第一列非常有效,但我认为由于Spark功能的计数,它不适用于列计数。 - a.moussa
你能在评论中添加一下你对计数正在做什么吗? - Thiago Baldim
不需要添加 select('count'),可以像这样使用:count_list = [int(i.count) for i in mvv_list.collect()] 我会在回复中添加示例。 - Thiago Baldim
#Thiago Baldim 我找到了一个解决方案,但它不够优雅。 count_list = [i[1] for i in mvv_list.collect()] 它能工作是因为我知道计数在索引=2的位置。 你知道有没有一种解决方案,可以根据列名'count'来精确指定,而不是依赖于索引? - a.moussa
1
@a.moussa [i.['count'] for i in mvv_list.collect()] 的作用是显式地使用名为 'count' 的列而不是 count 函数。 - user989762
显示剩余8条评论

183

以下一行代码即可获得您所需的列表。

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

7
就性能而言,这个解决方案比你的解决方案要快得多。mvv_list = [int(i.mvv) for i in mvv_count.select('mvv').collect()] - Chanaka Fernando
2
这个代码是否适用于楼主的问题?:mvv = mvv_count_df.select("mvv").rdd.flatMap(list).collect() - eemilk

44

这将以列表形式给出所有元素。

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

5
这是适用于 Spark 2.3+ 的最快、最高效的解决方案。请查看我的回答中的基准测试结果。 - Powers

41

我进行了基准分析,list(mvv_count_df.select('mvv').toPandas()['mvv']) 是最快的方法。 我非常惊讶。

我在一个5节点i3.xlarge群集上使用Spark 2.4.5对10万/1亿行数据集运行了不同的方法(每个节点具有30.5 GB的RAM和4个内核)。 数据均匀分布在20个压缩了Snappy的Parquet文件中,每个文件只有一列。

以下是基准测试结果(运行时间以秒为单位):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

在收集驱动节点上的数据时需要遵循的黄金法则:

  • 尝试使用其他方法解决问题。收集驱动程序节点的数据成本高,无法利用Spark集群的强大功能,应尽可能避免。
  • 在收集数据之前尽量减少行数。聚合,去重,过滤和修剪列以在收集数据时发送尽可能少的数据到驱动节点。

toPandas 在Spark 2.3中得到了显着改进。如果您使用的是早于2.3的Spark版本,则可能不是最佳方法。

有关更多详细信息/基准测试结果,请参见此处


4
这真的很惊奇,因为我本来以为 toPandas 的性能会很差,因为我们需要进行另一个数据结构的转换。Spark团队必须在优化方面做了非常出色的工作。感谢基准测试! - THIS USER NEEDS HELP
1
你能测试一下 @phgui 的答案吗? 它看起来也相当高效。 mvv_list = df.select(collect_list("mvv")).collect()[0][0] - Bohdan Pylypenko

24

在我的数据中,我得到了以下基准:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0.52秒

>>> [row[col] for row in data.collect()]

0.271秒

>>> list(data.select(col).toPandas()[col])

0.427秒

结果相同


2
如果你使用 toLocalIterator 而不是 collect,它应该会更加节省内存 [row[col] for row in data.toLocalIterator()] - oglop

21

以下代码将帮助您

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

3
这应该是被接受的答案。原因是您在整个过程中都停留在一个Spark上下文中,然后在最后进行收集,而不是早期退出Spark上下文,这可能会导致更大的收集量,具体取决于您正在做什么。 - AntiPawn79

8
一种可能的解决方案是使用pyspark.sql.functions下的collect_list()函数。它将所有列值聚合到一个pyspark数组中,在收集时转换为python列表:
mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 

6

您可以先使用collect()函数获取DataFrame,该函数返回Row类型的列表

row_list = df.select('mvv').collect()

迭代行以将其转换为列表

sno_id_array = [ int(row.mvv) for row in row_list]

sno_id_array 
[1,2,3,4]

使用 flatMap
sno_id_array = df.select("mvv").rdd.flatMap(lambda x: x).collect()

6
如果您遇到以下错误:
AttributeError: 'list' object has no attribute 'collect'
使用以下代码可以解决问题:
mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

我也遇到了这个错误,这个解决方案解决了问题。但是为什么会出现这个错误呢?(许多其他人似乎没有遇到这个问题!) - Bikash Gyawali

6

让我们创建所需的数据框

df_test = spark.createDataFrame(
    [
        (1, 5),
        (2, 9),
        (3, 3),
        (4, 1),
    ],
    ['mvv', 'count']
)
df_test.show()

这提供了

+---+-----+
|mvv|count|
+---+-----+
|  1|    5|
|  2|    9|
|  3|    3|
|  4|    1|
+---+-----+

然后应用rdd.flatMap(f).collect()以获得列表

test_list = df_test.select("mvv").rdd.flatMap(list).collect()
print(type(test_list))
print(test_list)

这提供了

<type 'list'>
[1, 2, 3, 4]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接