R - 从caret和glmnet lasso模型对象中提取因子预测变量名称

3
在下面的例子中,我设置了一个包含3个变量的数据框,分别是predict、var1和var2(一个因子)。
当我在caret或glmnet中运行模型时,这个因子会被转换成一个虚拟变量,例如var2b。
我想以程序化的方式提取变量名,并匹配原始变量名,而不是虚拟变量名--有没有办法做到这一点?
这只是一个例子,我的实际问题有许多具有不同级别的变量,因此,我希望避免手动操作,比如尝试提取“b”。
谢谢!
library(caret)
library(glmnet)

df <- data.frame(predict = c('Y','Y','N','Y','N','Y','Y','N','Y','N'), var1 = c(1,2,5,1,6,7,3,4,5,6),
              var2 = c('a','a','b','b','a','a','a','b','b','a'))

str(df)

# 'data.frame': 10 obs. of  3 variables:
# $ predict: Factor w/ 2 levels "N","Y": 2 2 1 2 1 2 2 1 2 1
# $ var1   : num  1 2 5 1 6 7 3 4 5 6
# $ var2   : Factor w/ 2 levels "a","b": 1 1 2 2 1 1 1 2 2 1

test <- train(predict ~ .,
           data = df,
           method = 'glmnet',
           trControl = trainControl(classProbs = TRUE,
                                    summaryFunction = twoClassSummary,
                                    allowParallel = FALSE),
           metric = 'ROC',
           tuneGrid = expand.grid(alpha = 1,
                                  lambda = .005))

predictors(test)
# [1] "var1"  "var2b"
varImp(test)
# glmnet variable importance

# Overall
# var2b     100
# var1        0

coef(test)
# NULL
#################
x <- model.matrix(as.formula(predict~.),data=df)
x <-  x[,-1] ##remove intercept

df$predict <- ifelse(df$predict == 'Y', TRUE, FALSE)

glmnet1 <- glmnet::cv.glmnet(x = x,
                          y = df$predict,
                          type.measure='auc',
                          nfolds=3,
                          alpha=1,
                          parallel = FALSE)

rownames(coef(glmnet1))
# [1] "(Intercept)" "var1"        "var2b

请注意,df$predict <- df$predict == 'Y' 等同于您的 ifelse 调用,但更高效。 - CSJCampbell
谢谢,我会记住的。 - BigTimeStats
2个回答

1
< p >“train”对象的< code > formula 方法返回一个“公式”对象,该对象具有您正在寻找的属性。

f1 <- formula(test)
f1
# predict ~ var1 + var2
# attr(,"variables")
# list(predict, var1, var2)
# attr(,"factors")
#         var1 var2
# predict    0    0
# var1       1    0
# var2       0    1
# attr(,"term.labels")
# [1] "var1" "var2"
# attr(,"order")
# [1] 1 1
# attr(,"intercept")
# [1] 1
# attr(,"response")
# [1] 1
# attr(,"predvars")
# list(predict, var1, var2)
# attr(,"dataClasses")
#   predict      var1      var2 
#  "factor" "numeric"  "factor" 
attr(f1, "term.labels")
# [1] "var1" "var2"

似乎变量名在'cv.glmnet'对象中不可用。我不知道有没有一种优雅的方法来收集它们。 'glmnetUtils'包可能具有一些实用功能。
以下是一些您可以尝试的代码;请注意,这将返回误报,因为它正在从输入数据中按模式搜索列名(例如,“var11”将匹配“var1”)。
# a generic method
termLabels <- function(object, ...) {
    UseMethod("termLabels")
}
# add for the train object too to save typing
termLabels.train <- function(object, ...) {
    attr(formula(object), "term.labels")
}
# try to find term labels for cv.glmnet object
# lambda must be provided and snaps to search grid
# allowed column names must be provided from corresponding data object
termLabels.cv.glmnet <- function(object, lambda, names, ...) {
    if (missing(lambda)) { stop("lambda is missing") }
    if (missing(names)) { stop("names is missing") }
    # match lambda
    lambdaArray <- object$glmnet.fit$a0
    if (lambda > max(lambdaArray) || lambda < min(lambdaArray)) {
        stop(paste("lambda must be in range", 
            paste(range(lambdaArray), collapse = ":")))
    }
    # find closest lambda
    whichLambda <- which.min(abs(lambdaArray - lambda))
    message(paste("using lambda", lambdaArray[whichLambda]))
    # matrix of parameter estimates
    betaLambda <- object$glmnet.fit$beta[, whichLambda, drop = FALSE]
    # non-zero estimates
    betaLambda <- betaLambda[betaLambda[, 1] != 0, , drop = FALSE]
    vars <- rownames(betaLambda)
    # search with names as pattern
    # note, does not account for nested names, e.g. var1 and var11
    matchNames <- apply(matrix(names), MARGIN = 1, FUN = grepl, x = vars)
    names[apply(matchNames, MARGIN = 2, FUN = any)]
}
termLabels(glmnet1, lambda = 1, names = colnames(df))
# using lambda 0.998561314952713
# [1] "var1" "var2"

感谢您提供的代码,非常感谢您抽出时间来做这件事。我认为我可以使用predictors(test)代码从caret模型对象中提取变量名,然后使用您的匹配名称代码将其与原始df的列名匹配:matchNames <- apply(matrix(names), MARGIN = 1, FUN = grepl, x = vars) names[apply(matchNames, MARGIN = 2, FUN = any)],但需要注意模糊匹配的问题。 - BigTimeStats

1
根据 @CSJCampbell 的回答,glmnetUtils 软件包可以让您使用 glmnet 和 cv.glmnet 对象来实现此操作。
library(glmnetUtils)
m <- glmnet(mpg ~ ., data=mtcars)
all.vars(m$terms)

m2 <- cv.glmnet(mpg ~ ., data=mtcars)
all.vars(m2$terms)

请注意,all.vars 在大多数其他 R 模型对象中也适用:
m3 <- lm(mpg ~ ., data=mtcars)
all.vars(delete.response(m3$terms))

glmnetUtils可以在CRAN上获取,或者您可以从Github获取dev版本。我目前正在完成一次重大更新,很快将发布到CRAN。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接