使用PySpark中的map迭代数组列

6
在PySpark中,我有一个由两列组成的数据框:
+-----------+----------------------+
| str1      | array_of_str         |
+-----------+----------------------+
| John      | [mango, apple, ...   |
| Tom       | [mango, orange, ...  |
| Matteo    | [apple, banana, ...  | 

我希望添加一个列concat_result,该列包含array_of_str中每个元素与str1列内字符串的拼接结果。
+-----------+----------------------+----------------------------------+
| str1      | array_of_str         | concat_result                    |
+-----------+----------------------+----------------------------------+
| John      | [mango, apple, ...   | [mangoJohn, appleJohn, ...       |
| Tom       | [mango, orange, ...  | [mangoTom, orangeTom, ...        |
| Matteo    | [apple, banana, ...  | [appleMatteo, bananaMatteo, ...  |

我正在尝试使用 map 来迭代数组:

from pyspark.sql import functions as F
from pyspark.sql.types import StringType, ArrayType

# START EXTRACT OF CODE
ret = (df
  .select(['str1', 'array_of_str'])
  .withColumn('concat_result', F.udf(
     map(lambda x: x + F.col('str1'), F.col('array_of_str')), ArrayType(StringType))
  )
)

return ret
# END EXTRACT OF CODE

但是我遇到了错误:
TypeError: argument 2 to map() must support iteration

我尝试了那个解决方案,但它并不起作用。如果你能写一个可行的解决方案,将不胜感激。 - Matteo Guarnerio
你需要定义一个带有两个参数的udf函数 - (除非你在Spark 2.4+中)。 - pault
可能是将PySpark数据框列从列表转换为字符串的重复问题。 - user10938362
1个回答

6
只需要进行小的调整即可让其正常工作:
from pyspark.sql.types import StringType, ArrayType
from pyspark.sql.functions import udf, col

concat_udf = udf(lambda con_str, arr: [x + con_str for x in arr],
                   ArrayType(StringType()))
ret = df \
  .select(['str1', 'array_of_str']) \
  .withColumn('concat_result', concat_udf(col("str1"), col("array_of_str")))

ret.show()

您不需要使用map,标准的列表推导式就足够了。

2
唯一的注意事项是,如果任何str1array_of_str值为null,则此代码将会出错。您需要在您的udf中添加明确的错误检查。 - pault

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接