如何解释sklearn决策树中的children_left属性?

3
我正在尝试使用sklearn DecisionTreeClassifier中的'tree_'方法提取最深节点的规则。我很难理解模型中'children_left'和'children_right'数组的含义。有人能帮忙解释一下吗?
estimator = DecisionTreeClassifier(max_depth=4, random_state=0)
estimator.fit(X_train, y_train)
estimator.tree_.children_left

[6] array([ 1,  2,  3,  4,  5, -1, -1,  8, -1, -1, 11, 12, -1, -1, 15, -1, -1,
   18, 19, 20, -1, -1, 23, -1, -1, 26, 27, -1, -1, 30, -1, -1, 33, 34,
   35, 36, -1, -1, 39, -1, -1, 42, 43, -1, -1, 46, -1, -1, 49, 50, 51,
   -1, -1, 54, -1, -1, 57, 58, -1, -1, 61, -1, -1])

tree_model.tree_.children_right

[7] array([32, 17, 10,  7,  6, -1, -1,  9, -1, -1, 14, 13, -1, -1, 16, -1, -1,
   25, 22, 21, -1, -1, 24, -1, -1, 29, 28, -1, -1, 31, -1, -1, 48, 41,
   38, 37, -1, -1, 40, -1, -1, 45, 44, -1, -1, 47, -1, -1, 56, 53, 52,
   -1, -1, 55, -1, -1, 60, 59, -1, -1, 62, -1, -1])

在Sklearn的示例中,http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html,它说:
`# The decision estimator has an attribute called tree_  which stores the    entire
# tree structure and allows access to low level attributes. The binary tree
# tree_ is represented as a number of parallel arrays. The i-th element of  each
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
# Some of the arrays only apply to either leaves or split nodes, resp.`

但是它并没有解释 children_left 数组中数字的含义。
3个回答

6
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
children_left = clf.tree_.children_left
print (children_left)

它会打印出以下内容:
[ 1 -1  3  4  5 -1 -1  8 -1 10 -1 -1 13 14 -1 -1 -1]

您可以在Google上找到17个使用鸢尾花数据的节点决策树。查看并将其与解释进行比较。
现在开始解释:
- 它只代表节点的左子节点。 - 如果值为-1,则表示该节点没有左子节点。这是该决策树的叶节点。在这里,我们可以看到有9个叶节点。 - 如果值是> 0,则它具有左节点。因此,它不是叶节点。这里有8个节点不是叶节点。
- 根节点有一个左节点。它是1。 - 现在1没有任何左节点。所以它是叶子节点。因此为-1。如果一个节点没有左节点,则也会增加节点计数为1。所以现在节点数为2。 - 现在我们回溯到根节点。然后我们转到根的右节点。它有左节点。现在节点计数为3。 - 节点3有另一个左节点。节点计数为4。 - 节点4有另一个左节点。节点计数为5。 - 节点5没有任何左节点。所以它是叶子,并显示为-1。但现在节点计数为6。 - 我们回溯到节点4。我们转到它的右子节点。再次没有任何左子节点。所以它是叶子节点,显示为-1。节点计数为7。 - 我们再次回溯到节点3。我们转到它的右子节点。这个节点有左节点。所以现在节点计数为8。
它继续进行。希望您能理解这个解释。

0

只是想向您展示一个可视化决策树的小技巧。您可以在所选绘图函数(在我的情况下为export_graphviz)中指定参数node_ids = True,它将在树的图像上显示节点ID!

    export_graphviz(clf, out_file=dot_data, node_ids=True, 
            filled=True, rounded=True,
            special_characters=True,feature_names = feature_cols,class_names=['0','1'])

鸢尾花图

!!! :)


0

来自帖子:https://github.com/scikit-learn/scikit-learn/blob/4907029b1ddff16b111c501ad010d5207e0bd177/sklearn/tree/_tree.pyx

   children_left : array of int, shape [node_count]
    children_left[i] holds the node id of the left child of node i.
    For leaves, children_left[i] == TREE_LEAF. Otherwise,
    children_left[i] > i. This child handles the case where
    X[:, feature[i]] <= threshold[i].
children_right : array of int, shape [node_count]
    children_right[i] holds the node id of the right child of node i.
    For leaves, children_right[i] == TREE_LEAF. Otherwise,
    children_right[i] > i. This child handles the case where
    X[:, feature[i]] > threshold[i].

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接