使用caret::train获取预测置信区间

3

我试图弄清楚如何从caret::train线性模型中获得置信区间。

我的第一次尝试只是使用通常的lm置信区间参数运行预测:

m <- caret::train(mpg ~ poly(hp,2), data=mtcars, method="lm")
predict(m, newdata=mtcars, interval="confidence", level=0.95)

但是看起来从caret::train返回的对象没有实现这个功能。

我的第二次尝试是提取finalModel并在其上进行预测:

m <- caret::train(mpg ~ poly(hp,2), data=mtcars, method="lm")
fm <- m$finalModel
predict(fm, newdata=mtcars, interval="confidence", level=0.95)

但是我遇到了错误。

Error in eval(predvars, data, env) : object 'poly(hp, 2)1' not found

深入研究后发现,最终模型对公式有一些奇怪的表示,并在我的新数据中搜索“poly(hp, 2)1”列,而不是评估该公式。m$finalModel看起来像这样:
Call:
lm(formula = .outcome ~ ., data = dat)

Coefficients:
   (Intercept)  `poly(hp, 2)1`  `poly(hp, 2)2`  
         20.09          -26.05           13.15

我应该补充说明,我之所以使用lm,并不仅仅是因为我正在使用caret通过交叉验证来拟合模型。

我如何从通过caret::train拟合的线性模型中获取置信区间?


1
你从 formula(ms$finalModel) 得到了什么输出?所呈现的输出与你第一个框中的公式不相等。 - Oliver
1
@Oliver 抱歉,我在问题中的公式错了,输出应该是.outcome ~ \poly(hp, 2)1` + `poly(hp, 2)2` <environment: 0x0000000091b1a6e8>` - Fish11
1
看起来你已经更新了,这样更符合要求。你使用的是哪个版本的 Rcaret?执行 predict(fm, newdata=mtcars, interval="confidence", level=0.95) 可以正确地给出预测区间。 - Oliver
1
R version 3.6.0 (2019-04-26) and caret_6.0-84 - Fish11
1
这不应该是问题,但尝试将 R 更新到3.6.1。希望另一个善良的人能找到更好的答案。 - Oliver
显示剩余2条评论
1个回答

6

免责声明:

这是一个糟糕的答案,或者说caret包只是对这个特定问题有一个可怕的实现。在任何一种情况下,如果不存在更多多样化的predict函数或修复object$finalModel中使用的命名,似乎适合在它们的github上开立一个问题或愿望。

这个问题(发生在第二次尝试时)源于caret包内部处理多样化拟合过程的方式,基本上限制了预测函数,看起来是为了清理和标准化目的。

问题:

问题有两个方面:

  1. predict.train不允许进行预测/置信区间
  2. train(...)输出中包含的finalModel包含一个格式异常的公式。

这两个问题似乎源自train的格式和在predict.train中的使用。首先关注后一个问题,通过查看输出可以看出。

formula(m$finalModel)
#`.outcome ~ `poly(hp, 2)1` + `poly(hp, 2)2`)

很明显,在运行“train”时会执行一些格式化操作,因为期望的输出应该是“mpg~poly(hp,2)”,但是输出已经扩展了RHS(并添加了引号/标签),并更改了LHS。因此,要么修复公式,要么能够使用公式将是不错的选择。
研究“caret”包中的“predict.train”函数如何使用这个内容,会发现针对“newdata”输入的代码如下:
predict.formula
#output
--more code
if (!is.null(newdata)) {
    if (inherits(object, "train.formula")) {
        newdata <- as.data.frame(newdata)
        rn <- row.names(newdata)
        Terms <- delete.response(object$terms)
        m <- model.frame(Terms, newdata, na.action = na.action, 
            xlev = object$xlevels)
        if (!is.null(cl <- attr(Terms, "dataClasses"))) 
            .checkMFClasses(cl, m)
        keep <- match(row.names(m), rn)
        newdata <- model.matrix(Terms, m, contrasts = object$contrasts)
        xint <- match("(Intercept)", colnames(newdata), 
            nomatch = 0)
        if (xint > 0) 
            newdata <- newdata[, -xint, drop = FALSE]
    }
}
--more code
    out <- predictionFunction(method = object$modelInfo, 
                modelFit = object$finalModel, newdata = newdata, 
                preProc = object$preProcess)

