我该如何对 PySpark 程序进行单元测试?

54

我的目前的Java / Spark单元测试方法(详细说明在这里)是通过使用“local”实例化SparkContext并使用JUnit运行单元测试。

代码必须组织在一个函数中执行I/O,然后调用另一个带有多个RDD的函数。

这非常有效。我编写了一个使用Java + Spark编写的高度测试过的数据转换。

我能否使用Python实现相同的功能?

如何使用Python运行Spark单元测试?


3
您可以使用pySpark和unittest模块完成相同的操作。该项目的测试本身就使用了这个模块:https://github.com/apache/spark/blob/master/python/pyspark/tests.py - Paul K.
3
pytest + chispa 让 PySpark 代码的单元测试变得容易。避免使用 unittest。chispa 是 spark-fast-tests 的本地 PySpark 端口。请查看我的答案以获取更多详细信息。 - Powers
1
@PaulK。你好,你分享的链接无效 :) - wawawa
8个回答

33
我建议也使用py.test。py.test可以轻松创建可重用的SparkContext测试夹具,并使用它编写简洁的测试函数。您还可以专门制定夹具(例如创建StreamingContext),并在测试中使用一个或多个夹具。
我在Medium上撰写了一篇关于此主题的博客文章:

https://engblog.nextdoor.com/unit-testing-apache-spark-with-py-test-3b8970dc013b

这是一个来自帖子的片段:
pytestmark = pytest.mark.usefixtures("spark_context")
def test_do_word_counts(spark_context):
    """ test word couting
    Args:
       spark_context: test fixture SparkContext
    """
    test_input = [
        ' hello spark ',
        ' hello again spark spark'
    ]

    input_rdd = spark_context.parallelize(test_input, 1)
    results = wordcount.do_word_counts(input_rdd)

    expected_results = {'hello':2, 'spark':3, 'again':1}  
    assert results == expected_results

9
欢迎来到 Stack Overflow!我们不太支持仅包含链接的回答。也就是说,如果链接消失了,这个答案将没有持久价值。建议您在回答中添加一些有用的文字摘要或突出显示链接资源的关键点。 - sclv
@Vikas Kawadia,您能否请看一下https://dev59.com/Z1UM5IYBdhLWcg3wCcuz这个链接? - User12345
博客文章中概述的RDD测试还不错,但是DataFrame测试仅检查是否有两行数据。它没有验证DataFrame模式和内容是否相同,因此它不是一个强大的测试。请参阅我的答案,了解更好的DataFrame比较方法。 - Powers

30

如果您正在使用Spark 2.x和SparkSession,这是一个使用pytest的解决方案。我也在导入第三方包。

import logging

import pytest
from pyspark.sql import SparkSession

def quiet_py4j():
    """Suppress spark logging for the test context."""
    logger = logging.getLogger('py4j')
    logger.setLevel(logging.WARN)


@pytest.fixture(scope="session")
def spark_session(request):
    """Fixture for creating a spark context."""

    spark = (SparkSession
             .builder
             .master('local[2]')
             .config('spark.jars.packages', 'com.databricks:spark-avro_2.11:3.0.1')
             .appName('pytest-pyspark-local-testing')
             .enableHiveSupport()
             .getOrCreate())
    request.addfinalizer(lambda: spark.stop())

    quiet_py4j()
    return spark


def test_my_app(spark_session):
   ...

注意,如果使用Python 3,则需要将其作为PYSPARK_PYTHON环境变量进行指定:

import os
import sys

IS_PY2 = sys.version_info < (3,)

if not IS_PY2:
    os.environ['PYSPARK_PYTHON'] = 'python3'

否则,你会得到以下错误:

异常: worker 中的 Python 版本为 2.7,与 driver 中的版本 3.5 不同,PySpark 无法在不同的次要版本下运行。请检查环境变量 PYSPARK_PYTHON 和 PYSPARK_DRIVER_PYTHON 是否设置正确。


1
Avro插件可以在Spark 2.1中加载,但不能在Spark 2.0.2中加载。在尝试使用Avro格式之前,您不会收到任何错误。我亲自进行过测试。 - clay
5
设置PYSPARK_PYTHON的正确值的一种稍微简单的方法是:os.environ['PYSPARK_PYTHON'] = sys.executable-- 这将设置为当前正在运行的Python版本,也有望更好地处理虚拟环境。 - Ash Berlin-Taylor
@ksindi,您能否请看一下https://dev59.com/Z1UM5IYBdhLWcg3wCcuz - User12345
@user9367133回答了你的问题。 - Kamil Sindi
非常好的模板,包括运行测试所需的PYSPARK~设置,enableHiveSupport()quietLog4j。这些似乎仍然相关。 - WestCoastProjects
显示剩余2条评论

