使用rpart在回归树中搜索相应的节点

8
我很新于R语言,并卡在了一个相当愚蠢的问题上。
我正在使用rpart包来校准回归树,以进行分类和预测。
感谢R,校准部分易于完成且易于控制。
#the package rpart is needed
library(rpart)

# Loading of a big data file used for calibration
my_data <- read.csv("my_file.csv", sep=",", header=TRUE)

# Regression tree calibration
tree <- rpart(Ratio ~ Attribute1 + Attribute2 + Attribute3 + 
                      Attribute4 + Attribute5, 
                      method="anova", data=my_data, 
                      control=rpart.control(minsplit=100, cp=0.0001))

在校准了一棵大决策树之后,我想为给定的数据样本找到某些新数据的相应簇(从而得出预测值)。
predict函数似乎非常适合这个需求。

# read validation data
validationData <-read.csv("my_sample.csv", sep=",", header=TRUE)

# search for the probability in the tree
predict <- predict(tree, newdata=validationData, class="prob")

# dump them in a file
write.table(predict, file="dump.txt") 

然而,使用 predict 方法,我只能获得新元素的预测比率,但我找不到一种方法来获取新元素所属的决策树叶节点
我认为应该很容易获取,因为预测方法必须已经找到了这个叶节点才能返回这个比率。
通过class=参数,可以给预测方法提供几个参数,但是对于回归树,所有参数似乎都返回相同的结果(决策树目标属性的值)。
有人知道如何获得决策树中相应的节点吗?
通过分析使用path.rpart方法的节点,这将有助于我理解结果。

你尝试过使用“str()”探索你的对象吗? - Roman Luštrik
4个回答

13

不幸的是,Benjamin的答案并不起作用:type="vector"仍然返回预测值。

我的解决方案相当笨拙,但我认为没有更好的方法。诀窍是将模型框架中的预测y值替换为相应的节点编号。

tree2 = tree
tree2$frame$yval = as.numeric(rownames(tree2$frame))
predict = predict(tree2, newdata=validationData)

现在predict的输出将是节点编号,而不是预测的y值。

(需要注意的是,在我的案例中以上内容适用于回归树,而不是分类树。对于分类树,您可能需要省略as.numeric或将其替换为as.factor。)


2
您可以使用partykit软件包:
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)

library("partykit")
fit.party <- as.party(fit)
predict(fit.party, newdata = kyphosis[1:4, ], type = "node")

对于你的例子,只需设置


predict(as.party(tree), newdata = validationData, type = "node")

1

我认为你想要的是 type="vector" 而不是 class="prob"(我认为 class 不是 predict 方法的可接受参数),正如 rpart 文档中所解释的:

如果 type="vector":预测响应的向量。对于回归树,这是节点处的平均响应;对于泊松树,它是估计的响应率;对于分类树,它是预测的类别(作为数字)。


1
  1. treeClust::rpart.predict.leaves(tree, validationData) 返回节点编号。
  2. 同时,tree$where 返回训练集的节点编号。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接