在PySpark中,计算Spark数据框中每列非NaN条目的数量

43

我有一个非常大的数据集,在Hive中加载(约190万行和1450列)。我需要确定每个列的“覆盖范围”,即每个列具有非NaN值的行的比例。

以下是我的代码:

from pyspark import SparkContext
from pyspark.sql import HiveContext
import string as string

sc = SparkContext(appName="compute_coverages") ## Create the context
sqlContext = HiveContext(sc)

df = sqlContext.sql("select * from data_table")
nrows_tot = df.count()

covgs = sc.parallelize(df.columns)
          .map(lambda x: str(x))
          .map(lambda x: (x, float(df.select(x).dropna().count()) / float(nrows_tot) * 100.))

在 PySpark shell 中尝试此操作,如果我执行 covgs.take(10),则会返回一个相当大的错误堆栈。它说在文件 /usr/lib64/python2.6/pickle.py 中的保存过程中存在问题。这是错误的最后一部分:

py4j.protocol.Py4JError: An error occurred while calling o37.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
        at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:333)
        at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342)
        at py4j.Gateway.invoke(Gateway.java:252)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:207)
        at java.lang.Thread.run(Thread.java:745)
有更好的方法来完成这个任务吗?我不能使用 pandas,因为它目前在我工作的集群上不可用,而且我也没有安装它的访问权限。
5个回答

89

让我们从虚拟数据开始:

from pyspark.sql import Row

row = Row("v", "x", "y", "z")
df = sc.parallelize([
    row(0.0, 1, 2, 3.0), row(None, 3, 4, 5.0),
    row(None, None, 6, 7.0), row(float("Nan"), 8, 9, float("NaN"))
]).toDF()

## +----+----+---+---+
## |   v|   x|  y|  z|
## +----+----+---+---+
## | 0.0|   1|  2|3.0|
## |null|   3|  4|5.0|
## |null|null|  6|7.0|
## | NaN|   8|  9|NaN|
## +----+----+---+---+

你只需要进行简单的聚合:
from pyspark.sql.functions import col, count, isnan, lit, sum

def count_not_null(c, nan_as_null=False):
    """Use conversion between boolean and integer
    - False -> 0
    - True ->  1
    """
    pred = col(c).isNotNull() & (~isnan(c) if nan_as_null else lit(True))
    return sum(pred.cast("integer")).alias(c)

df.agg(*[count_not_null(c) for c in df.columns]).show()

## +---+---+---+---+
## |  v|  x|  y|  z|
## +---+---+---+---+
## |  2|  3|  4|  4|
## +---+---+---+---+

如果您想将 NaN 视为 NULL

df.agg(*[count_not_null(c, True) for c in df.columns]).show()

## +---+---+---+---+
## |  v|  x|  y|  z|
## +---+---+---+---+
## |  1|  3|  4|  3|
## +---+---+---+---

您还可以利用 SQL 的 NULL 语义来实现相同的结果,而无需创建自定义函数:

df.agg(*[
    count(c).alias(c)    # vertical (column-wise) operations in SQL ignore NULLs
    for c in df.columns
]).show()

## +---+---+---+
## |  x|  y|  z|
## +---+---+---+
## |  1|  2|  3|
## +---+---+---+

但是这种方法无法处理 NaNs

如果你更喜欢分数:

exprs = [(count_not_null(c) / count("*")).alias(c) for c in df.columns]
df.agg(*exprs).show()

## +------------------+------------------+---+
## |                 x|                 y|  z|
## +------------------+------------------+---+
## |0.3333333333333333|0.6666666666666666|1.0|
## +------------------+------------------+---+

或者
# COUNT(*) is equivalent to COUNT(1) so NULLs won't be an issue
df.select(*[(count(c) / count("*")).alias(c) for c in df.columns]).show()

## +------------------+------------------+---+
## |                 x|                 y|  z|
## +------------------+------------------+---+
## |0.3333333333333333|0.6666666666666666|1.0|
## +------------------+------------------+---+

Scala的等效代码:

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{col, isnan, sum}

type JDouble = java.lang.Double

val df = Seq[(JDouble, JDouble, JDouble, JDouble)](
  (0.0, 1, 2, 3.0), (null, 3, 4, 5.0),
  (null, null, 6, 7.0), (java.lang.Double.NaN, 8, 9, java.lang.Double.NaN)
).toDF()