22

假设您已安装pyspark,您可以使用以下类在unittest中进行单元测试:

```python TBD ```
import unittest
import pyspark


class PySparkTestCase(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        conf = pyspark.SparkConf().setMaster("local[2]").setAppName("testing")
        cls.sc = pyspark.SparkContext(conf=conf)
        cls.spark = pyspark.SQLContext(cls.sc)

    @classmethod
    def tearDownClass(cls):
        cls.sc.stop()

例子:

class SimpleTestCase(PySparkTestCase):

    def test_with_rdd(self):
        test_input = [
            ' hello spark ',
            ' hello again spark spark'
        ]

        input_rdd = self.sc.parallelize(test_input, 1)

        from operator import add

        results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
        self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])

    def test_with_df(self):
        df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']], 
                                        schema=['c1', 'c2'])
        self.assertEqual(df.count(), 2)

请注意这将为每个类创建一个上下文环境。使用setUp代替setUpClass可获得每个测试的上下文环境。由于创建新的Spark上下文环境目前较昂贵,因此这通常会增加测试执行的开销时间。


你好@Jorge。如果我想在一个新的测试类中创建一个额外的setUpClass,并且需要从PySparkTestCase访问sparkSession,该怎么办?我尝试调用super().setUpClass(),然后访问super().spark,但这并不起作用。 - itscarlayall
没关系,已经通过 cls.spark 解决了它! - itscarlayall

11

我使用pytest,它允许测试夹具,因此您可以实例化一个pyspark上下文,并将其注入到所有需要它的测试中。大致如下:

@pytest.fixture(scope="session",
                params=[pytest.mark.spark_local('local'),
                        pytest.mark.spark_yarn('yarn')])
def spark_context(request):
    if request.param == 'local':
        conf = (SparkConf()
                .setMaster("local[2]")
                .setAppName("pytest-pyspark-local-testing")
                )
    elif request.param == 'yarn':
        conf = (SparkConf()
                .setMaster("yarn-client")
                .setAppName("pytest-pyspark-yarn-testing")
                .set("spark.executor.memory", "1g")
                .set("spark.executor.instances", 2)
                )
    request.addfinalizer(lambda: sc.stop())

    sc = SparkContext(conf=conf)
    return sc

def my_test_that_requires_sc(spark_context):
    assert spark_context.textFile('/path/to/a/file').count() == 10

然后你可以通过调用 py.test -m spark_local 在本地模式下运行测试,或者通过 py.test -m spark_yarn 在YARN上运行。这对我来说效果很好。


请您看一下 https://dev59.com/Z1UM5IYBdhLWcg3wCcuz - User12345

10

您可以通过在测试套件中运行DataFrames上的代码并比较DataFrame列的相等性或两个完整DataFrames的相等性来测试PySpark代码。

quinn项目有几个示例

为测试套件创建SparkSession

创建一个tests/conftest.py文件,并使用此fixture,以便您可以轻松地在测试中访问SparkSession。

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope='session')
def spark():
    return SparkSession.builder \
      .master("local") \
      .appName("chispa") \
      .getOrCreate()

列相等性示例

假设您想要测试以下函数,该函数从字符串中删除所有非单词字符。

def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")

你可以使用在chispa库中定义的assert_column_equality函数来测试此功能。
def test_remove_non_word_characters(spark):
    data = [
        ("jo&&se", "jose"),
        ("**li**", "li"),
        ("#::luisa", "luisa"),
        (None, None)
    ]
    df = spark.createDataFrame(data, ["name", "expected_name"])\
        .withColumn("clean_name", remove_non_word_characters(F.col("name")))
    assert_column_equality(df, "clean_name", "expected_name")

DataFrame相等性示例

一些函数需要通过比较整个DataFrame来进行测试。这里有一个对DataFrame列进行排序的函数。

def sort_columns(df, sort_order):
    sorted_col_names = None
    if sort_order == "asc":
        sorted_col_names = sorted(df.columns)
    elif sort_order == "desc":
        sorted_col_names = sorted(df.columns, reverse=True)
    else:
        raise ValueError("['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format(
            sort_order=sort_order
        ))
    return df.select(*sorted_col_names)

