如何从tidymodels workflowset中提取模型拟合结果?

3

我正在尝试学习tidymodels和DALEXtra... 我已经成功地用workflow_map构建了一组模型:

grid_results <-
   all_workflows %>%
   workflow_map(
      seed = 1503,
      resamples = the_folds,
      grid = 100,
      control = grid_ctrl,
      verbose=TRUE
   )

grid_results %>% 
  rank_results() %>% 
  filter(.metric == "roc_auc") %>% 
  select(model, .config, roc_auc = mean, rank) |> 
  head()

我其中一个BART模型看起来像是“胜利者”:

# A tibble: 6 × 4
  model        .config                roc_auc  rank
  <chr>        <chr>                    <dbl> <int>
1 bart         Preprocessor1_Model046   0.656     1

我想将那个模型输入到DALEXtra中:

library(DALEXtra)

explainer_bart <- 
  explain_tidymodels(
    x, # <--------------- what goes here?
    data = the_train,
    y = adherence_group,
    label = bart,
    verbose = FALSE
  )



我认为explain_tidymodels()函数需要一个拟合模型。我该如何从工作流设置的结果中提取它?
我是初学者,所以希望能够提供给不懂的人一些线索(最好附带链接)。

1
请查看Julia的博客文章。您需要在模型上使用extract_workflow()函数。 - Desmond
1个回答

2
如果您调整了BART模型,则需要获得已拟合的工作流对象。这就是您提供给DALEX函数的内容。
下面是一个使用Chicago数据的内置工作流程集的示例:
library(tidymodels)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.2).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain

# Pull out the workflow that you want
workflow_object <- 
  chi_features_res %>% 
  extract_workflow(id = "plus_pca_lm") 

# If there are tuning parameters, get the best results
best_results <- 
  chi_features_res %>% 
  extract_workflow_set_result(id = "plus_pca_lm") %>% 
  select_best(metric = "rmse")

# Update your workflow and fit: 
fitted_workflow_object <- 
  workflow_object %>% 
  finalize_workflow(best_results) %>% 
  fit(data = Chicago)

fitted_workflow_object
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 5 Recipe Steps
#> 
#> • step_date()
#> • step_holiday()
#> • step_dummy()
#> • step_zv()
#> • step_pca()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> 
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#> 
#> Coefficients:
#>       (Intercept)           temp_min               temp           temp_max  
#>        -6.250e+02          1.624e-02          2.639e-02          4.314e-03  
#>       temp_change                dew           humidity           pressure  
#>                NA         -2.514e-02          1.124e-02         -1.116e-04  
#>   pressure_change               wind           wind_max               gust  
#>         8.107e-02         -1.912e-02          1.091e-03         -1.107e-02  
#>          gust_max             percip         percip_max       weather_rain  
#>         4.430e-03         -1.127e+01         -1.470e-01         -7.882e-01  
#>      weather_snow      weather_cloud      weather_storm    Blackhawks_Away  
#>        -7.061e-01         -3.149e-01          5.831e-02         -1.395e-01  
#>   Blackhawks_Home         Bulls_Away         Bulls_Home         Bears_Away  
#>        -3.423e-02          3.554e-02          3.418e-01          3.287e-01  
#>        Bears_Home      WhiteSox_Away      WhiteSox_Home          Cubs_Away  
#>         2.740e-01         -4.920e-01                 NA                 NA  
#>         Cubs_Home          date_year      date_LaborDay   date_NewYearsDay  
#>                NA          3.121e-01          8.171e-01         -1.004e+01  
#> date_ChristmasDay       date_dow_Mon       date_dow_Tue       date_dow_Wed  
#>        -1.127e+01          1.244e+01          1.384e+01          1.385e+01  
#>      date_dow_Thu       date_dow_Fri       date_dow_Sat     date_month_Feb  
#>         1.361e+01          1.302e+01          1.435e+00          3.602e-01  
#>    date_month_Mar     date_month_Apr     date_month_May     date_month_Jun  
#>         6.938e-01          9.297e-01          5.221e-01          1.397e+00  
#>    date_month_Jul     date_month_Aug     date_month_Sep     date_month_Oct  
#>         7.532e-01          9.335e-01          9.002e-01          1.545e+00  
#>    date_month_Nov     date_month_Dec                PC1                PC2  
#>         2.633e-01         -3.567e-01          6.636e-04          1.461e-01  
#>               PC3                PC4                PC5                PC6  
#>         4.950e-01         -1.577e-01         -4.550e-02          4.059e-01  
#>               PC7                PC8                PC9  
#>        -1.665e-01         -6.379e-02          2.689e-01

# Put that in the explainer
# There are some warnings here but you can disragard them
explainer_obj <- 
  explain_tidymodels(
    fitted_workflow_object, 
    data = Chicago %>% select(-ridership),
    y = Chicago$ridership,
    label = "model",
    verbose = FALSE
  )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from a rank-deficient fit may be misleading

#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from a rank-deficient fit may be misleading

#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from a rank-deficient fit may be misleading

reprex包(v2.0.1)于2022-10-11创建


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接