我的目标是使用Spark DataFrames对一组分类列进行One-Hot编码。例如,就像Pandas
中的get_dummies()
函数一样。
数据集bureau.csv
最初来自于Kaggle的一个比赛Home Credit Default Risk。这里是我的入口表示例entryData
,仅筛选出KEY = 100001
的数据。
# primary key
KEY = 'SK_ID_CURR'
data = spark.read.csv("bureau.csv", header=True, inferSchema=True)
# sample data from bureau.csv of 1716428 rows
entryData = data.select(columnList).where(F.col(KEY) == 100001).show()
print(entryData)
+----------+-------------+---------------+---------------+
|SK_ID_CURR|CREDIT_ACTIVE|CREDIT_CURRENCY| CREDIT_TYPE|
+----------+-------------+---------------+---------------+
| 100001| Closed| currency 1|Consumer credit|
| 100001| Closed| currency 1|Consumer credit|
| 100001| Closed| currency 1|Consumer credit|
| 100001| Closed| currency 1|Consumer credit|
| 100001| Active| currency 1|Consumer credit|
| 100001| Active| currency 1|Consumer credit|
| 100001| Active| currency 1|Consumer credit|
+----------+-------------+---------------+---------------+
我想通过创建函数 catg_encode(entryData, columnList)
来对列表 columnList
进行独热编码。
columnList = cols_type(entryData, obj=True)[1:]
print(columnList)
['CREDIT_ACTIVE', 'CREDIT_CURRENCY', 'CREDIT_TYPE']
注意 cols_type()
是一个返回列列表的函数,如果obj=True
则返回分类列,否则返回数值列。
我已经成功地对第一列'CREDIT_ACTIVE'
进行了独热编码,但我无法同时对所有列进行编码,也就是说无法构建catg_encode
函数。
# import necessary modules
from pyspark.sql import functions as F
# look for all distinct categoris within a given feature (here 'CREDIT_ACTIVE')
categories = entryData.select(columnList[0]).distinct().rdd.flatMap(lambda x: x).collect()
# one-hot encode the categories
exprs = [F.when(F.col(columnList[0]) == category, 1).otherwise(0).alias(category) for category in categories]
# nice table with encoded feature 'CREDIT_ACTIVE'
oneHotEncode = entryData.select(KEY, *exprs)
print(oneHotEncode)
+----------+--------+----+------+------+
|SK_ID_CURR|Bad debt|Sold|Active|Closed|
+----------+--------+----+------+------+
| 100001| 0| 0| 0| 1|
| 100001| 0| 0| 0| 1|
| 100001| 0| 0| 0| 1|
| 100001| 0| 0| 0| 1|
| 100001| 0| 0| 1| 0|
| 100001| 0| 0| 1| 0|
| 100001| 0| 0| 1| 0|
+----------+--------+----+------+------+
这里的特征'CREDIT_ACTIVE'
有4个不同的类别; ['坏账', '已售出', '活跃', '关闭']
.
注意: 我甚至尝试了IndexToString
和OneHotEncoderEstimator
,但对于这个特定的任务没有帮助。
我期望得到以下输出:
+----------+--------+----+------+------+----------+----------+----------+----------+----------+---
|SK_ID_CURR|Bad debt|Sold|Active|Closed|currency 1|currency 2|currency 3|currency 4|..........|...
+----------+--------+----+------+------+----------+----------+----------+----------+----------+---
| 100001| 0| 0| 0| 1| 1| 0| 0| 0| ..|
| 100001| 0| 0| 0| 1| 1| 0| 0| 0| ..|
| 100001| 0| 0| 0| 1| 1| 0| 0| 0| ..|
| 100001| 0| 0| 0| 1| 1| 0| 0| 0| ..|
| 100001| 0| 0| 1| 0| 1| 0| 0| 0| ..|
| 100001| 0| 0| 1| 0| 1| 0| 0| 0| ..|
| 100001| 0| 0| 1| 0| 1| 0| 0| 0| ..|
+----------+--------+----+------+------+----------+----------+----------+----------+----------+---
连续的点
...
代表了特征'CREDIT_TYPE'
的其余类别,包括:['购买设备贷款', '现金贷款(非指定用途)', '小额贷款', '消费信贷', '移动运营商贷款', '其他类型贷款', '抵押贷款', '银行间信贷', '流动资金贷款', '汽车贷款', '房地产贷款', '未知类型贷款', '业务发展贷款', '信用卡', '股票购买贷款(保证金借贷)']
。
注意:我看到了这篇文章E-num / get Dummies in pyspark
,但它不能自动化处理多个列的情况,而我的问题正是如此。该文章提供了一种解决方案,即为每个分类特征编写单独的代码,但这不是我的问题。
df.withColumn('arr', F.array(columnList))
,然后使用 CountVectorizer 一次性创建独热编码。以下是我以前的一篇帖子作为示例:https://stackoverflow.com/questions/58010126 - jxc