我已经使用scikit-learn训练了一个随机森林模型,现在我想要将其树结构保存到文本文件中,以便在其他地方使用。根据此链接,树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如左子节点、右子节点、检查哪个特征等)。但是,在该链接提供的示例中甚至没有提到每个叶节点对应的类标签的任何信息!有人知道scikit-learn决策树结构中存储类标签的位置吗?
sklearn.tree.DecisionTreeClassifier.tree_.value
的文档:from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
clf.fit(iris.data, iris.target)
print(clf.classes_)
[0, 1, 2]
print(clf.tree_.value)
[[[ 50. 50. 50.]]
[[ 50. 0. 0.]]
[[ 0. 50. 50.]]
[[ 0. 49. 5.]]
[[ 0. 47. 1.]]
[[ 0. 47. 0.]]
[[ 0. 0. 1.]]
[[ 0. 2. 4.]]
[[ 0. 0. 3.]]
[[ 0. 2. 1.]]
[[ 0. 2. 0.]]
[[ 0. 0. 1.]]
[[ 0. 1. 45.]]
[[ 0. 1. 2.]]
[[ 0. 0. 2.]]
[[ 0. 1. 0.]]
[[ 0. 0. 43.]]]
clf.tree_.value
中“包含每个节点的常数预测值”,(help(clf.tree_)
)这与clf.classes_
按索引对应。参见此答案以获取(仅有的)更多细节。
clf.classes_[np.argmax(value)]
来获取预测的类别标签。 - Vivek Kumar