Scikit-Learn决策树导出Graphviz图形 - 决策树中的类名错误

4

我从"scikit learn/decision tree/export graphviz"中得到的决策树类名不正确。代码如下:

import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree

digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']

digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)

with open("digital.dot", 'w') as f:
    f = tree.export_graphviz(digital_tree, 
                            feature_names=digital_name,
                            class_names=digital_label,
                            filled=True, rounded=True,
                            out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")

plt.imshow(img.imread('digital.png'))
plt.show()

输出如下: 决策树 问题在于叶子节点中显示的类名。例如,如果idx-1为1且idx-2为1,则绿色框应标记为“three”。但是,图像显示标签为“one”。有人可以发表评论吗?

如果有人还有疑问,那么问题在于该模型是针对字符串标签 ['zero', 'one', 'two', 'three'] 进行训练的。函数不知道哪个应该被称为零,哪个应该被称为一。因此,它最终按字母顺序使用它们,one 变成了 0,three 变成了 2,以此类推。处理这个问题的最佳方法是将标签转换为整数类,如 [0, 1, 2, 3] - betelgeuse
2个回答

4
当您使用DecisionTreeClassifier时,应该将类标签更改为数字,例如0、1、2。
然后使用:
classe_names = decision_tree_classifier.classes_

它将按升序给出类的标签。然后以相同的顺序指定您的class_label。它可以是字符串。


-2

在将类标签传递给export_graphviz之前,尝试按字母顺序对其进行排序。


1
谢谢您的评论。但是,我认为表格元素的顺序和标签元素的顺序应该保持同步,对吗? - Frank

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接