使用R中'rpart'包中的生存树来预测新观测数据。

14
我试图使用R中的“rpart”包来构建生存树,并希望使用此树来预测其他观测值。我知道有很多关于rpart和预测的SO问题,但是我没有找到任何一个解决使用“Surv”对象与rpart一起使用时(我认为)特定的问题。
我的特定问题涉及解释“predict”函数的结果。以下是一个示例:
library(rpart)
library(OIsurv)

# Make Data:
set.seed(4)
dat = data.frame(X1 = sample(x = c(1,2,3,4,5), size = 1000, replace=T))
dat$t = rexp(1000, rate=dat$X1)
dat$t = dat$t / max(dat$t)
dat$e = rbinom(n = 1000, size = 1, prob = 1-dat$t )

# Survival Fit:
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
plot(sfit)

# Tree Fit:
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)

# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )

到目前为止一切都很好。我对这里正在发生的事情的理解是rpart试图将指数生存曲线拟合到我的数据子集上。基于这个理解,我认为当我调用predict(tfit)时,对于每个观察值,我会得到一个与该观察值的指数曲线相关的参数。例如,如果predict(fit)[1]是0.46,则这意味着对于原始数据集中的第一个观测值,曲线由方程P(s) = exp(−λt)给出,其中λ=0.46
这似乎正是我想要的。对于每个观察值(或任何新的观察值),我可以获得在给定时间点预测该观察值存活/死亡的概率。(编辑:我意识到这可能是一个误解——这些曲线并不给出存活/死亡的概率,而是在一段时间间隔内生存下来的概率。不过,这并不影响下面描述的问题。) 然而,当我尝试使用指数公式时...
# Predict:
# an attempt to use the rates extracted from the tree to
# capture the survival curve formula in each tree node.
rates = unique(predict(tfit))
for (rate in rates) {
  grid= seq(0,1,length.out = 100)
  lines(x= grid, y= exp(-rate*(grid)), col=2)
}

绘图

我所做的是按照生存树的方式将数据集分割,然后使用survfit为每个分区绘制非参数曲线。这就是黑色线条。我还画了对应于将“速率”参数(我认为)插入到(我认为)生存指数公式中的结果的线条。

我知道非参数拟合和参数拟合不一定相同,但这似乎不仅如此:我似乎需要缩放我的X变量或其他操作。

基本上,我似乎不理解rpart/survival在内部使用的公式。有谁能帮我从(1)rpart模型转换到(2)任意观测值的生存方程?

2个回答

13
存活数据在内部以指数形式缩放,因此根节点中预测的速率始终固定为1.000predict()方法报告的预测值总是相对于根节点中的生存情况,即高于或低于某个因子。有关更多详细信息,请参见vignette(“longintro”,package =“rpart”)第8.4节。无论如何,您报告的Kaplan-Meier曲线与rpart文献中报告的内容完全相同。
如果要直接获得树中Kaplan-Meier曲线的图形并获取预测的中位生存时间,则可以将rpart树强制转换为partykit包提供的constparty树。
library("partykit")
(tfit2 <- as.party(tfit))
## Model formula:
## Surv(t, event = e) ~ X1
## 
## Fitted party:
## [1] root
## |   [2] X1 < 2.5
## |   |   [3] X1 < 1.5: 0.192 (n = 213)
## |   |   [4] X1 >= 1.5: 0.082 (n = 213)
## |   [5] X1 >= 2.5: 0.037 (n = 574)
## 
## Number of inner nodes:    2
## Number of terminal nodes: 3
##
plot(tfit2)

生存树

打印输出显示了中位生存时间,可视化显示了相应的Kaplan-Meier曲线。两者也可以通过将type参数设置为"response""prob"来使用predict()方法获得。

predict(tfit2, type = "response")[1]
##          5 
## 0.03671885 
predict(tfit2, type = "prob")[[1]]
## Call: survfit(formula = y ~ 1, weights = w, subset = w > 0)
## 
##  records    n.max  n.start   events   median  0.95LCL  0.95UCL 
## 574.0000 574.0000 574.0000 542.0000   0.0367   0.0323   0.0408 

除了使用 rpart 生存树之外,您还可以考虑基于条件推断的非参数生存树,例如使用 ctree()(使用logrank分数)或完全参数化的生存树,使用来自 partykit 包的通用 mob() 基础设施。