def count_not_null(c: Column, nanAsNull: Boolean = false) = {
  val pred = c.isNotNull and (if (nanAsNull) not(isnan(c)) else lit(true))
  sum(pred.cast("integer"))
}

df.select(df.columns map (c => count_not_null(col(c)).alias(c)): _*).show
// +---+---+---+---+                                                               
// | _1| _2| _3| _4|
// +---+---+---+---+
// |  2|  3|  4|  4|
// +---+---+---+---+

 df.select(df.columns map (c => count_not_null(col(c), true).alias(c)): _*).show
 // +---+---+---+---+
 // | _1| _2| _3| _4|
 // +---+---+---+---+
 // |  1|  3|  4|  3|
 // +---+---+---+---+

这里的return sum(col(c).isNotNull().cast("integer")).alias(c)是如何自动知道要访问哪个dataframe的?是因为我们从特定的dataframe获取了列名吗? - Roshini
@Roshini 列仅在特定的 SQL 表达式作用域内具有意义,该表达式定义了绑定。换句话说,给定 select 的上下文定义了如何解析列。 - zero323
如果nan计数大于阈值,如何选择列? - rosefun
第一次尝试时出现“TypeError: Column is not iterable”错误。 - Jérémy

0
对于字符串和数字列,使用summary非常方便。
  • 计算非空值:

    df.summary("count").show()
    
  • 计算非NaN值:

    df.replace(float("nan"), None).summary("count").show()
    

注意。 summary 不会返回除字符串或数字类型之外的列(例如,日期类型的列将从结果中省略)。


完整测试:

df = spark.createDataFrame(
    [(0.0, 1, 2, float("Nan")),
     (None, 3, 4, 5.0),
     (None, None, 6, 7.0),
     (float("Nan"), 8, 9, 7.0)],
    ["v", "x", "y", "z"])
df.show()
# +----+----+---+---+
# |   v|   x|  y|  z|
# +----+----+---+---+
# | 0.0|   1|  2|NaN|
# |null|   3|  4|5.0|
# |null|null|  6|7.0|
# | NaN|   8|  9|7.0|
# +----+----+---+---+

df.summary("count").show()
# +-------+---+---+---+---+
# |summary|  v|  x|  y|  z|
# +-------+---+---+---+---+
# |  count|  2|  3|  4|  4|
# +-------+---+---+---+---+

df.replace(float("nan"), None).summary("count").show()
# +-------+---+---+---+---+
# |summary|  v|  x|  y|  z|
# +-------+---+---+---+---+
# |  count|  1|  3|  4|  3|
# +-------+---+---+---+---+

0

您可能会遇到 数据类型不匹配 异常:

org.apache.spark.sql.AnalysisException: cannot resolve 'isnan(`date_hour`)' due to data type mismatch: argument 1 requires (double or float) type, however, '`date_hour`' is of timestamp type.;

最好先选择数字列:

from pyspark.sql.functions import *

def get_numerical_cols(df):
    return [i.name for i in df.schema  if str(i.dataType) in ('IntegerType', 'LongType', 'FloatType', 'DoubleType') ]

numcols = get_numerical_cols(df)
df_nan_rate = df.select([(count(when(isnan(c) | col(c).isNull(), c))/count(lit(1))).alias(c) for c in numcols])

0
from pyspark.sql import functions as F

z = df.count()
(df.replace(float('nan'), None)
 .agg(*[F.expr(f'count({col})/{z} as {col}') for col in df.columns])
).show()

请不要仅仅发布代码作为答案,还要提供解释您的代码是如何解决问题的。带有解释的答案通常更有帮助和更高质量,并且更有可能吸引赞同。 - Mark Rotteveel

0
你可以使用 isNotNull()
df.where(df[YOUR_COLUMN].isNotNull()).select(YOUR_COLUMN).show()

为什么会有踩票?这段代码非常优雅,至少和上面的Spark SQL代码一样Pythonic(也很出色,但在许多更简单的情况下,这段代码也能胜任)。点个赞。给大家都点个赞! - eric
1
Nulls和nans具有不同的功能。 - Tanner Clark

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接