如何在pyspark中可视化决策树模型/对象?

4
有没有办法在Pyspark中使用mllib或ml库创建的决策树进行可视化/绘图。还想知道如何获取叶节点中记录数等信息。谢谢。

使用Python绘制Pyspark决策树也是可以的。 - Neo
2个回答

3

首先,您需要使用model.toDebugString来获取与您的随机森林模型类似的输出:

 "RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 0 <= 3="" 10="" 1.0)="" if="" (feature="" <="0.0)" predict:="" 0.0="" else=""> 6.0)
       Predict: 0.0
     Else (feature 10 > 0.0)
      If (feature 12 <= 12="" 63.0)="" predict:="" 0.0="" else="" (feature=""> 63.0)
       Predict: 0.0
    Else (feature 0 > 1.0)
     If (feature 13 <= 3="" 1.0)="" if="" (feature="" <="3.0)" predict:="" 0.0="" else=""> 3.0)
       Predict: 1.0
     Else (feature 13 > 1.0)
      If (feature 7 <= 7="" 1.0)="" predict:="" 0.0="" else="" (feature=""> 1.0)
       Predict: 0.0
  Tree 1 (weight 1.0):
    If (feature 2 <= 11="" 15="" 1.0)="" if="" (feature="" <="0.0)" predict:="" 0.0="" else=""> 0.0)
       Predict: 1.0
     Else (feature 15 > 0.0)
      If (feature 11 <= 11="" 0.0)="" predict:="" 0.0="" else="" (feature=""> 0.0)
       Predict: 1.0
    Else (feature 2 > 1.0)
     If (feature 12 <= 5="" 31.0)="" if="" (feature="" <="0.0)" predict:="" 0.0="" else=""> 0.0)
       Predict: 0.0
     Else (feature 12 > 31.0)
      If (feature 3 <= 3="" 4.0)="" predict:="" 0.0="" else="" (feature=""> 4.0)
       Predict: 0.0
  Tree 2 (weight 1.0):
    If (feature 8 <= 4="" 6="" 1.0)="" if="" (feature="" <="2.0)" predict:="" 0.0="" else=""> 10875.0)
       Predict: 1.0
     Else (feature 6 > 2.0)
      If (feature 1 <= 1="" 36.0)="" predict:="" 0.0="" else="" (feature=""> 36.0)
       Predict: 1.0
    Else (feature 8 > 1.0)
     If (feature 5 <= 4="" 0.0)="" if="" (feature="" <="4113.0)" predict:="" 0.0="" else=""> 4113.0)
       Predict: 1.0
     Else (feature 5 > 0.0)
      If (feature 11 <= 11="" 2.0)="" predict:="" 0.0="" else="" (feature=""> 2.0)
       Predict: 0.0
  Tree 3 ...

将其保存为某个.txt文件,然后使用:https://github.com/tristaneljed/Decision-Tree-Visualization-Spark


我们在pyspark中有没有仅使用d3.js的替代方法?另外,model.toDebugString无法提供每个节点接收到的数据百分比。 - Neo

1
你可以获取所有叶子节点的统计数字,例如不纯度、增益、基尼指数,以及由模型数据文件分类为每个标签的元素数组。
数据文件位于您保存模型/数据/的位置。
model.save(location)
modeldf = spark.read.parquet(location+"data/*")

这个文件包含决策树或随机森林所需的许多元数据。您可以提取所有所需的信息,例如:
noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()
df = pd.Dataframe([[rw['id'],rw['gain],rw['impurity'],rw['gini']] for rw in noderows if rw['leftChild'] < 0 and rw['rightChild'] < 0])
df.show()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接