sklearn.tree.export_graphviz替代方案

5
可以使用pypi中的pydotplus来可视化决策树,但我的机器上出现了问题(它说没有使用libexpat构建,因此它只显示节点上的一个数字而不是带有一些信息的表格),我想使用另一种替代方法。我已经尝试使用networkx,但它需要pygraphviz来读取.dot文件并制作它们的networkx图形。当我试图使用pip安装它时也失败了。
所以现在我正在寻找另一种可视化决策树的替代方法,可以使用pip或anaconda安装。
有哪些替代方案? 编辑#1 conda list的输出:
# packages in environment at /home/xiaolong/development/anaconda3/envs/coursera_ml_classification:
#
alabaster                 0.7.7                    py34_0    defaults
awscli                    1.6.2                     <pip>
babel                     2.3.3                    py34_0    defaults
backports                 1.0                      py34_0    defaults
backports-abc             0.4                       <pip>
backports.shutil-get-terminal-size 1.0.0                     <pip>
backports_abc             0.4                      py34_0    defaults
bcdoc                     0.12.2                    <pip>
boto                      2.33.0                    <pip>
botocore                  0.73.0                    <pip>
cairo                     1.12.18                       6    defaults
certifi                   2015.4.28                 <pip>
colorama                  0.2.5                     <pip>
cycler                    0.10.0                   py34_0    defaults
decorator                 4.0.9                    py34_0    defaults
docutils                  0.12                     py34_0    defaults
entrypoints               0.2                      py34_1    defaults
expat                     2.1.0                         0    defaults
fontconfig                2.11.1                        5    defaults
freetype                  2.5.5                         0    defaults
get_terminal_size         1.0.0                    py34_0    defaults
glib                      2.43.0                        2    asmeurer
graphviz                  2.38.0                        1    defaults
harfbuzz                  0.9.39                        0    defaults
imagesize                 0.7.0                    py34_0    defaults
ipykernel                 4.3.1                    py34_0    defaults
ipython                   4.2.0                    py34_0    defaults
ipython-genutils          0.1.0                     <pip>
ipython_genutils          0.1.0                    py34_0    defaults
ipywidgets                4.1.1                    py34_0    defaults
jedi                      0.9.0                    py34_0    defaults
jinja2                    2.8                      py34_0    defaults
jmespath                  0.5.0                     <pip>
jsonschema                2.5.1                    py34_0    defaults
jupyter                   1.0.0                    py34_2    defaults
jupyter-client            4.2.2                     <pip>
jupyter-console           4.1.1                     <pip>
jupyter-core              4.1.0                     <pip>
jupyter_client            4.2.2                    py34_0    defaults
jupyter_console           4.1.1                    py34_0    defaults
jupyter_core              4.1.0                    py34_0    defaults
libffi                    3.2.1                         0    defaults
libgcc                    5.2.0                         0    defaults
libgfortran               3.0.0                         1    defaults
libpng                    1.6.17                        0    defaults
libsodium                 1.0.3                         0    defaults
libxml2                   2.9.2                         0    defaults
llvmlite                  0.10.0                   py34_0    defaults
markupsafe                0.23                     py34_0    defaults
matplotlib                1.5.1               np111py34_0    defaults
mistune                   0.7.2                    py34_0    defaults
mkl                       11.3.1                        0    defaults
multipledispatch          0.4.8                     <pip>
nbconvert                 4.2.0                    py34_0    defaults
nbformat                  4.0.1                    py34_0    defaults
notebook                  4.2.0                    py34_0    defaults
numpy                     1.11.0                   py34_0    defaults
openssl                   1.0.2h                        0    defaults
pandas                    0.18.1              np111py34_0    defaults
pango                     1.39.0                        0    defaults
path.py                   8.2.1                    py34_0    defaults
pep8                      1.7.0                    py34_0    defaults
pexpect                   4.0.1                    py34_0    defaults
pickleshare               0.5                      py34_0    defaults
pip                       8.1.1                    py34_1    defaults
pixman                    0.32.6                        0    defaults
prettytable               0.7.2                     <pip>
psutil                    4.1.0                    py34_0    defaults
ptyprocess                0.5                      py34_0    defaults
pyasn1                    0.1.9                     <pip>
pydotplus                 2.0.2                    py34_0    file:///home/xiaolong/development/anaconda3/conda-bld/linux-64/pydotplus-2.0.2-py34_0.tar.bz2
pyflakes                  1.1.0                    py34_0    defaults
pygments                  2.1.3                    py34_0    defaults
pyparsing                 2.1.1                    py34_0    defaults
pyqt                      4.11.4                   py34_1    defaults
python                    3.4.4                         0    defaults
python-contrib-nbextensions alpha                     <pip>
python-dateutil           2.5.2                    py34_0    defaults
pytz                      2016.3                   py34_0    defaults
pyyaml                    3.11                      <pip>
pyzmq                     15.2.0                   py34_0    defaults
qt                        4.8.7                         1    defaults
qtconsole                 4.2.1                    py34_0    defaults
readline                  6.2                           2    defaults
requests                  2.9.1                     <pip>
rope                      0.9.4                    py34_1    defaults
rope-py3k                 0.9.4.post1               <pip>
rsa                       3.1.2                     <pip>
scikit-learn              0.17.1              np111py34_0    defaults
scipy                     0.17.0              np111py34_3    defaults
setuptools                20.7.0                   py34_0    defaults
sframe                    1.8.5                     <pip>
simplegeneric             0.8.1                    py34_0    defaults
sip                       4.16.9                   py34_0    defaults
six                       1.10.0                   py34_0    defaults
snowballstemmer           1.2.1                    py34_0    defaults
sphinx                    1.4.1                    py34_0    defaults
sphinx-rtd-theme          0.1.9                     <pip>
sphinx_rtd_theme          0.1.9                    py34_0    defaults
spyder                    2.3.8                    py34_1    defaults
sqlite                    3.9.2                         0    defaults
terminado                 0.5                      py34_1    defaults
tk                        8.5.18                        0    defaults
tornado                   4.3                      py34_0    defaults
traitlets                 4.2.1                    py34_0    defaults
wheel                     0.29.0                   py34_0    defaults
xz                        5.0.5                         1    defaults
zeromq                    4.1.3                         0    defaults
zlib                      1.2.8                         0    defaults

