从任务中调用Java/Scala函数

45

背景

我最初的问题是,为什么在map函数中使用DecisionTreeModel.predict会引发异常?这与如何在Spark中使用MLlib生成(原始标签,预测标签)元组有关?

当我们使用Scala API 建议的方式获取RDD[LabeledPoint]的预测值时,只需对RDD进行映射即可使用DecisionTreeModel

val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}

很不幸,PySpark中类似的方法效果并不好:
labelsAndPredictions = testData.map(
    lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()

异常:您似乎正在尝试从广播变量、操作或转换中引用SparkContext。SparkContext只能在驱动程序上使用,而不能在运行在工作节点上的代码中使用。有关更多信息,请参见SPARK-5063

相反,官方文档建议这样做:

predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

那么这里发生了什么?这里没有广播变量,Scala API 定义了 predict 如下:

/**
 * Predict values for a single data point using the model trained.
 *
 * @param features array representing a single data point
 * @return Double prediction from the trained model
 */
def predict(features: Vector): Double = {
  topNode.predict(features)
}

/**
 * Predict values for the given data set using the model trained.
 *
 * @param features RDD representing data points to be predicted
 * @return RDD of predictions for each of the given data points
 */
def predict(features: RDD[Vector]): RDD[Double] = {
  features.map(x => predict(x))
}

因此,至少乍一看从操作或转换中调用似乎不是问题,因为预测似乎是一个本地操作。

解释

经过一番探索,我发现问题的根源是从DecisionTreeModel.predict调用的JavaModelWrapper.call方法。它访问SparkContext,这是调用Java函数所必需的:

callJavaFunc(self._sc, getattr(self._java_model, name), *a)

问题

DecisionTreeModel.predict的情况下,有一个推荐的解决方法,所有必需的代码已经是Scala API的一部分,但是是否有任何优雅的方式来处理这种问题?

我现在能想到的唯一解决方案都比较繁重:

  • 通过隐式转换扩展Spark类或添加某种包装器将所有内容下推到JVM
  • 直接使用Py4j网关

2
在这一部分上是正确的。我在尝试将Scala中相同的代码实现放入Python的决策树时遇到了相同的广播问题,因此不得不使用.zip函数将标签组合回来。感谢解释! - Anchit
1个回答

55
使用默认的Py4J网关进行通信是不可能的。要理解原因,我们需要查看来自PySpark内部文档[1]的以下图表:

enter image description here

由于Py4J网关运行在驱动程序上,因此无法通过套接字与JVM工作进程通信的Python解释器访问它(例如,请参见PythonRDD/rdd.py)。
理论上,为每个工作进程创建单独的Py4J网关是可能的,但实际上这不太有用。忽略可靠性问题,Py4J简单地不适合执行数据密集型任务。
是否有任何变通方法?
  1. 使用Spark SQL数据源API封装JVM代码。

    优点:受支持,高级别,不需要访问内部PySpark API

    缺点:相对冗长且文档不太完善,主要限于输入数据

  2. 使用Scala UDF在数据帧上操作。

    优点:易于实现(参见Spark:如何将Python与Scala或Java用户定义的函数映射?),如果数据已存储在数据帧中,则无需在Python和Scala之间进行数据转换,最小化访问Py4J

    缺点:需要访问Py4J网关和内部方法,仅限于Spark SQL,难以调试,不受支持

  3. 以类似于MLlib中所做的方式创建高级别Scala接口。

    优点:灵活,能够执行任意复杂的代码。它可以直接在RDD上执行(例如参见MLlib模型包装器)或者使用DataFrames(参见如何在Pyspark中使用Scala类)。后一种解决方案似乎更加友好,因为所有的序列化和反序列化细节都由现有的API处理。

    缺点:低级别,需要数据转换,与UDF相同需要访问Py4J和内部API,不受支持

    一些基本示例可以在使用Scala转换PySpark RDD中找到

  4. 使用外部工作流管理工具在Python和Scala / Java作业之间切换并将数据传递给DFS。

    优点:易于实现,对代码本身的更改最小

    缺点:读取/写入数据的成本(Alluxio?)

  5. 使用共享的SQLContext(例如参见Apache ZeppelinLivy)通过注册临时表在客户语言之间传递数据。

    优点:非常适合交互式分析

    缺点:不太适合批处理作业(Zeppelin)或可能需要额外的编排(Livy)


  1. Joshua Rosen. (2014年8月4日) PySpark 内部原理。检索自https://cwiki.apache.org/confluence/display/SPARK/PySpark+Internals

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接