下面是为此函数编写的一个测试用例。

def test_sort_columns_asc(spark):
    source_data = [
        ("jose", "oak", "switch"),
        ("li", "redwood", "xbox"),
        ("luisa", "maple", "ps4"),
    ]
    source_df = spark.createDataFrame(source_data, ["name", "tree", "gaming_system"])

    actual_df = T.sort_columns(source_df, "asc")

    expected_data = [
        ("switch", "jose", "oak"),
        ("xbox", "li", "redwood"),
        ("ps4", "luisa", "maple"),
    ]
    expected_df = spark.createDataFrame(expected_data, ["gaming_system", "name", "tree"])

    assert_df_equality(actual_df, expected_df)

测试I/O

通常最好将代码逻辑与I/O函数分离,这样更容易进行测试。

假设您有一个像这样的函数。

def your_big_function:
    df = spark.read.parquet("some_directory")
    df2 = df.withColumn(...).transform(function1).transform(function2)
    df2.write.parquet("other directory")

最好按照以下方式对代码进行重构:

def all_logic(df):
  return df.withColumn(...).transform(function1).transform(function2)

def your_formerly_big_function:
    df = spark.read.parquet("some_directory")
    df2 = df.transform(all_logic)
    df2.write.parquet("other directory")

像这样设计你的代码可以让你轻松地使用上面提到的列等式或DataFrame等式函数来测试all_logic函数。你可以使用mocking来测试your_formerly_big_function。通常最好在测试套件中避免I/O(但有时是不可避免的)。


1
未找到 assert_df_equality。 - Krishna Kumar Singh

5

pyspark有一个unittest模块,可以按以下方式使用

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase

class MySparkTests(PySparkTestCase):
    def spark_session(self):
        return pyspark.SQLContext(self.sc)

    def createMockDataFrame(self):
         self.spark_session().createDataFrame(
            [
                ("t1", "t2"),
                ("t1", "t2"),
                ("t1", "t2"),
            ],
            ['col1', 'col2']
        )

2
不久前,我也遇到了同样的问题。经过阅读几篇文章、论坛和一些StackOverflow答案后,我编写了一个小插件用于pytest:pytest-spark
我已经使用它几个月了,在Linux上的通用工作流程看起来不错:
  1. 安装Apache Spark(设置JVM + 解压Spark的分发到某个目录)
  2. 安装 "pytest" + 插件 "pytest-spark"
  3. 在项目目录中创建 "pytest.ini" 并在其中指定Spark位置。
  4. 像往常一样使用pytest运行您的测试。
  5. 可选地,您可以在测试中使用插件提供的夹具“spark_context”,该夹具尝试最小化输出中的Spark日志。

1

结合其他答案,这是我在使用pyspark 3.3与fixtures(pytest)和TestCaseunittest)时找到的解决方法。首先设置一个spark session的fixture,稍后将为所有相关测试调用该fixture。通过使用fixture,我们避免了每次需要初始化会话时都需要进行设置的开销。这在src/tests/conftest.py中完成。

# src/tests/conftest.py

import pytest

from pyspark.sql import SparkSession


@pytest.fixture(scope="session")
def spark_session():
    spark = (
        SparkSession.builder.master("local[1]")  # run on local machine
        .appName("local-tests")
        .config("spark.executor.cores", "1")
        .config("spark.executor.instances", "1")
        .config("spark.sql.shuffle.partitions", "1")
        .config("spark.driver.bindAddress", "127.0.0.1")
        .getOrCreate()
    )
    yield spark
    spark.stop()

具有以下功能:
# src/utils/spark_utils.py
from pyspark.sql import DataFrame

def my_spark_function(df: DataFrame) -> bool:
   ...

测试:

# src/tests/utils/test_spark_utils.py

from unittest import TestCase

import pytest
from utils.spark_utils import my_spark_function

columns_underscore = ["the", "watchtower"]
data = [("joker", 1), ("thief", 2), ("princes", 3)]


class TestMySparkFunction(TestCase):
    @pytest.fixture(autouse=True)
    def prepare_fixture(self, spark_session):
        self.spark_session = spark_session

    def test_function_okay(self):
        df = self.spark_session.createDataFrame(data=data, schema=columns)
        self.assertEqual(my_spark_function(df), True)

最后,可以使用 pytest 运行测试了。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接