获取rpart节点(即:CART)中的观测数据

8

我希望能够查看到达rpart决策树某个节点的所有观测数据。例如,在以下代码中:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit

n= 81 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

我希望能够查看节点(5)中的所有观测值(即:开始时间大于等于8.5且小于14.5的33个观测值)。显然,我可以手动获取它们。但我希望有一个类似于“get_node_date”的函数,只需运行get_node_date(5)即可获取相关的观测值。
你有什么建议吗?
6个回答

5

似乎没有这样的功能可以从特定节点中提取观察结果。我会按照以下方式解决:首先确定您感兴趣的节点使用的规则是哪个/哪些。您可以使用 path.rpart 进行操作。然后,您可以逐一应用规则以提取观测结果。

此方法作为一个功能:

get_node_date <- function(tree = fit, node = 5){
  rule <- path.rpart(tree, node)
  rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
  ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
  kyphosis[ind,]
  }

对于节点5,您将获得:
get_node_date()

 node number: 5 
   root
   Start>=8.5
   Start< 14.5
   Kyphosis Age Number Start
2    absent 158      3    14
10  present  59      6    12
11  present  82      5    14
14   absent   1      4    12
18   absent 175      5    13
20   absent  27      4     9
23  present  96      3    12
26   absent   9      5    13
28   absent 100      3    14
32   absent 125      2    11
33   absent 130      5    13
35   absent 140      5    11
37   absent   1      3     9
39   absent  20      6     9
40  present  91      5    12
42   absent  35      3    13
46  present 139      3    10
48   absent 131      5    13
50   absent 177      2    14
51   absent  68      5    10
57   absent   2      3    13
59   absent  51      7     9
60   absent 102      3    13
66   absent  17      4    10
68   absent 159      4    13
69   absent  18      4    11
71   absent 158      5    14
72   absent 127      4    12
74   absent 206      4    10
77  present 157      3    13
78   absent  26      7    13
79   absent 120      2    13
81   absent  36      4    13

2
在rpart中,可以通过$where获得训练观测的终端节点分配。
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where

作为一个函数:
get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) {
  data[which(fit$where == node.number),]  
}
get_node()

这仅适用于训练观测数据,而不适用于新的观测数据。也不适用于内部节点。

1
这仅适用于树的终端节点。问题要求非终端节点。 - riccardo-df
1
@riccardo-df 确实如此!我已经调整了答案。我仍然保留了答案,因为一些用户可能只需要终端节点,这只涉及有限的代码量。更高级别的答案显然提供了更全面的答案。 - Marjolein Fokkema

1
另一种方法是通过查找任何特定节点的所有终端节点,并返回调用中使用的数据子集来实现的。
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)

head(subset.rpart(fit, 5))
#    Kyphosis Age Number Start
# 2    absent 158      3    14
# 10  present  59      6    12
# 11  present  82      5    14
# 14   absent   1      4    12
# 18   absent 175      5    13
# 20   absent  27      4     9


subset.rpart <- function(tree, node = 1L) {
  data <- eval(tree$call$data, parent.frame(1L))
  wh <- sapply(as.integer(rownames(tree$frame)), parent)
  wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
  data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}

parent <- function(x) {
  if (x[1] != 1)
    c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}

1

“partykit”包也提供了一个现成的解决方案。您只需要将“rpart”对象转换为“party”类,以便使用其统一的接口来处理树。然后,您可以使用“data_party()”函数。

使用问题中的“fit”,并加载“library(“partykit”)”,您可以首先将“rpart”树强制转换为“party”:

pfit <- as.party(fit)
plot(pfit)

full pfit tree

提取您想要的数据只有两个小问题:(1)原始拟合的model.frame()在强制转换中总是被删除,需要手动重新附加。(2)节点使用不同的编号方案。现在您想要节点4(而不是5)。

pfit$data <- model.frame(fit)
data4 <- data_party(pfit, 4)
dim(data4)
## [1] 33  5
head(data4)
##    Kyphosis Age Start (fitted) (response)
## 2    absent 158    14        7     absent
## 10  present  59    12        8    present
## 11  present  82    14        8    present
## 14   absent   1    12        5     absent
## 18   absent 175    13        7     absent
## 20   absent  27     9        5     absent

另一种方法是对从节点4开始的子树进行子集,然后获取其中的数据:

pfit4 <- pfit[4]
plot(pfit4)

subtree of pfit from node 4

然后,data_party(pfit4)将给你与上面的data4相同的结果。而pfit4$data将给你没有(fitted)节点和预测的(response)的数据。


