如何在PySpark Dataframe中设置显示精度

17

在调用.show()函数时,如何设置PySpark中的显示精度?

考虑以下示例:

from math import sqrt
import pyspark.sql.functions as f

data = zip(
    map(lambda x: sqrt(x), range(100, 105)),
    map(lambda x: sqrt(x), range(200, 205))
)
df = sqlCtx.createDataFrame(data, ["col1", "col2"])
df.select([f.avg(c).alias(c) for c in df.columns]).show()

输出结果为:

#+------------------+------------------+
#|              col1|              col2|
#+------------------+------------------+
#|10.099262230352151|14.212583322380274|
#+------------------+------------------+

我该如何更改它以便只显示小数点后三位数字?

期望的输出:

#+------+------+
#|  col1|  col2|
#+------+------+
#|10.099|14.213|
#+------+------+

这是一个PySpark版本的这个scala问题。我在这里发布它,因为当我搜索PySpark解决方案时找不到答案,我认为它对未来的其他人有帮助。

2个回答

20

四舍五入

最简单的方法是使用pyspark.sql.functions.round()函数:

from pyspark.sql.functions import avg, round
df.select([round(avg(c), 3).alias(c) for c in df.columns]).show()
#+------+------+
#|  col1|  col2|
#+------+------+
#|10.099|14.213|
#+------+------+

这将保持数值类型的值不变。

格式化数字

functionsscalapython中是相同的,唯一的区别是import

您可以使用format_number将数字格式化为希望的小数位数,如官方api文档所述:

将数字列x格式化为类似于“#,###,###.##”的格式,舍入到d个小数位,并将结果作为字符串列返回。

from pyspark.sql.functions import avg, format_number 
df.select([format_number(avg(c), 3).alias(c) for c in df.columns]).show()
#+------+------+
#|  col1|  col2|
#+------+------+
#|10.099|14.213|
#+------+------+

变换后的列将会是 StringType 类型,并使用逗号作为千位分隔符:

#+-----------+--------------+
#|       col1|          col2|
#+-----------+--------------+
#|500,100.000|50,489,590.000|
#+-----------+--------------+

正如此答案中所述的Scala版本,我们可以使用regexp_replace,替换为任何您想要的字符串。

用rep替换匹配regexp的指定字符串值的所有子字符串。

from pyspark.sql.functions import avg, format_number, regexp_replace
df.select(
    [regexp_replace(format_number(avg(c), 3), ",", "").alias(c) for c in df.columns]
).show()
#+----------+------------+
#|      col1|        col2|
#+----------+------------+
#|500100.000|50489590.000|
#+----------+------------+

0

只需将答案封装到一个仅处理浮点和双精度列的函数中。

import pyspark.sql.functions as F
from pyspark.sql import DataFrame

def dataframe_format_float(df: DataFrame, num_decimals=4) -> DataFrame:
    r = []
    for c in df.dtypes:
        name, dtype = c[0], c[1]
        if dtype in ['float', 'double']:
            r.append(F.round(name, num_decimals).alias(name))
        else:
            r.append(name)
    df = df.select(r)
    return df

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接