在Spark RDD和/或Spark DataFrames中重新塑形/透视数据

25

我有一些数据,格式如下(可以是RDD或Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

我想做的是“重塑”数据,将Country(特别是美国、英国和加拿大)中的某些行转换为列:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

基本上,我需要类似于Python的pivot工作流程:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

我的数据集非常大,所以我无法使用collect()将数据加载到内存中,并在Python中进行重塑。有没有一种方法可以将Python的.pivot()转换为可调用函数,并同时映射RDD或Spark DataFrame?任何帮助都将不胜感激!

6个回答

22
自从Spark 1.6版本起,您可以在GroupedData上使用pivot函数并提供聚合表达式。
pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

等级可以省略,但如果提供可以提高性能并作为内部过滤器。

这种方法仍然相对缓慢,但肯定比手动在JVM和Python之间传递数据要好。


7
首先,这可能不是一个好主意,因为你没有获得任何额外的信息,但你会将自己与固定模式(即你必须知道期望的国家数量,当然,额外的国家意味着代码的变化)绑定。
话虽如此,这是一个SQL问题,如下所示。但是,如果你认为它不太像“软件”(说真的,我听过这样的话!),那么你可以参考第一个解决方案。
解决方案1:
def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

现在,解决方案2:当然更好,因为SQL是处理这个问题的正确工具。
callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

数据设置:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

结果:

来自第一种解决方案

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

从第二个解决方案开始:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

请告诉我这个是否有效,谢谢 :)

最好的 Ayan


谢谢!你的解决方案可行,更重要的是它们是可扩展的! - Jason
1
你能将其扩展为更通用的情况吗?例如,我的数据中有时可能有3个国家,另一次可能有5个。你上面的代码似乎是硬编码为4个特定的国家。我知道我需要提前知道有哪些国家,但随着时间的推移,这可能会发生变化。我如何将国家列表作为参数传递,并使其仍然起作用?在处理数据时,这是一件相当常见的事情,所以我希望这很快就可以成为内置功能。 - J Calbreath
正如我所指出的,这是一个模式设计上的问题。你不能简单地传递一个国家列表,因为你的模式将在下游发生改变。然而,你可以通过从reshape返回广义元组并设置aggregateByKey的零值来实现。在SQL方法中,你需要基本上按照这里描述的模式编程生成一个SQL语句。 - ayan guha
2
这是一个在大多数数据语言/框架中都存在的相当常见的功能:如SAS、Scalding、Pandas等。希望这个功能很快能在Spark中出现。 - J Calbreath
1
我基于你上面的回答创建了一个灵活的版本。您可以在此处查看:https://dev59.com/hF0a5IYBdhLWcg3whI4b。希望Spark尽快实现此解决方案,因为它是大多数其他数据操作语言/工具(Pandas、Scalding、SAS、Excel等)中相当基本的功能。 - J Calbreath

5
这里有一种本地的Spark方法,不硬编码列名。它基于aggregateByKey,并使用字典来收集每个键出现的列。然后我们收集所有列名来创建最终数据帧。[之前的版本在发出每个记录的字典后使用jsonRDD,但这更有效率。]限制特定列表中的列或排除像XX这样的列将是一个简单的修改。
即使在相当大的表上,性能似乎也很好。我正在使用一种变体,计算每个ID的多个事件发生次数,为每个事件类型生成一列。代码基本相同,只是在seqFn中使用collections.Counter来计算出现次数。
from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    return u

def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    df
    .rdd
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
    schema=StructType(
        [StructField('ID', StringType())] + 
        [StructField(c, IntegerType()) for c in columns]
    )
)
result.show()

生成:

ID  CA UK US XX  
X02 7  6  4  8   
X01 2  1  3  null

不错的写作 - 顺便说一句,Spark 1.6数据框支持简单的透视表。 https://github.com/apache/spark/pull/7841 - meyerson
酷啊 - Spark 正在迅速变得更好。 - patricksurry
如果重新塑造的输出太大而无法适应内存,我该怎么办?我如何直接在磁盘上处理它? - skan

1

首先,我需要对你的RDD进行更正(这与你的实际输出相匹配):

rdd = sc.parallelize([('X01',41,'US',3),
                      ('X01',41,'UK',1),
                      ('X01',41,'CA',2),
                      ('X02',72,'US',4),
                      ('X02',72,'UK',6),
                      ('X02',72,'CA',7),
                      ('X02',72,'XX',8)])

一旦我进行了更正,这个就起作用了:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
    $"ID" === $"usID" and $"C1" === "US"
)
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
    $"ID" === $"ukID" and $"C2" === "UK"
)
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")

肯定不如您的透视表那么优雅。


David,我无法让它工作。首先,Spark不接受$作为引用列的方式。在删除所有$符号后,我仍然得到一个语法错误,指向你上面代码中最后一行的.select表达式。 - Jason
抱歉,我正在使用Scala。它是直接从spark-shell中复制粘贴的。如果你将最后一个select()去掉,你应该可以得到正确的结果,只是列数太多了。你能做到这一点并发布结果吗? - David Griffin

1

关于 patricksurry 的非常有帮助的答案,我想提出一些评论:

  • 列“Age”缺失,只需在函数seqPivot中添加u["Age"] = v.Age即可。
  • 结果发现,遍历两个列元素的循环以不同的顺序给出了元素。这些列的值是正确的,但它们的名称却不正确。要避免这种情况,只需对列进行排序。

以下是稍加修改后的代码:

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

# u is a dictionarie
# v is a Row
def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    # In the original posting the Age column was not specified
    u["Age"] = v.Age
    return u

# u1
# u2
def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    rdd
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)

columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)

columns_ord = sorted(columns)

result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
        schema=StructType(
            [StructField('ID', StringType())] + 
            [StructField(c, IntegerType()) for c in columns_ord]
        )
    )

print result.show()

最终,输出应为:

+---+---+---+---+---+----+
| ID|Age| CA| UK| US|  XX|
+---+---+---+---+---+----+
|X02| 72|  7|  6|  4|   8|
|X01| 41|  2|  1|  3|null|
+---+---+---+---+---+----+

0

在Hive中有一个JIRA,可以原生地执行PIVOT操作,而不需要为每个值编写大量的CASE语句:

https://issues.apache.org/jira/browse/HIVE-3776

请投票支持 JIRA,这样它就能尽早实现。 一旦在 Hive SQL 中实现了它,Spark 通常也不会落后太多,并最终在 Spark 中实现。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接