在pyspark数据框的列中统计一个子字符串列表出现的次数

6
我想要计算一组子字符串的出现次数,并基于pyspark数据框中包含的长字符串列创建一列。
Input:          
       ID    History

       1     USA|UK|IND|DEN|MAL|SWE|AUS
       2     USA|UK|PAK|NOR
       3     NOR|NZE
       4     IND|PAK|NOR

 lst=['USA','IND','DEN']


Output :
       ID    History                      Count

       1     USA|UK|IND|DEN|MAL|SWE|AUS    3
       2     USA|UK|PAK|NOR                1
       3     NOR|NZE                       0
       4     IND|PAK|NOR                   1
2个回答

5
# Importing requisite packages and creating a DataFrame
from pyspark.sql.functions import split, col, size, regexp_replace
values = [(1,'USA|UK|IND|DEN|MAL|SWE|AUS'),(2,'USA|UK|PAK|NOR'),(3,'NOR|NZE'),(4,'IND|PAK|NOR')]
df = sqlContext.createDataFrame(values,['ID','History'])
df.show(truncate=False)
+---+--------------------------+
|ID |History                   |
+---+--------------------------+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|
|2  |USA|UK|PAK|NOR            |
|3  |NOR|NZE                   |
|4  |IND|PAK|NOR               |
+---+--------------------------+

这个想法是基于这三个分隔符lst=['USA','IND','DEN']来拆分字符串,然后统计产生的子字符串数量。

例如; 字符串USA|UK|IND|DEN|MAL|SWE|AUS被拆分为 - ,, |UK|, |, |MAL|SWE|AUS。由于产生了4个子字符串并且有3个分隔符匹配,所以4-1 = 3表示此类字符串在列字符串中出现的次数。

我不确定Spark是否支持多字符分隔符,因此作为第一步,我们用标志/虚拟值%替换列表['USA','IND','DEN']中的任何这三个子字符串。您也可以使用其他内容。以下代码执行此replacement-

df = df.withColumn('History_X',col('History'))
lst=['USA','IND','DEN']
for i in lst:
    df = df.withColumn('History_X', regexp_replace(col('History_X'), i, '%'))
df.show(truncate=False)
+---+--------------------------+--------------------+
|ID |History                   |History_X           |
+---+--------------------------+--------------------+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|%|UK|%|%|MAL|SWE|AUS|
|2  |USA|UK|PAK|NOR            |%|UK|PAK|NOR        |
|3  |NOR|NZE                   |NOR|NZE             |
|4  |IND|PAK|NOR               |%|PAK|NOR           |
+---+--------------------------+--------------------+

最后,我们首先使用以 % 为分隔符的方式对其进行分割字符串,计算所创建的子串数,然后使用size函数计算所创建的子串数目,并最终将其减去1。

df = df.withColumn('Count', size(split(col('History_X'), "%")) - 1).drop('History_X')
df.show(truncate=False)
+---+--------------------------+-----+
|ID |History                   |Count|
+---+--------------------------+-----+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|3    |
|2  |USA|UK|PAK|NOR            |1    |
|3  |NOR|NZE                   |0    |
|4  |IND|PAK|NOR               |1    |
+---+--------------------------+-----+

我遇到了一个错误:类“org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator”超过了64 KB。 - Faliha Zikra
你使用的是哪个Spark版本? - cph_sto

5
如果您使用的是Spark 2.4+,您可以尝试使用SPARK SQL高阶函数filter()
from pyspark.sql import functions as F

>>> df.show(5,0)
+---+--------------------------+
|ID |History                   |
+---+--------------------------+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|
|2  |USA|UK|PAK|NOR            |
|3  |NOR|NZE                   |
|4  |IND|PAK|NOR               |
+---+--------------------------+

df_new = df.withColumn('data', F.split('History', '\|')) \
           .withColumn('cnt', F.expr('size(filter(data, x -> x in ("USA", "IND", "DEN")))'))

>>> df_new.show(5,0)
+---+--------------------------+----------------------------------+---+
|ID |History                   |data                              |cnt|
+---+--------------------------+----------------------------------+---+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|[USA, UK, IND, DEN, MAL, SWE, AUS]|3  |
|2  |USA|UK|PAK|NOR            |[USA, UK, PAK, NOR]               |1  |
|3  |NOR|NZE                   |[NOR, NZE]                        |0  |
|4  |IND|PAK|NOR               |[IND, PAK, NOR]                   |1  |
+---+--------------------------+----------------------------------+---+

在这里,我们首先将字段 History 拆分为一个名为 data 的数组列,然后使用筛选函数:

filter(data, x -> x in ("USA", "IND", "DEN"))

为了仅检索满足条件的数组元素:IN ("USA", "IND", "DEN"),之后,我们可以使用size()函数对结果数组进行计数。 更新:添加另一种使用array_contains()函数的方法,适用于旧版本Spark。
lst = ["USA", "IND", "DEN"]

df_new = df.withColumn('data', F.split('History', '\|')) \
           .withColumn('Count', sum([F.when(F.array_contains('data',e),1).otherwise(0) for e in lst]))

注意: 数组中的重复项将被跳过,此方法仅计算唯一国家/地区代码。


这会抛出一个错误。ParseException: "\n多余的输入 '>'。 - Faliha Zikra
@FalihaZikra 你的 Spark 版本是多少?filter() 函数仅在 2.40 及以上版本可用。 - jxc
@FalihaZikra,我添加了另一种方法,适用于旧版本的Spark,你可以测试一下是否可行。 - jxc
更新的版本可以工作。谢谢你。不幸的是,由于它不计算重复出现的情况,这可能对我的用例无效。 - Faliha Zikra
@FalihaZikra,如果您可以将Spark升级到2.4版本,那么您的任务将会变得更加容易 :) - jxc

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接