如何从rpart对象中获取新观测的终端节点?

6

假设我有一个

head(kyphosis)
inTrain <- sample(1:nrow(kyphosis), 45, replace = F)
TRAIN_KYPHOSIS <- kyphosis[inTrain,]
TEST_KYPHOSIS <- kyphosis[-inTrain,]

(kyph_tree <- rpart(Number ~ ., data = TRAIN_KYPHOSIS))

如何从已拟合的TEST_KYPHOSIS对象中获取每个观测值的终端节点?

如何获得汇总信息,例如偏差和预测值,以及与每个测试观测值对应的终端节点?

2个回答

8

rpart实际上具有这个功能,但它没有公开(令人奇怪的是,这是一个相当明显的要求)。

predict_nodes <-
    function (object, newdata, na.action = na.pass) {
        where <-
            if (missing(newdata)) 
                object$where
            else {
                if (is.null(attr(newdata, "terms"))) {
                    Terms <- delete.response(object$terms)
                    newdata <- model.frame(Terms, newdata, na.action = na.action, 
                                           xlev = attr(object, "xlevels"))
                    if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                        .checkMFClasses(cl, newdata, TRUE)
                }
                rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))
            }
        as.integer(row.names(object$frame))[where]
    }

然后:
> predict_nodes(kyph_tree, TEST_KYPHOSIS)
 [1] 5 3 4 3 3 5 5 3 3 3 3 5 5 4 3 5 4 3 3 3 3 4 3 4 4 5 5 3 4 4 3 5 3 5 5 5

1
为什么 rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata)) 会导致一个预测节点? - goldisfine
2
@goldisfine,因为这是rpart在内部计算预测节点的方式。这个功能在内部使用,但不会被公开。 - VitoshKa
@VitoshKa,感谢您发布解决方案。这是树的基本部分!没有这个部分,它几乎无法使用。 - user1700890

5

一种选择是将rpart对象转换为partykit包中的party类对象。这提供了一个通用工具箱来处理递归分区。转换很简单:

library("partykit")
(kyph_party <- as.party(kyph_tree))

Model formula:
Number ~ Kyphosis + Age + Start

Fitted party:
[1] root
|   [2] Start >= 15.5: 2.933 (n = 15, err = 10.9)
|   [3] Start < 15.5
|   |   [4] Age >= 112.5: 3.714 (n = 14, err = 18.9)
|   |   [5] Age < 112.5: 5.125 (n = 16, err = 29.8)

Number of inner nodes:    2
Number of terminal nodes: 3

(为确保精确重现,请在运行我的代码之前使用set.seed(1)运行您问题中的代码。)
对于这个类的对象,有一些更灵活的方法可以用于plot()predict()fitted()等。例如,plot(kyph_party)比默认的plot(kyph_tree)提供了更多信息的显示。fitted()方法提取一个具有拟合节点编号和训练数据上观察到的响应的双列data.frame
kyph_fit <- fitted(kyph_party)
head(kyph_fit, 3)

  (fitted) (response)
1        5          6
2        2          2
3        4          3

使用此功能,您可以轻松计算您感兴趣的任何量,例如每个节点内的平均值、中位数或残差平方和。

tapply(kyph_fit[,2], kyph_fit[,1], mean)

       2        4        5 
2.933333 3.714286 5.125000 

tapply(kyph_fit[,2], kyph_fit[,1], median)

2 4 5 
3 4 5 

tapply(kyph_fit[,2], kyph_fit[,1], function(x) sum((x - mean(x))^2))

       2        4        5 
10.93333 18.85714 29.75000 

除了简单的tapply(),您可以使用任何其他函数来计算分组统计表。

现在,要了解测试数据TEST_KYPHOSIS中哪个观察结果对应于树中的哪个节点,您可以简单地使用predict(..., type = "node")方法:

kyph_pred <- predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node")
head(kyph_pred)

 2  3  4  6  7 10 
 4  4  5  2  2  5 

1
你的解决方案产生了与 kyph_tree$where 相同的结果,但与下面的 VitoshKa 解决方案所得到的结果不同。 - user1700890
VitoshKa的predict_nodes()解决方案和partykit中的predict(..., type="node")解决方案由于节点ID的分配略有不同,因此不能产生完全相同的标签。但是实际上信息是等价的。请查看:table(predict_nodes(kyph_tree, TEST_KYPHOSIS), predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node"))。由于标签不同,它可能不是对角线,但存在1:1的匹配。但这是因为partykit为递归分区提供了通用解决方案,这些解决方案不特定于rpart - Achim Zeileis
感谢您的回复。我很难理解kyph_tree$where返回的是什么。它似乎不会返回叶节点标签。 - user1700890
实际上,它只是训练数据“TRAIN_KYPHOSIS”上的终端节点标签/ID。例如,请查看“table(kyph_tree$where)” 。如果您将其与“table(fitted(kyph_party)[,1])”或“table(predict(kyph_party, type = "node"))”进行比较,则会看到相同的绝对频率,但可能有不同的标签(取决于树的结构)。 - Achim Zeileis
再次感谢您。我不确定您所说的标签/ID是什么意思。unique(kyph_tree$where)返回3,7,5,6,但是如果您查看生成的树的终端节点,则为:4, 10, 11, 3'。我使用了'fancyRpartPlot(kyph_tree)'来绘制这棵树。 - user1700890
1
好的,$where 中的 ID 对应于 $frame 中的行,即不是用于打印/绘图的相同标签,而是对应于摘要表中的 ID。如果您使用该摘要表中的行名称,则可以获得与打印/绘图中使用的相同标签:table(as.numeric(rownames(kyph_tree$frame))[kyph_tree$where]) - Achim Zeileis

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接