我有以下代码,用于创建一个使用 lightgbm
模型的 tidymodels
工作流程。然而,当我尝试将其保存为 .rds
对象并进行预测时,会出现一些问题。
library(AmesHousing)
library(treesnip)
library(lightgbm)
library(tidymodels)
tidymodels_prefer()
### Model ###
# data
data <- make_ames() %>%
janitor::clean_names()
data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
lot_frontage, year_built, year_remod_add, year_sold))
data$id <- c(1:nrow(data))
data <- data %>%
mutate(id = as.character(id)) %>%
select(id, everything())
# model specification
lgbm_model <- boost_tree(
mtry = 7,
trees = 347,
min_n = 10,
tree_depth = 12,
learn_rate = 0.0106430579211173,
loss_reduction = 0.000337948798058139,
) %>%
set_mode("regression") %>%
set_engine("lightgbm", objective = "regression")
# recipe and workflow
lgbm_recipe <- recipe(sale_price ~., data = data) %>%
update_role(id, new_role = "ID") %>%
step_corr(all_predictors(), threshold = 0.7) %>%
prep()
lgbm_workflow <- workflow() %>%
add_recipe(lgbm_recipe) %>%
add_model(lgbm_model)
# fit workflow
fit_lgbm_workflow <- lgbm_workflow %>%
fit(data = data)
# predict
data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)
### CASE 1: Save the workflow with SaveRDS()
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
# Predict - error: Attempting to use a Booster which no longer exists
predict(new_lgbm_workflow, new_data = data_predict)
### CASE 2: Save the workflow and the fitted model separately
fitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
lightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
new_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")
new_lgbm_workflow$fit$fit <- new_lgbm_model
# Predict - error: cannot predict on data of class ‘tbl_df’‘tbl’‘data.frame’
predict(new_lgbm_workflow, new_data = data_predict)
只有使用 lightgbm
模型的工作流似乎存在此问题。对于其他类型的模型(随机森林、xgboost、glm等),我可以使用 saveRDS()
将拟合的工作流保存下来,用 readRDS()
读取,然后使用新数据进行预测,一切正常。
对于第二种情况,显然基础的预测函数将被更改为 predict.lgb.Booster()
,它接受一个 matrix
作为输入。但我的 id 变量具有 character
格式,而矩阵中的所有列必须具有相同的格式。
是否有一种方法可以保存整个 workflow
以供将来使用?
readr::write_rds()
来保存工作流对象时我从未遇到过任何问题 - 或许你可以试试这个函数。 - Mark Rieke