免责声明:
这是一个糟糕的答案,或者说caret
包只是对这个特定问题有一个可怕的实现。在任何一种情况下,如果不存在更多多样化的predict
函数或修复object$finalModel
中使用的命名,似乎适合在它们的github上开立一个问题或愿望。
这个问题(发生在第二次尝试时)源于caret
包内部处理多样化拟合过程的方式,基本上限制了预测函数,看起来是为了清理和标准化目的。
问题:
问题有两个方面:
predict.train
不允许进行预测/置信区间
train(...)
输出中包含的finalModel
包含一个格式异常的公式。
这两个问题似乎源自train
的格式和在predict.train
中的使用。首先关注后一个问题,通过查看输出可以看出。
formula(m$finalModel)
#`.outcome ~ `poly(hp, 2)1` + `poly(hp, 2)2`)
很明显,在运行“train”时会执行一些格式化操作,因为期望的输出应该是“mpg~poly(hp,2)”,但是输出已经扩展了RHS(并添加了引号/标签),并更改了LHS。因此,要么修复公式,要么能够使用公式将是不错的选择。
研究“caret”包中的“predict.train”函数如何使用这个内容,会发现针对“newdata”输入的代码如下:
predict.formula
--more code
if (!is.null(newdata)) {
if (inherits(object, "train.formula")) {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
m <- model.frame(Terms, newdata, na.action = na.action,
xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
keep <- match(row.names(m), rn)
newdata <- model.matrix(Terms, m, contrasts = object$contrasts)
xint <- match("(Intercept)", colnames(newdata),
nomatch = 0)
if (xint > 0)
newdata <- newdata[, -xint, drop = FALSE]
}
}
--more code
out <- predictionFunction(method = object$modelInfo,
modelFit = object$finalModel, newdata = newdata,
preProc = object$preProcess)
对于经验不太丰富的
R
用户来说,我们基本上看到一个
model.matrix
是从头开始构建的,而没有使用
formula(m$finalModel)
的输出(我们可以使用这个!),随后调用一些函数来根据
m$finalModel
进行预测。查看同一软件包中的
predictionFunction
函数会发现,该函数只是简单地调用
m$modelInfo$predict(m$finalModel, newdata)
(针对我们的示例)。
最后,查看
m$modelInfo$predict
可以发现以下代码片段。
m$modelInfo$predict
function(modelFit, newdata, submodels = NULL) {
if(!is.data.frame(newdata))
newdata <- as.data.frame(newdata)
predict(modelFit, newdata)
}
请注意,
modelFit = m$finalModel
,而
newdata
是使用上面的输出创建的。另外
注意,调用
predict
时不能指定
interval = "confidence"
,这就是第一个问题的原因。
解决问题(有点):
有多种方法可以解决这个问题。其中一种是使用
lm(...)
代替
train(...)
。另一种是利用该函数的内部机制创建数据对象,适应奇怪的模型规范,这样我们就可以像预期的那样使用
predict(m$finalModel, newdata = newdata, interval = "confidence")
。
我选择后者。
caretNewdata <- caretTrainNewdata(m, mtcars)
preds <- predict(m$finalModel, caretNewdata, interval = "confidence")
head(preds, 3)
fit lwr upr
Mazda RX4 22.03708 20.74297 23.33119
Mazda RX4 Wag 22.03708 20.74297 23.33119
Datsun 710 24.21108 22.77257 25.64960
以下是提供的函数。对于技术控来说,我基本上从“预测”、“predictionFunction”和“m$modelInfo$predict”中提取了“model.matrix”构建过程。 我不能保证这个函数适用于每个“caret”模型的一般用法,但它是一个起点。
caretTrainNewdata
函数:
caretTrainNewdata <- function(object, newdata, na.action = na.omit){
if (!is.null(object$modelInfo$library))
for (i in object$modelInfo$library) do.call("requireNamespaceQuietStop",
list(package = i))
if (!is.null(newdata)) {
if (inherits(object, "train.formula")) {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
m <- model.frame(Terms, newdata, na.action = na.action,
xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
keep <- match(row.names(m), rn)
newdata <- model.matrix(Terms, m, contrasts = object$contrasts)
xint <- match("(Intercept)", colnames(newdata),
nomatch = 0)
if (xint > 0)
newdata <- newdata[, -xint, drop = FALSE]
}
}
else if (object$control$method != "oob") {
if (!is.null(object$trainingData)) {
if (object$method == "pam") {
newdata <- object$finalModel$xData
}
else {
newdata <- object$trainingData
newdata$.outcome <- NULL
if ("train.formula" %in% class(object) &&
any(unlist(lapply(newdata, is.factor)))) {
newdata <- model.matrix(~., data = newdata)[,
-1]
newdata <- as.data.frame(newdata)
}
}
}
else stop("please specify data via newdata")
} else
stop("please specify data data via newdata")
if ("xNames" %in% names(object$finalModel) & is.null(object$preProcess$method$pca) &
is.null(object$preProcess$method$ica))
newdata <- newdata[, colnames(newdata) %in% object$finalModel$xNames,
drop = FALSE]
if(!is.null(object$preProcess))
newdata <- predict(preProc, newdata)
if(!is.data.frame(newdata) &&
!is.null(object$modelInfo$predict) &&
any(grepl("as.data.frame", as.character(body(object$modelInfo$predict)))))
newdata <- as.data.frame(newdata)
newdata
}
formula(ms$finalModel)
得到了什么输出?所呈现的输出与你第一个框中的公式不相等。 - Oliver.outcome ~ \
poly(hp, 2)1` + `poly(hp, 2)2` <environment: 0x0000000091b1a6e8>` - Fish11R
和caret
?执行predict(fm, newdata=mtcars, interval="confidence", level=0.95)
可以正确地给出预测区间。 - OliverR version 3.6.0 (2019-04-26)
andcaret_6.0-84
- Fish11R
更新到3.6.1。希望另一个善良的人能找到更好的答案。 - Oliver