决策树重复类名问题

3
我有一个非常简单的数据/标签样本,我的问题是生成的决策树(pdf)重复了类名:
from sklearn import tree
from sklearn.externals.six import StringIO  
import pydotplus

features_names = ['weight', 'texture']
features = [[140, 1], [130, 1], [150, 0], [110, 0]]
labels = ['apple', 'apple', 'orange', 'orange']

clf = tree.DecisionTreeClassifier()
clf.fit(features, labels)

dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, 
                         feature_names=features_names,  
                         class_names=labels,  
                         filled=True, rounded=True,  
                         special_characters=True,
                         impurity=False)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
graph.write_pdf("apples_oranges.pdf")

生成的PDF文件如下:

enter image description here

所以,问题很明显,两种可能性都是苹果。我做错了什么?
来自 DOCS
列表字符串、布尔值或None,可选(默认为None) 每个目标类别的名称按升序排列。仅适用于分类,不支持多输出。如果为True,则显示类名的符号表示。
“…按升序排列”的意思对我来说不太清楚,如果我将kwarg更改为:
class_names=sorted(labels)

结果是一样的(在这种情况下很明显)。

类名字面上就是类的名称。它不是每个示例的标签。因此,类0是“apple”,类1是“orange”,所以我想你只需要传入['apple', 'orange']即可。 - Ken Syme
尝试使用class_names=unique(labels, 'stable') - Dan
@KenSyme,谢谢你。我想这个问题必须要排序对吧?就像 sorted(set(labels)) 这样,因为如果我不这样做的话,它会显示错误的位置(交换)。如果你愿意,你可以回答,我会尽快接受它。 - Hula Hula
2个回答

2
类名字面上就是类的名称。它不是每个示例的标签。
因此,一个类是“apple”,另一个类是“orange”,所以您只需要传入['apple','orange']
关于顺序,为了使其正确一致,您可以使用LabelEncoder将目标转换为整数int_labels = labelEncoder.fit_transform(labels),使用int_labels来拟合决策树,然后使用labelEncoder.classes_属性传递到您的图形可视化工具中。

0

类名应该是您的标签名称集合,并按升序传递。您可以直接这样做

labels_set = sorted(labels.unique())

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接