对于经验不太丰富的 R 用户来说,我们基本上看到一个 model.matrix 是从头开始构建的,而没有使用 formula(m$finalModel) 的输出(我们可以使用这个!),随后调用一些函数来根据 m$finalModel 进行预测。查看同一软件包中的 predictionFunction 函数会发现,该函数只是简单地调用 m$modelInfo$predict(m$finalModel, newdata)(针对我们的示例)。
最后,查看 m$modelInfo$predict 可以发现以下代码片段。
m$modelInfo$predict
#output
function(modelFit, newdata, submodels = NULL) {
                    if(!is.data.frame(newdata)) 
                        newdata <- as.data.frame(newdata)
                    predict(modelFit, newdata)
                  }

请注意,modelFit = m$finalModel,而newdata是使用上面的输出创建的。另外注意,调用predict时不能指定interval = "confidence",这就是第一个问题的原因。
解决问题(有点):
有多种方法可以解决这个问题。其中一种是使用lm(...)代替train(...)。另一种是利用该函数的内部机制创建数据对象,适应奇怪的模型规范,这样我们就可以像预期的那样使用predict(m$finalModel, newdata = newdata, interval = "confidence")
我选择后者。
caretNewdata <- caretTrainNewdata(m, mtcars)
preds <- predict(m$finalModel, caretNewdata, interval = "confidence")
head(preds, 3)
#output
                         fit      lwr      upr
Mazda RX4           22.03708 20.74297 23.33119
Mazda RX4 Wag       22.03708 20.74297 23.33119
Datsun 710          24.21108 22.77257 25.64960

以下是提供的函数。对于技术控来说,我基本上从“预测”、“predictionFunction”和“m$modelInfo$predict”中提取了“model.matrix”构建过程。 我不能保证这个函数适用于每个“caret”模型的一般用法,但它是一个起点。

caretTrainNewdata函数:

caretTrainNewdata <- function(object, newdata, na.action = na.omit){
    if (!is.null(object$modelInfo$library)) 
        for (i in object$modelInfo$library) do.call("requireNamespaceQuietStop", 
                                                    list(package = i))
    if (!is.null(newdata)) {
        if (inherits(object, "train.formula")) {
            newdata <- as.data.frame(newdata)
            rn <- row.names(newdata)
            Terms <- delete.response(object$terms)
            m <- model.frame(Terms, newdata, na.action = na.action, 
                             xlev = object$xlevels)
            if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                .checkMFClasses(cl, m)
            keep <- match(row.names(m), rn)
            newdata <- model.matrix(Terms, m, contrasts = object$contrasts)
            xint <- match("(Intercept)", colnames(newdata), 
                          nomatch = 0)
            if (xint > 0) 
                newdata <- newdata[, -xint, drop = FALSE]
        }
    }
    else if (object$control$method != "oob") {
        if (!is.null(object$trainingData)) {
            if (object$method == "pam") {
                newdata <- object$finalModel$xData
            }
            else {
                newdata <- object$trainingData
                newdata$.outcome <- NULL
                if ("train.formula" %in% class(object) && 
                    any(unlist(lapply(newdata, is.factor)))) {
                    newdata <- model.matrix(~., data = newdata)[, 
                                                                -1]
                    newdata <- as.data.frame(newdata)
                }
            }
        }
        else stop("please specify data via newdata")
    } else
        stop("please specify data data via newdata")
    if ("xNames" %in% names(object$finalModel) & is.null(object$preProcess$method$pca) & 
        is.null(object$preProcess$method$ica)) 
        newdata <- newdata[, colnames(newdata) %in% object$finalModel$xNames, 
                           drop = FALSE]
    if(!is.null(object$preProcess))
       newdata <- predict(preProc, newdata)
    if(!is.data.frame(newdata) && 
      !is.null(object$modelInfo$predict) && 
      any(grepl("as.data.frame", as.character(body(object$modelInfo$predict)))))
           newdata <- as.data.frame(newdata)
    newdata
}

非常感谢,看起来他们的Github上有一个未解决的问题 https://github.com/topepo/caret/issues/187 - Fish11
2
找得好。我已经在问题上添加了一条评论(自2015年最后一次讨论以来似乎未得到解决),引用了这个答案。 - Oliver
干得好。不过我有一个问题。我使用train()方法和"glmnet"模型创建了一个模型。一开始出现了一个错误,说requireNamespaceQuietStop不存在,但我用caret::"requireNamespaceQuietStop"解决了这个问题。然而,现在我在数据集的每一行中都得到了s0-s48(特征数量)列的数字。你有什么想法是怎么出错的吗? - Wilkit

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接