在R中,我可以使用API直接绘制与CART模型对应的决策树的图形表示。例如,prp
将生成类似于以下内容:
sklearn
的RandomForestClassifier
和DecisionTreeClassifier
都没有绘制树的方法或功能。我该如何在Python中获得CART或随机森林树的图形表示?
在R中,我可以使用API直接绘制与CART模型对应的决策树的图形表示。例如,prp
将生成类似于以下内容:
sklearn
的RandomForestClassifier
和DecisionTreeClassifier
都没有绘制树的方法或功能。export_graphviz
函数。from sklearn.tree import DecisionTreeClassifier, export_graphviz
np.random.seed(0)
X = np.random.randn(10, 4)
y = array(["foo", "bar", "baz"])[np.random.randint(0, 3, 10)]
clf = DecisionTreeClassifier(random_state=42).fit(X, y)
export_graphviz(clf)
dotty tree.dot
应该显示类似于以下内容的东西:
这里有一个笔记本。请注意保留 HTML 标记。X[.]
。有这样的选项吗? - oromeRandomForestClassifier
树怎么样?(2)我如何在您的示例中使用漂亮的字体? - oromeRandomForestClassifier.estimators_
上运行导出器 (2) 我认为我在图片上使用了 dot -Tpng
而不是 dotty
。 - Fred Foo除了这里列出的其他方法外,从scikit-learn 21.0版本开始(大约在2019年5月),可以使用scikit-learn的tree.plot_tree函数,无需依赖graphviz即可使用matplotlib绘制决策树。
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
X, y = load_iris(return_X_y=True)
# Make an instance of the Model
clf = DecisionTreeClassifier()
# Train the model on the data
clf.fit(X, y)
fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']
# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)
tree.plot_tree(clf,
feature_names = fn,
class_names=cn,
filled = True);
fig.savefig('imagename.png')
# Imports
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image, display
import pydotplus
def jupyter_graphviz(m, **kwargs):
dot_data = StringIO()
export_graphviz(m, dot_data, **kwargs)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
display(Image(graph.create_png()))
例如:
import sklearn.datasets as datasets
import pandas as pd
iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
dtree = DecisionTreeClassifier(random_state=42)
dtree.fit(df, y)
jupyter_graphviz(dtree, filled=True, rounded=True, special_characters=True)