一种选择是将rpart
对象转换为partykit
包中的party
类对象。这提供了一个通用工具箱来处理递归分区。转换很简单:
library("partykit")
(kyph_party <- as.party(kyph_tree))
Model formula:
Number ~ Kyphosis + Age + Start
Fitted party:
[1] root
| [2] Start >= 15.5: 2.933 (n = 15, err = 10.9)
| [3] Start < 15.5
| | [4] Age >= 112.5: 3.714 (n = 14, err = 18.9)
| | [5] Age < 112.5: 5.125 (n = 16, err = 29.8)
Number of inner nodes: 2
Number of terminal nodes: 3
(为确保精确重现,请在运行我的代码之前使用
set.seed(1)
运行您问题中的代码。)
对于这个类的对象,有一些更灵活的方法可以用于
plot()
、
predict()
、
fitted()
等。例如,
plot(kyph_party)
比默认的
plot(kyph_tree)
提供了更多信息的显示。
fitted()
方法提取一个具有拟合节点编号和训练数据上观察到的响应的双列
data.frame
。
kyph_fit <- fitted(kyph_party)
head(kyph_fit, 3)
(fitted) (response)
1 5 6
2 2 2
3 4 3
使用此功能,您可以轻松计算您感兴趣的任何量,例如每个节点内的平均值、中位数或残差平方和。
tapply(kyph_fit[,2], kyph_fit[,1], mean)
2 4 5
2.933333 3.714286 5.125000
tapply(kyph_fit[,2], kyph_fit[,1], median)
2 4 5
3 4 5
tapply(kyph_fit[,2], kyph_fit[,1], function(x) sum((x - mean(x))^2))
2 4 5
10.93333 18.85714 29.75000
除了简单的tapply()
,您可以使用任何其他函数来计算分组统计表。
现在,要了解测试数据TEST_KYPHOSIS
中哪个观察结果对应于树中的哪个节点,您可以简单地使用predict(..., type = "node")
方法:
kyph_pred <- predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node")
head(kyph_pred)
2 3 4 6 7 10
4 4 5 2 2 5
rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))
会导致一个预测节点? - goldisfine