scikit-learn在其树结构中保存每个叶节点的决策标签放在哪里?

10
我已经使用scikit-learn训练了一个随机森林模型,现在我想要将其树结构保存到文本文件中,以便在其他地方使用。根据此链接,树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如左子节点、右子节点、检查哪个特征等)。但是,在该链接提供的示例中甚至没有提到每个叶节点对应的类标签的任何信息!有人知道scikit-learn决策树结构中存储类标签的位置吗?
1个回答

7
请参阅 sklearn.tree.DecisionTreeClassifier.tree_.value 的文档:
from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()

clf.fit(iris.data, iris.target)

print(clf.classes_)

[0, 1, 2]

print(clf.tree_.value)

[[[ 50.  50.  50.]]

 [[ 50.   0.   0.]]

 [[  0.  50.  50.]]

 [[  0.  49.   5.]]

 [[  0.  47.   1.]]

 [[  0.  47.   0.]]

 [[  0.   0.   1.]]

 [[  0.   2.   4.]]

 [[  0.   0.   3.]]

 [[  0.   2.   1.]]

 [[  0.   2.   0.]]

 [[  0.   0.   1.]]

 [[  0.   1.  45.]]

 [[  0.   1.   2.]]

 [[  0.   0.   2.]]

 [[  0.   1.   0.]]

 [[  0.   0.  43.]]]

每行在clf.tree_.value中“包含每个节点的常数预测值”,(help(clf.tree_))这与clf.classes_按索引对应。参见此答案以获取(仅有的)更多细节。

7
除了上述答案,对于数组中的每一行,您可以执行 clf.classes_[np.argmax(value)] 来获取预测的类别标签。 - Vivek Kumar
@not_a_robot 谢谢。你解释得很好。然而,我仍然找不到 clf.tree_.value 在文档中的提及。我猜我不再需要它,因为你的答案正是我要找的。 - whoAmI
1
只是另一个快速问题。看起来 clf.classes_ 给我标签 [0,...,n-1],无论我使用什么标签。我是对的吗?在我的情况下,我期望是 [1,...,n]。 - whoAmI
1
我相信标签是从零开始索引的,这就是为什么它是[0,n-1]。 - blacksite

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接