背景
我最初的问题是,为什么在map函数中使用DecisionTreeModel.predict
会引发异常?这与如何在Spark中使用MLlib生成(原始标签,预测标签)元组有关?
当我们使用Scala API 建议的方式获取RDD[LabeledPoint]
的预测值时,只需对RDD
进行映射即可使用DecisionTreeModel
。
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
很不幸,PySpark中类似的方法效果并不好:
labelsAndPredictions = testData.map(
lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()
异常:您似乎正在尝试从广播变量、操作或转换中引用SparkContext。SparkContext只能在驱动程序上使用,而不能在运行在工作节点上的代码中使用。有关更多信息,请参见SPARK-5063。
相反,官方文档建议这样做:
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
那么这里发生了什么?这里没有广播变量,Scala API 定义了 predict
如下:
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predict(features)
}
/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
因此,至少乍一看从操作或转换中调用似乎不是问题,因为预测似乎是一个本地操作。
解释
经过一番探索,我发现问题的根源是从DecisionTreeModel.predict调用的JavaModelWrapper.call
方法。它访问了SparkContext
,这是调用Java函数所必需的:
callJavaFunc(self._sc, getattr(self._java_model, name), *a)
问题
在DecisionTreeModel.predict
的情况下,有一个推荐的解决方法,所有必需的代码已经是Scala API的一部分,但是是否有任何优雅的方式来处理这种问题?
我现在能想到的唯一解决方案都比较繁重:
- 通过隐式转换扩展Spark类或添加某种包装器将所有内容下推到JVM
- 直接使用Py4j网关