为tidymodel对象创建SHAP图表。

9
这个问题涉及到如何使用tidymodels在R中获取catboost模型的shap值摘要图表。根据问题下面的评论,OP已找到解决方案,但迄今为止还没有与社区分享。
我想分析使用tidymodels包拟合的我的树集合,并生成SHAP值图,例如单个观察的图表。

ttps://prnt.sc/CO_PC4aDUQA0

总结一下我的数据集所有特征的影响,例如:

enter image description here

DALEXtra提供了一个函数,用于为tidymodels explain.tidymodels()创建SHAP值。来自fastshap包的force_plot提供了一个包装器,用于底层python包SHAP的绘图函数。但我不知道如何使该函数与explain.tidymodels()函数的输出配合使用。

问题:如何使用tidymodelsexplain.tidymodels在R中生成这样的SHAP图?

MWE(对于使用explain.tidymodels的SHAP值)

library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <-  recipe(crim ~ ., data = Boston)

split <- initial_split(Boston)

train_data <- training(split)

test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()

model_default<-
  parsnip::boost_tree(
    mode = "regression"
  ) %>%
  set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>%  add_recipe(rec) %>% {parsnip::fit(object = ., data =  train_data)}

SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data

3
我个人在使用catboost和treesnip方面并没有太好的运气,但你可能会发现查看这篇博客文章有帮助:this blog post。特别注意如何将tidymodels输出作为SHAPforxgboost等函数的输入,使用extract_fit_engine()bake() - Julia Silge
我猜catboost的主要问题之一是,据我所知,原始作者仍未在CRAN中发布catboost在R中的实现,而且我怀疑他们是否有意这样做。 - mugdi
3
对于SHAP值而言,需要考虑的一个重要问题是这场持续争论,关于Lundberg等人原论文中主要假设之一在使用Tree算法时被违反了!如果你从事科学领域的工作,可能需要限制结果的有效性! - mugdi
1个回答

6

或许这将有所帮助。至少,这是朝着正确方向迈出的一步。

第一步,确保已安装 fastshap 和 reticulate(即 install.packages("..."))。接下来,设置虚拟环境并安装 shap(pip install ...)。另外,为了依赖图,需要安装 matplotlib 3.2.2(查看 GitHub 上的问题 -- 需要旧版本的 matplotlib)。

RStudio 提供了关于虚拟环境设置的很好的信息。尽管如此,根据使用的 IDE 的不同,虚拟环境设置需要更多或更少的故障排除。(遗憾的是,某些工作环境由于许可证的限制而限制了开源 RStudio 的使用。)

library(fastshap) 的文档对此也很有帮助。

以下是 lightgbm 的工作流程(来自 treesnip 文档,稍作修改)。

library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

# model specs
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 6)

#workflows
lightgbm_wf <- workflow() %>% 
    add_model(
       lightgbm_model
    )

rec_ordered <- recipe(
    price ~ .
      , data = diamonds
) 

lightgbm_fit_ordered <- fit_resamples(
  add_recipe(
    lightgbm_wf, rec_ordered
    ), resamples = diamonds_splits)

在预测之前,我们需要优化我们的工作流程

fit_workflow <- lightgbm_wf %>% 
     add_recipe(rec_ordered) %>% 
     fit(data = diamonds)

现在我们有了一个合适的工作流程并可以进行预测。为了使用fastshap::explain函数,我们需要创建一个预测函数(这并不总是成立:根据所使用的引擎,它可能或可能不可以直接使用 - 请参见文档)。

predict_function_gbm <-  function(model, newdata) {
    predict(model, newdata) %>% pluck(.,1)
}

我们顺便得到平均预测值(在下面使用),这也作为检查函数是否正常工作的验证。

mean_preds <- mean(
    predict_function_gbm(
       fit_workflow, diamonds %>% select(-price)
   )
)

现在我们创建我们的解释(shap值)。请注意这里的pred_wrapper和X参数(有关其他示例,请参见fastshap github问题,例如glmnet)。
fastshap::explain( 
    fit_workflow, 
    X = as.data.frame(diamonds %>% select(-price)),
    pred_wrapper = predict_function_gbm, 
    nsim = 10
) -> explanations_gbm

这应该生成一个力图。

fastshap::force_plot(
    object = explanations_gbm[1,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1,], 
    display = "viewer", 
    baseline = mean_preds) 

这样可以允许多个垂直堆叠:

fastshap::force_plot(
    object = explanations_gbm[1:20,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1:20,], 
    display = "viewer", 
    baseline = mean_preds) 

添加link="logit"用于分类。将显示更改为“html”以进行Rmarkdown渲染。

现在来看摘要图和依赖图。

技巧是使用reticulate直接访问函数。请注意,对于诸如transformers、numpy等库,同样适用相同的逻辑。

首先,对于依赖关系图。

library(reticulate)
shap = import("shap")
np = import("numpy") 

shap$dependence_plot(
     "rank(3)", 
     data.matrix(explanations_gbm),
     data.matrix(diamond %>% select(-price))
)

请参考shap文档中有关rank(3) -- rank(1)等的解释,这些也适用。

不幸的是,当我尝试直接命名特征(即“cut”)时,它抛出了一个错误。

现在来看总结图:

shap$summary_plot( 
    data.matrix(explanations_gbm),
    data.matrix(diamond %>% select(-price))
)

最后注意:重复渲染图表将导致视觉效果错误。希望这为catboost可视化提供了一个出发点。


非常好的工作,感谢您的答案!对于那些在保存和加载包含已拟合的lightgbm模型的工作流程时遇到问题的人,我有一个简短的补充。似乎由于某种原因,使用write_rds()保存这样的模型不会保存实际的lightgbm模型。必须单独提取和保存模型,并在加载后组合它们以继续使用。要单独保存这样的模型,可以执行以下操作:(1/2) - mugdi
pull_lightgbm <- extract_fit_parsnip(final_model_cv) lightgbm::lgb.save(pull_lightgbm$fit, file = str_c(here::here(),'/data/shap/lightgbm/','lightgbm.model_',dataset_type,'_',y))将 pull_lightgbm 提取并挑出 fit 参数, 然后使用 lightgbm::lgb.save 保存模型至指定路径下的文件,文件名为 'lightgbm.model_数据集类型_目标变量'。 - mugdi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接