可视化决策树(来自scikit-learn的示例)

10

我是scikit-learn的新手,请多包容。

我正在学习这个示例: http://scikit-learn.org/stable/modules/tree.html#tree

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> iris = load_iris()
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(iris.data, iris.target)
>>> from StringIO import StringIO
>>> out = StringIO()
>>> out = tree.export_graphviz(clf, out_file=out)

显然,graphiz文件已经可以使用。

但是我该如何使用graphiz文件绘制树?(示例没有详细说明如何绘制树)。

示例代码和提示将不胜感激!

谢谢!


更新

我正在使用Ubuntu 12.04,Python 2.7.3。


Scikit-learn从0.21版本开始提供了plot_tree方法,比导出到graphviz更易于使用。不过,还有一个非常好的包dtreeviz。这里是sklearn树可视化方法的比较:博客文章链接 - pplonski
2个回答

5
你运行的操作系统是什么?是否已安装graphviz
在你的示例中,StringIO()对象保存了graphviz数据,以下是一种检查数据的方法:
...
>>> print out.getvalue()

digraph Tree {
0 [label="X[2] <= 2.4500\nerror = 0.666667\nsamples = 150\nvalue = [ 50.  50.  50.]", shape="box"] ;
1 [label="error = 0.0000\nsamples = 50\nvalue = [ 50.   0.   0.]", shape="box"] ;
0 -> 1 ;
2 [label="X[3] <= 1.7500\nerror = 0.5\nsamples = 100\nvalue = [  0.  50.  50.]", shape="box"] ;
0 -> 2 ;
3 [label="X[2] <= 4.9500\nerror = 0.168038\nsamples = 54\nvalue = [  0.  49.   5.]", shape="box"] ;
2 -> 3 ;
4 [label="X[3] <= 1.6500\nerror = 0.0407986\nsamples = 48\nvalue = [  0.  47.   1.]", shape="box"] ;
3 -> 4 ;
5 [label="error = 0.0000\nsamples = 47\nvalue = [  0.  47.   0.]", shape="box"] ;
4 -> 5 ;
6 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
4 -> 6 ;
7 [label="X[3] <= 1.5500\nerror = 0.444444\nsamples = 6\nvalue = [ 0.  2.  4.]", shape="box"] ;
3 -> 7 ;
8 [label="error = 0.0000\nsamples = 3\nvalue = [ 0.  0.  3.]", shape="box"] ;
7 -> 8 ;
9 [label="X[0] <= 6.9500\nerror = 0.444444\nsamples = 3\nvalue = [ 0.  2.  1.]", shape="box"] ;
7 -> 9 ;
10 [label="error = 0.0000\nsamples = 2\nvalue = [ 0.  2.  0.]", shape="box"] ;
9 -> 10 ;
11 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
9 -> 11 ;
12 [label="X[2] <= 4.8500\nerror = 0.0425331\nsamples = 46\nvalue = [  0.   1.  45.]", shape="box"] ;
2 -> 12 ;
13 [label="X[0] <= 5.9500\nerror = 0.444444\nsamples = 3\nvalue = [ 0.  1.  2.]", shape="box"] ;
12 -> 13 ;
14 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  1.  0.]", shape="box"] ;
13 -> 14 ;
15 [label="error = 0.0000\nsamples = 2\nvalue = [ 0.  0.  2.]", shape="box"] ;
13 -> 15 ;
16 [label="error = 0.0000\nsamples = 43\nvalue = [  0.   0.  43.]", shape="box"] ;
12 -> 16 ;
}

你可以将其写成 .dot文件 的形式,并生成图像输出,如您提供的源代码所示: $ dot -Tpng tree.dot -o tree.png (PNG格式输出)

嗨,谢谢!我正在使用Ubuntu 12.04,Python版本为2.7.3。我想知道是否有办法在Python脚本中完成而不是在命令行中完成? - DjangoRocks
1
当然,只需获取其中一个可用的Python绑定到Graphviz,您就可以在Python shell中完成它。 - theta
有没有在Python3中执行该任务的方法? - soupault

4
你离正确答案很近了!只需要执行以下操作:
graph_from_dot_data(out.getvalue()).write_pdf("somefile.pdf")

1
只有当 #classes 足够小,以至于文本中的 nvalue 数组不会跨行断开时,此方法才有效...在这种情况下,我不得不手动搜索/替换 \n 为 ''(当然要保留合法的)...有点麻烦。同样适用于 one-hot 编码标签...它们会立即引发错误。 - user1269942

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接