如何访问由RandomForestClassifier(spark.ml版本)创建的模型中的个别树?

4
如何访问由Spark ML的RandomForestClassifier生成的模型中的个别树?我正在使用RandomForestClassifier的Scala版本。
1个回答

6

实际上它有trees属性:

import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{
  RandomForestClassificationModel, RandomForestClassifier, 
  DecisionTreeClassificationModel
}

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0", "1.0")
  .toMetadata

val data = sqlContext.read.format("libsvm")
  .load("data/mllib/sample_libsvm_data.txt")
  .withColumn("label", $"label".as("label", meta))

val rf: RandomForestClassifier = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
  case t: DecisionTreeClassificationModel => t
}

正如您所看到的,唯一的问题是正确获取类型,以便我们实际使用这些类型:

trees.head.transform(data).show(3)
// +-----+--------------------+-------------+-----------+----------+
// |label|            features|rawPrediction|probability|prediction|
// +-----+--------------------+-------------+-----------+----------+
// |  0.0|(692,[127,128,129...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
// |  1.0|(692,[158,159,160...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// |  1.0|(692,[124,125,126...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// +-----+--------------------+-------------+-----------+----------+
// only showing top 3 rows

注意:

如果你使用管道,你也可以提取单独的树:

import org.apache.spark.ml.Pipeline

val model = new Pipeline().setStages(Array(rf)).fit(data)

// There is only one stage and know its type 
// but lets be thorough
val rfModelOption = model.stages.headOption match {
  case Some(m: RandomForestClassificationModel) => Some(m)
  case _ => None
}

val trees = rfModelOption.map {
  _.trees //  ... as before
}.getOrElse(Array())

嗨zero323,感谢您的帮助。我有一个后续问题。我想从具有高预测概率(例如0.3以上)的树节点中提取规则。在spark.ml中,对象impurityStats是树的内部节点中的私有对象,方法toOld和fromOld也是如此。我需要那些细节(因为它们是私有的,所以我无法访问)才能提取任何内容。同样,节点的分裂不提供有关其类别和特征阈值的任何信息。是否有任何方法可以从spark.ml中的高概率节点中提取规则? - machine_learner
我不知道是否有简单的解决方案。你应该将其作为一个单独的问题进行提问 - 也许已经有人找到了解决方案。如果你找到了,请通过链接联系我。 - zero323
谢谢 zero323。我刚刚发布了一个问题"如何从Spark ML RandomForestClassifier模型(Scala版本)中提取规则?" 如果我得到答案,我会把消息告诉你的。 - machine_learner

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接