如果您使用了 ptree$data <- model.frame(eval(tree$call$data)),那么在公式中没有使用的变量将不会被删除。 - rawr
如果data包含了formula中的所有变量,那么返回True,但这并不一定是情况。使用model.frame()函数,您还可以获得转换后的变量,例如log()Surv()factor()版本的变量,这些变量通常是即时创建的。 - Achim Zeileis
顺便说一句:对于 rpart 对象,as.party() 强制转换现在默认 _保留数据_! 因此,您可以执行 as.party(fit, data = TRUE)(这是新的默认设置)或 as.party(fit, data = FALSE)(与旧行为相对应)。 - Achim Zeileis

0

另一种方法是从给定节点 n 找到所有子节点。 我们可以使用 rpart 对象来找到它们。将这些信息与数据集中每个点的终节点(在此问题中为驼背症)结合起来, 可以通过 fit$where 获得,如 @rawar 所解释, 您可以获取涉及给定节点的数据集中的所有点,不一定是终节点。

步骤摘要:

  1. 找到节点编号并识别那些是末端节点("叶子节点")。这些信息可以在rpart对象的frame元素中找到。
  2. 计算给定节点n的所有子节点。可以使用递归计算,因为节点n的子节点是2*n2*n+1,详见rpart.plot包的vignette第26页。
  3. 一旦知道了从节点n下垂的叶子,就可以使用rpart对象的where元素选择数据集中的那些叶子节点中的数据点。

我在函数get_children_nodes()中编写了步骤1和2,并在函数get_node_data()中编写了步骤3来回答提出的问题。在这个函数中,我已经包括了打印相应节点规则(rule = TRUE)以获得与@datamineR相同的答案的可能性。

library(rpart)
library(rpart.plot)

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)

get_children_nodes <- function(tree, node){
  # check if node is a leaf based in rpart object (tree) information (step 1)
  z <- tree$frame
  is_leaf <- z$var == "<leaf>"
  nodes <- as.integer(row.names(z))
  
  # find recursively all children nodes (step 2)
  find_children <- function(node, nodes, is_leaf){
    condition <- is_leaf[nodes == node]
    if (condition) {
      # If node is leaf, return it
      v1 <- node
    } else {
      # If node is not leaf, search children leaf recursively
      v1 <- c(find_children(2 * node, nodes, is_leaf), 
              find_children(2 * node + 1, nodes, is_leaf))
    } 
    return(v1)
  }
  
  return(find_children(node, nodes, is_leaf))
}

get_node_data <- function(dataset, tree, node, rule = FALSE) {
  # Find children nodes of the node
  children_nodes <- get_children_nodes(tree, node)
  # match those nodes into the rpart node identification
  id_nodes <- which(as.integer(row.names(tree$frame)) %in% children_nodes)
  # Get the elements in the datset involved in the node (step 3)
  filtered_dataset <- dataset[tree$where %in% id_nodes, ]
  
  # print the node rule if needed
  if(rule) {
    rpart::path.rpart(tree, node, pretty = TRUE)
    cat("  \n")
  }
  return( filtered_dataset)
}

# Get the children nodes
get_children_nodes(fit, 5)
#> [1] 10 22 23

# Complete function to return the elements of node 5
get_node_data(kyphosis, fit, 5, rule = TRUE) 
#> 
#>  node number: 5 
#>    root
#>    Start>=8.5
#>    Start< 14.5
#> 
#>    Kyphosis Age Number Start
#> 2    absent 158      3    14
#> 10  present  59      6    12
#> 11  present  82      5    14
#> 14   absent   1      4    12
#> 18   absent 175      5    13
#> 20   absent  27      4     9
#> 23  present  96      3    12
#> 26   absent   9      5    13
#> 28   absent 100      3    14
#> 32   absent 125      2    11
#> 33   absent 130      5    13
#> 35   absent 140      5    11
#> 37   absent   1      3     9
#> 39   absent  20      6     9
#> 40  present  91      5    12
#> 42   absent  35      3    13
#> 46  present 139      3    10
#> 48   absent 131      5    13
#> 50   absent 177      2    14
#> 51   absent  68      5    10
#> 57   absent   2      3    13
#> 59   absent  51      7     9
#> 60   absent 102      3    13
#> 66   absent  17      4    10
#> 68   absent 159      4    13
#> 69   absent  18      4    11
#> 71   absent 158      5    14
#> 72   absent 127      4    12
#> 74   absent 206      4    10
#> 77  present 157      3    13
#> 78   absent  26      7    13
#> 79   absent 120      2    13
#> 81   absent  36      4    13

创建于2023-08-14,使用reprex v2.0.2生成


0

rpart返回rpart.object元素,其中包含您需要的信息:

require(rpart)
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit2

get_node_date <-function(nodeId,fit)
{  
  fit$frame[toString(nodeId),"n"]
}


for (i in c(1,2,4,5,10,11,22,23,3) )
  cat(get_node_date(i,fit2),"\n")

1
你只能得到落入某个类别的观测数量,而不能得到这些观测本身。 - DatamineR

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接