我该如何在Python中绘制一个CART树,就像在R中一样?

6

在R中,我可以使用API直接绘制与CART模型对应的决策树的图形表示。例如,prp将生成类似于以下内容:

但我找不到与Python中等效功能相似的API。例如,据我所知,sklearnRandomForestClassifierDecisionTreeClassifier都没有绘制树的方法或功能。
我该如何在Python中获得CART或随机森林树的图形表示?

回归曲面(不管它们是什么,我刚刚发现它们)也会很棒。 - orome
1
看起来在Python/scikit中有一种绘制树的方法。http://scikit-learn.org/stable/modules/tree.html - hrbrmstr
3个回答

7
使用export_graphviz函数。
from sklearn.tree import DecisionTreeClassifier, export_graphviz
np.random.seed(0)
X = np.random.randn(10, 4)
y = array(["foo", "bar", "baz"])[np.random.randint(0, 3, 10)]
clf = DecisionTreeClassifier(random_state=42).fit(X, y)
export_graphviz(clf)

现在,dotty tree.dot 应该显示类似于以下内容的东西:

tree visualization

这里有一个笔记本。请注意保留 HTML 标记。

这个括号里的话让我很好奇!链接在哪里?时间范围?软件包? - orome
我无法与GraphViz正常配合。例如,要将导出的dot转换为PNG,需要重新启动我的IPython内核。是否有一种方法可以创建dot、生成PDF或PNG(或SVG),并将其加载到我的IPython笔记本中,以便我可以查看它? - orome
我还希望能够用我的数据中相应的列名替换X[.]。有这样的选项吗? - orome
2
一些后续问题:(1)RandomForestClassifier树怎么样?(2)我如何在您的示例中使用漂亮的字体? - orome
1
@raxacoricofallapatorius (1) 在所有 RandomForestClassifier.estimators_ 上运行导出器 (2) 我认为我在图片上使用了 dot -Tpng 而不是 dotty - Fred Foo
我发布了一个适用于Jupyter的答案:https://dev59.com/Nn7aa4cB1Zd3GeqPoU__#53546533 - Max Ghenis

2

除了这里列出的其他方法外,从scikit-learn 21.0版本开始(大约在2019年5月),可以使用scikit-learn的tree.plot_tree函数,无需依赖graphviz即可使用matplotlib绘制决策树。

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

X, y = load_iris(return_X_y=True)

# Make an instance of the Model
clf = DecisionTreeClassifier()

# Train the model on the data
clf.fit(X, y)

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']

# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

tree.plot_tree(clf,
           feature_names = fn, 
           class_names=cn,
           filled = True);

fig.savefig('imagename.png')

下面的图片是保存的内容。 enter image description here 这段代码是从这篇文章中改编而来的:post

1
这个函数将在Jupyter笔记本中显示图形:
# Imports
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image, display
import pydotplus

def jupyter_graphviz(m, **kwargs):
    dot_data = StringIO()
    export_graphviz(m, dot_data, **kwargs)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
    display(Image(graph.create_png()))

例如:

import sklearn.datasets as datasets
import pandas as pd

iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
dtree = DecisionTreeClassifier(random_state=42)
dtree.fit(df, y)

jupyter_graphviz(dtree, filled=True, rounded=True, special_characters=True)

Tree visualization

这里有一个笔记本电脑的操作示例,改编自此篇文章


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接