感谢您详细的回复!但我这里的目标是在任何时间点对任何实例获得一个P(alive),因此这似乎应该比仅提取与每个实例的树节点相关的中位生存时间更有用。我唯一能做到这一点的方法是使用“pec”包中的predictSurvProb函数,但这个函数有些错误,并且我还希望通过计算生存曲线本身来计算生存概率,而不是依赖于这个函数,以此来提高效率。 - jwdink
是的,Kaplan-Meier函数确实是(非参数)生存函数S(t)的估计量,即在时间t仍然存活的概率。 Kaplan-Meier函数可以通过手动使用survfit()和基于$where的因子进行计算 - 或者通过type =“prob”使用partykit来计算。如果您想在每个叶子中拟合参数模型(例如指数或Weibull),则可以使用survreg()而不是survfit() - Achim Zeileis
抱歉,我不是很明白:您能否编辑您的帖子,提供实际的代码,以便为给定的t和给定的实例提供S(t)?例如,给定一个rpart对象“tfit”和一个实例“dat[1,]”,以及一个时间“dat[1,'t']”,我应该使用什么代码来获取该实例和该时间的S(t)? - jwdink
我不明白为什么你想编辑我的回答。上面展示的代码片段 predict(tfit2, type = "prob")[[1]] 提取了第一个观测值的拟合 survfit 对象。从中,您可以提取所有您喜欢的“通常”数量。例如,查看对象的 summary(),它会向您显示完整的Kaplan-Meier曲线坐标以及其他几个附加信息。 - Achim Zeileis
我希望进行编辑,因为从survfit对象的摘要中提取概率并不容易(正如您所指示的那样),但我认为我应该能够弄清楚。 - jwdink
2
但这实际上是关于survfitsurvival的问题,对此有很多有用的书籍、教程等。但是我认为,如果您执行以下操作:km1 <- predict(tfit2, type = "prob")[[1]],然后summary(km1),您应该会看到您需要的一切。您可以轻松地从中获取分位数,例如quantile(km1, c(0.2, 0.5, 0.8)),它会给出S(t)分别为0.8、0.5和0.2的时间。或者,如果您想要一个函数,您可以执行km1f <- approxfun(km1$time, km1$surv),然后执行km1f(c(0.011, 0.037, 0.094))等。 - Achim Zeileis

3
@Achim Zeileis的回答很有帮助,但似乎没有确切回答@jwdink的问题。我理解的是“如果RPart树通过最佳指数生存拟合进行分割,则这些拟合的Lambdas在绝对值上是多少,以便我们可以使用这些指数生存函数进行预测”。RPart摘要确实显示了估计速率,但只是相对于整个人口速率为1的情况下。为了克服这一点,可以适配指数survreg,从中取得引用的lambda,然后将RPart预测速率乘以该数字(参见下面的代码)。
话虽如此,这不是RPart中树的生存率预测方法。我没有直接在RPart中找到生存预测功能,但正如Achim所说,partykit使用Kaplan-Meier估计,即来自处于相应最终叶子节点的非参数生存情况。我认为在生存随机森林树中也是一样的,在最终叶子中使用K-M曲线。
这个问题中的模拟数据使用指数分布,因此K-M和指数生存曲线在设计上是类似的,但对于不同的模拟或真实分布,通过RPart树估计的指数率和使用最终叶子(相同树)中的K-M曲线将给出不同的生存率。
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)

# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
table(dat$node)
s0 = survreg(Surv(t,e)~ 1, data =  dat, dist = "exponential") #-0.6175
e0 = exp(-summary(s0)$coefficients[1]); e0 #1.854
rates = unique(predict(tfit))
#1) plot K-M curves by node (black):
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )

#2) plot exponential survival with rates = e0 * RPart rates (red):
for (rate in rates) {
  grid= seq(0,1,length.out = 100)
  lines(x= grid, y= exp(-e0*rate*(grid)), col=2)
}
#3) plot partykit survival curves based on RPart tree (green)
library(partykit)
tfit2 <- as.party(tfit)
col_n = 1
for (node in names(table(dat$node))){
  predict_curve = predict(tfit2, newdata = dat[dat$node == node, ], type = "prob")  
  surv_esitmated = approxfun(predict_curve[[1]]$time, predict_curve[[1]]$surv)
  lines(x= grid, y= surv_esitmated(grid), col = 2+col_n)
  col_n=+1
}

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接