有没有办法在决策树的每个叶子节点下获取样本?

8

我已经使用数据集训练了一棵决策树。现在我想看看哪些样本属于树的哪个叶子节点。

从这里开始,我想要红色圆圈中的样本。

enter image description here

我正在使用Python的Sklearn实现决策树。


1
这个:https://dev59.com/5FwY5IYBdhLWcg3wh4Pc 和这个:https://dev59.com/E2Ij5IYBdhLWcg3weFCq#42227468 可能是相关的。 - Miriam Farber
左上角的叶子是故意省略的吗? - Maximilian Peters
1个回答

12

如果您只想要每个样本的叶子节点,您可以直接使用:

clf.apply(iris.data)
array([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5,5,5,5,5,5,5,5,5,5,5,5,5, 5,5,5,5,5,5,14,5,5,5,5,5,5,10,5,5,5,5,5,10,5,5,5,5,5,5,5,5,5, 5,5,5,5,5,5,16,16,16,16,16,16,6,16,16,16,16,16,16,16,16,16,16, 16,8,16,16,16,16,16,16,15,16,16,11,16,16,16,8,8,16,16,16,15,16, 16,16,16,16,16,16,16,16,16])
dec_paths = clf.decision_path(iris.data)

然后循环遍历决策路径,使用 toarray() 将其转换为数组,并检查它们是否属于某个节点。所有内容都存储在 defaultdict 中,其中键是节点编号,值是样本编号。

for d, dec in enumerate(dec_paths):
    for i in range(clf.tree_.node_count):
        if dec.toarray()[0][i] == 1:
            samples[i].append(d)

完整的代码

import sklearn.datasets
import sklearn.tree
import collections

clf = sklearn.tree.DecisionTreeClassifier(random_state=42)
iris = sklearn.datasets.load_iris()
clf = clf.fit(iris.data, iris.target)

samples = collections.defaultdict(list)
dec_paths = clf.decision_path(iris.data)

for d, dec in enumerate(dec_paths):
    for i in range(clf.tree_.node_count):
        if dec.toarray()[0][i] == 1:
            samples[i].append(d) 

输出

print(samples[13])

[70, 126, 138]


1
@AlaaM. 你可以运行 clf.decision_path(my_test_samples) 并获得这些样本的决策路径。 - Maximilian Peters
1
@AlaaM 请看这个答案:https://dev59.com/gqDia4cB1Zd3GeqPGp4q#43218264,如果您传入一个样本,您可以将所有只有一个样本的节点着色,并且您可以可视化该特定样本的决策。 - Maximilian Peters
嗨!我正在尝试你的代码,但是我的绘制的随机森林决策树在节点0上有180个样本和29个节点。而你的代码返回23个样本[i]和651个样本。我错过了什么吗? - Noob Programmer
@NoobProgrammer:你能否开一个新问题,把你的代码放上去?我还没有尝试过随机森林的代码。 - Maximilian Peters
@MaximilianPeters https://stackoverflow.com/questions/69852142/extracting-samples-indices-of-decision-trees-in-random-forest - Noob Programmer
显示剩余6条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接