SciPy 版本: 0.17.0

digraph Tree {
node [shape=box, style="filled", color="black"] ;
0 [label="grade.B <= 0.5\ngini = 0.5\nsamples = 37224\nvalue = [18476, 18748]", fillcolor="#399de504"] ;
1 [label="grade.C <= 0.5\ngini = 0.4973\nsamples = 32094\nvalue = [17218, 14876]", fillcolor="#e5813923"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.4829\nsamples = 21728\nvalue = [12875, 8853]", fillcolor="#e5813950"] ;
1 -> 2 ;
3 [label="gini = 0.4869\nsamples = 10366\nvalue = [4343, 6023]", fillcolor="#399de547"] ;
1 -> 3 ;
4 [label="grade.A <= 14.8301\ngini = 0.3702\nsamples = 5130\nvalue = [1258, 3872]", fillcolor="#399de5ac"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
5 [label="gini = 0.3555\nsamples = 4987\nvalue = [1153, 3834]", fillcolor="#399de5b2"] ;
4 -> 5 ;
6 [label="gini = 0.3902\nsamples = 143\nvalue = [105, 38]", fillcolor="#e58139a3"] ;
4 -> 6 ;
}

编辑#2

我在Jupyter笔记本中编写了这个程序,但是如果您尝试使用以下方式显示SVG,则会出现无法着色的错误:

![Decision Tree]('dtree.svg')

我在这里找到了一个解决办法:链接

from IPython.display import HTML

svg = None
with open('dtree.svg') as svg_file:
    svg = svg_file.read()

HTML(svg)
2个回答

6

虽然不是最性感的解决方案,但我使用Grapviz CLI(它被称为dot)通过subprocess来调用它。我用的是Mac系统,所以我使用homebrew安装了它,但你可以从它们的下载页面下载其他平台的二进制文件。以下是一个使用Titanic数据集的示例:

import pandas as pd
import subprocess
import seaborn.apionly as sns
fromwd sklearn.preprocessing import Imputer
from sklearn.tree import DecisionTreeClassifier, export_graphviz

raw_data = sns.load_dataset('titanic')
predictors = ['pclass','sex','age','sibsp','parch','fare','embarked','alone','adult_male']
categorical = ['sex','embarked']
numeric = [c for c in predictors if c not in categorical]
target='survived'

encoded_data = pd.get_dummies(raw_data[predictors], columns=categorical)

imputer = Imputer()
X = imputer.fit_transform(encoded_data).astype('float32')
Y = raw_data[target].astype('float32')

model = DecisionTreeClassifier(min_samples_leaf=10, max_depth=3)
model.fit(X, Y)

export_graphviz(model,
                out_file='tree.dot',
                feature_names=encoded_data.columns,
                proportion=True,
                filled=True,
                impurity=False)

subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])

尝试过并且有效。不过图表缺少着色。这个有简单的解决方法吗?别管了,着色是由pydotplus生成的,但是它无法正确地渲染标签,所以这听取你的选择。 - Zelphir Kaltstahl
我这边能看到颜色,你用的是哪个平台? - maxymoo
我使用的是Linux系统,但问题在于pydotplus也给了我一个警告,它没有使用libexpat进行构建,我没有找到解决方法,仍然可以在virtualenv中使用pip或更好的anaconda进行安装。因此,如果没有libexpat,它无法将表格数据显示为节点标签。但是,如果像答案中建议的那样使用dot工具本身,则我的dot文件中似乎没有颜色信息,因此似乎只有逻辑上不存在任何从dot文件生成的图形中的颜色信息。我做错了什么吗?毕竟,颜色确实很不错:) - Zelphir Kaltstahl
你正在使用最新的sklearn版本吗?我的点文件中有颜色信息,第一个节点是[label="adult_male <= 0.5\nsamples = 100.0%\nvalue = [0.62, 0.38]", fillcolor="#e5813960"] ; - maxymoo
1
请查看原帖或者这个链接:https://github.com/scikit-learn/scikit-learn/issues/6522 :) - Zelphir Kaltstahl
显示剩余3条评论

4

从版本0.21开始,scikit-learn拥有了plot_tree方法,可以使用matplotlib绘制树形图。

使用plot_tree的代码:

from sklearn import tree
# the clf is Decision Tree object
tree.plot_tree(clf,feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

替代sklearn图表的选择可以是dtreeviz包。 下面是一棵树的示例。 使用dtreeviz的代码:

from dtreeviz.trees import dtreeviz # remember to load the package
# the clf is Decision Tree object
viz = dtreeviz(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz

您可以在这里找到不同的scikit-learn树绘图技术的比较。

dtreeviz决策树可视化


1
看起来很不错。你能引用生成代码吗?这会提高你的回答质量。 - Zelphir Kaltstahl

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接