Tidymodels: 如何从训练数据中提取重要性

3
我有以下代码,其中我进行了一些网格搜索以寻找不同的mtry和min_n。我知道如何提取给出最高准确性的参数(请参见第二个代码框)。那么,如何从训练数据集中提取每个特征的重要性呢?我在网上找到的指南仅展示了如何使用“last_fit”在测试数据集中执行此操作。例如,指南链接:https://www.tidymodels.org/start/case-study/#data-split
set.seed(seed_number)
    data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
    
    train <- training(data_split)
    test <- testing(data_split)
    train_folds <- vfold_cv(train,v = 10)
    
    
    rfc <- rand_forest(mode = "classification", mtry = tune(),
                       min_n = tune(), trees = 1500) %>%
        set_engine("ranger", num.threads = 48, importance = "impurity")
    
    rfc_recipe <- recipe(data = train, Group~.)
    
    rfc_workflow <- workflow() %>% add_model(rfc) %>%
        add_recipe(rfc_recipe)
    
    rfc_result <- rfc_workflow %>%
        tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
                  metrics = metric_set(accuracy))

.

best <- 
        rfc_result %>% 
        select_best(metric = "accuracy")
1个回答

8
要做到这一点,您需要创建一个自定义的extract函数,如在此文档中概述
对于随机森林变量重要性,您的函数将类似于以下内容:
get_rf_imp <- function(x) {
    x %>% 
        extract_fit_parsnip() %>% 
        vip::vi()
}

接着,你可以像这样将其应用于您的重采样数据(请注意,您会得到一个新的.extracts列):

library(tidymodels)
data(cells, package = "modeldata")

set.seed(123)
cell_split <- cells %>% select(-case) %>%
    initial_split(strata = class)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)
folds <- vfold_cv(cell_train)            

rf_spec <- rand_forest(mode = "classification") %>%
    set_engine("ranger", importance = "impurity")

ctrl_imp <- control_grid(extract = get_rf_imp)

cells_res <-
    workflow(class ~ ., rf_spec) %>%
    fit_resamples(folds, control = ctrl_imp)
cells_res
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits             id     .metrics         .notes           .extracts       
#>    <list>             <chr>  <list>           <list>           <list>          
#>  1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

此内容创建于2022年06月19日,使用reprex package (v2.0.1)。

获取变量重要性分数后,您可以使用unnest()函数对其进行展开(此时需要执行两次,因为它是深度嵌套的),然后按照您的喜好进行汇总和可视化:

cells_res %>%
    select(id, .extracts) %>%
    unnest(.extracts) %>%
    unnest(.extracts) %>%
    group_by(Variable) %>%
    summarise(Mean = mean(Importance),
              Variance = sd(Importance)) %>%
    slice_max(Mean, n = 15) %>%
    ggplot(aes(Mean, reorder(Variable, Mean))) +
    geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +
    labs(x = "Variable importance", y = NULL)

2022年6月19日由reprex包(v2.0.1)创建


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接