获取rpart模型节点的id/name

7
如何获取每行模型的终端节点的ID(或名称)?predict.rpart只能返回分类树的预测类别(数字或因子)、类别概率或某些组合(使用type="matrix")。 我想做如下操作:
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit) # there are 5 terminal nodes
predict(fit, type = "node_id")   # should return IDs of terminal nodes (e.g. 1-5) (does not work)
4个回答

7

partykit包支持predict(..., type = "node"),无论是在样本内还是样本外。你可以简单地转换rpart对象以使用它:

library("partykit")
predict(as.party(fit), type = "node")  
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8 
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8 
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9 
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9 
## 5 
## 5 
table(predict(as.party(fit), type = "node")) 
##  3  5  7  8  9 
## 29 12 14  7 19 

6

对于这个模型,有4个分割,产生了5个“终端节点”,或者用rpart术语来说:<leaf>。我不明白为什么会有5个预测。预测是针对特定情况的,而叶子是用于作出这些预测所需的可变数量的分割的结果。原始数据集中最终落入叶子的行数可能是您想要的内容,在这种情况下,以下是获取这些数字的方法:

# Row-wise predicted class
fit$where

# counts of cases in leaves of prediction rules
table(fit$where)
 3  5  7  8  9 
29 12 14  7 19 

为了组装适用于特定叶子的标签(fit),您需要遍历规则树并累积应用于生成特定叶子的所有分裂的标签。您可能需要查看以下内容:
?print.rpart    
?rpart.object
?text.rpart
?labels.rpart

1
谢谢,我所说的1-5是指终端节点的ID。您的答案有效,我可以简单地使用kyphosis["id_node"] <-fit$where将叶子ID分配给原始数据框。 - Tomas Greif

3
使用$where方法只会在树框架中弹出行号。因此,当使用kyphosis$ID = fit$where时,一些观测可能会被分配为节点ID而不是叶节点ID。要获取实际的叶节点ID,请使用以下方法:
MyID <- row.names(fit$frame)
kyphosis$ID <- MyID[fit$where]

0

对于预测新数据上的叶子节点,可以使用来自rpart.plot包的rpart.predict(fit, newdata, nn = TRUE)将节点名称添加到输出中。

这是一个独立的rpart叶子节点预测器:

rpart_leaves <- function(fit, newdata, type = c("where", "leaf"), na.action = na.pass) {
  if (is.null(attr(newdata, "terms"))) {
    Terms <- delete.response(fit$terms)
    newdata <- model.frame(Terms, newdata, na.action = na.action,
                           xlev = attr(fit, "xlevels"))
    if (!is.null(cl <- attr(Terms, "dataClasses")))
      .checkMFClasses(cl, newdata, TRUE)
  }
  newdata <- rpart:::rpart.matrix(newdata)
  where <- unname(rpart:::pred.rpart(fit, newdata))
  if (match.arg(type) == "where")
    return(where)
  rownames(fit$frame)[where]
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接