scikit-learn中是否可以打印决策树?

16

有没有一种方法可以在scikit-learn中打印已训练好的决策树? 我想为我的论文训练一棵决策树,并将树的图片放入论文中。这是否可能?

3个回答

17

有一种方法可以将数据导出为 graph_viz 格式:http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

因此,根据在线文档:

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
...     out_file='tree.dot')    

然后您可以使用graph viz进行加载,或者如果您已经安装了pydot,则可以更直接地执行此操作:http://scikit-learn.org/stable/modules/tree.html

>>> from sklearn.externals.six import StringIO  
>>> import pydot 
>>> dot_data = StringIO() 
>>> tree.export_graphviz(clf, out_file=dot_data) 
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue()) 
>>> graph.write_pdf("iris.pdf") 

生成一个svg,无法在此处显示,因此您需要按照链接进行查看:http://scikit-learn.org/stable/_images/iris.svg

更新

自我回答这个问题以来,似乎行为已经发生了变化,现在返回一个list,因此您会收到以下错误:

AttributeError: 'list' object has no attribute 'write_pdf'

首先,当您看到此内容时,值得打印该对象并检查该对象,很可能您想要的是第一个对象:

graph[0].write_pdf("iris.pdf")

感谢@NickBraunagel的评论


7
我遇到了这个错误:AttributeError: 'list' object has no attribute 'write_pdf'。我该怎么解决? - Ernest Soo
@EdChum,您能否帮忙查看一下以下链接:https://dev59.com/Sanka4cB1Zd3GeqPKDKu - user9238790
@ErnestSoo(以及遇到相同错误的任何人):pydot.graph_from_dot_data()返回所需的graph(即pydot.Dot对象),但它将其返回到一个list中:因此,访问列表的第一个对象以访问pydot.Dot对象:graph[0].write_pdf("iris.pdf") - NickBraunagel
1
@NickBraunagel,看起来很多人都遇到了这个错误,我会将此作为更新添加进去。似乎自从我在3年前回答这个问题以来,行为发生了一些变化,谢谢。 - EdChum
1
你会如何在测试数据上执行相同的操作? - bernando_vialli

9
虽然我来晚了,但以下全面的说明对于希望显示决策树输出的其他人可能会有用:
安装必要的模块:
  1. 安装 graphviz。我使用conda的安装包这里( 推荐使用 pip install graphviz 不包含实际的 GraphViz 可执行文件)
  2. 通过 pip 安装 pydot (pip install pydot)
  3. 将包含 .exe 文件(例如 dot.exe)的 graphviz 文件夹目录添加到环境变量 PATH 中。
  4. 运行 EdChum 上述代码 (注意:graph 是一个包含 pydot.Dot 对象的列表):
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO  
import pydot 

clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)

dot_data = StringIO() 
tree.export_graphviz(clf, out_file=dot_data) 
graph = pydot.graph_from_dot_data(dot_data.getvalue()) 

graph[0].write_pdf("iris.pdf")  # must access graph's first element

现在你会在你的环境默认目录中找到“iris.pdf”文件。

7

我知道有以下4种方法可以绘制scikit-learn决策树:

  • 使用sklearn.tree.export_text方法打印树的文本表示
  • 使用sklearn.tree.plot_tree方法绘制(需要matplotlib
  • 使用sklearn.tree.export_graphviz方法绘制(需要graphviz
  • 使用dtreeviz包绘制(需要dtreevizgraphviz

最简单的方法是将决策树导出为文本表示。示例决策树如下所示:

|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2

如果您已安装matplotlib,那么您可以使用sklearn.tree.plot_tree绘制图表:

tree.plot_tree(clf) # the clf is your decision tree model

示例输出与您使用export_graphviz获得的输出类似: sklearn决策树可视化 您还可以尝试使用dtreeviz包。 它将为您提供更多信息。 示例: dtreeviz示例决策树 您可以在此博客文章中找到不同可视化方法的sklearn决策树比较和代码片段:链接

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接