有没有一种方法可以在scikit-learn中打印已训练好的决策树? 我想为我的论文训练一棵决策树,并将树的图片放入论文中。这是否可能?
有没有一种方法可以在scikit-learn中打印已训练好的决策树? 我想为我的论文训练一棵决策树,并将树的图片放入论文中。这是否可能?
有一种方法可以将数据导出为 graph_viz 格式:http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
因此,根据在线文档:
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
... out_file='tree.dot')
然后您可以使用graph viz进行加载,或者如果您已经安装了pydot,则可以更直接地执行此操作:http://scikit-learn.org/stable/modules/tree.html
>>> from sklearn.externals.six import StringIO
>>> import pydot
>>> dot_data = StringIO()
>>> tree.export_graphviz(clf, out_file=dot_data)
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue())
>>> graph.write_pdf("iris.pdf")
生成一个svg,无法在此处显示,因此您需要按照链接进行查看:http://scikit-learn.org/stable/_images/iris.svg
更新
自我回答这个问题以来,似乎行为已经发生了变化,现在返回一个list
,因此您会收到以下错误:
AttributeError: 'list' object has no attribute 'write_pdf'
首先,当您看到此内容时,值得打印该对象并检查该对象,很可能您想要的是第一个对象:
graph[0].write_pdf("iris.pdf")
感谢@NickBraunagel的评论
graphviz
。我使用conda的安装包这里( 推荐使用 pip install graphviz
不包含实际的 GraphViz 可执行文件)pydot
(pip install pydot
)graph
是一个包含 pydot.Dot
对象的列表):from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph[0].write_pdf("iris.pdf") # must access graph's first element
我知道有以下4种方法可以绘制scikit-learn决策树:
sklearn.tree.export_text
方法打印树的文本表示sklearn.tree.plot_tree
方法绘制(需要matplotlib
)sklearn.tree.export_graphviz
方法绘制(需要graphviz
)dtreeviz
包绘制(需要dtreeviz
和graphviz
)最简单的方法是将决策树导出为文本表示。示例决策树如下所示:
|--- feature_2 <= 2.45
| |--- class: 0
|--- feature_2 > 2.45
| |--- feature_3 <= 1.75
| | |--- feature_2 <= 4.95
| | | |--- feature_3 <= 1.65
| | | | |--- class: 1
| | | |--- feature_3 > 1.65
| | | | |--- class: 2
| | |--- feature_2 > 4.95
| | | |--- feature_3 <= 1.55
| | | | |--- class: 2
| | | |--- feature_3 > 1.55
| | | | |--- feature_0 <= 6.95
| | | | | |--- class: 1
| | | | |--- feature_0 > 6.95
| | | | | |--- class: 2
| |--- feature_3 > 1.75
| | |--- feature_2 <= 4.85
| | | |--- feature_1 <= 3.10
| | | | |--- class: 2
| | | |--- feature_1 > 3.10
| | | | |--- class: 1
| | |--- feature_2 > 4.85
| | | |--- class: 2
如果您已安装matplotlib
,那么您可以使用sklearn.tree.plot_tree
绘制图表:
tree.plot_tree(clf) # the clf is your decision tree model
export_graphviz
获得的输出类似:
dtreeviz
包。 它将为您提供更多信息。 示例:
AttributeError: 'list' object has no attribute 'write_pdf'
。我该怎么解决? - Ernest Soopydot.graph_from_dot_data()
返回所需的graph
(即pydot.Dot
对象),但它将其返回到一个list
中:因此,访问列表的第一个对象以访问pydot.Dot
对象:graph[0].write_pdf("iris.pdf")
。 - NickBraunagel