如果您只想要每个样本的叶子节点,您可以直接使用:
clf.apply(iris.data)
array([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5,5,5,5,5,5,5,5,5,5,5,5,5,
5,5,5,5,5,5,14,5,5,5,5,5,5,10,5,5,5,5,5,10,5,5,5,5,5,5,5,5,5,
5,5,5,5,5,5,16,16,16,16,16,16,6,16,16,16,16,16,16,16,16,16,16,
16,8,16,16,16,16,16,16,15,16,16,11,16,16,16,8,8,16,16,16,15,16,
16,16,16,16,16,16,16,16,16])dec_paths = clf.decision_path(iris.data)
然后循环遍历决策路径,使用 toarray()
将其转换为数组,并检查它们是否属于某个节点。所有内容都存储在 defaultdict
中,其中键是节点编号,值是样本编号。
for d, dec in enumerate(dec_paths):
for i in range(clf.tree_.node_count):
if dec.toarray()[0][i] == 1:
samples[i].append(d)
完整的代码
import sklearn.datasets
import sklearn.tree
import collections
clf = sklearn.tree.DecisionTreeClassifier(random_state=42)
iris = sklearn.datasets.load_iris()
clf = clf.fit(iris.data, iris.target)
samples = collections.defaultdict(list)
dec_paths = clf.decision_path(iris.data)
for d, dec in enumerate(dec_paths):
for i in range(clf.tree_.node_count):
if dec.toarray()[0][i] == 1:
samples[i].append(d)
输出
print(samples[13])
[70, 126, 138]
clf.decision_path(my_test_samples)
并获得这些样本的决策路径。 - Maximilian Peters