在R中使用Caret保存和加载catboost模型

3

我能够使用caret(在Rstudio中)训练Catboost模型,效果非常好。

my_catboost <- caret::train(x, y, 

              method=catboost.caret, 
              trControl=fitControl, 
              tuneGrid = param,
              metric = "ROC")

如果我在同一会话中使用模型来预测新数据,没有问题,它可以正常工作:
output <- caret::predict.train(my_catboost, newdata=x_testing, type="prob")

然而,如果我保存模型并稍后加载它(或者保存它,删除“my_catboost”并加载),函数“predict”将在没有错误消息的情况下崩溃R和Rstudio,并且无法在Rstudio日志中找到任何信息。 加载后,我可以在全局环境中看到创建的模型,它看起来很好。
我尝试了R函数save和load,saveRDS和readRDS,但两者都崩溃了。
谢谢!

在使用caret调整模型后,只需保存最终模型:my_catboost$finalModel,使用catboost::catboost.save_model进行保存,并使用catboost::catboost.load_model进行加载。如果您在caret::train中使用了preProcess参数进行数据转换,则情况会稍微复杂一些。如果您对此感兴趣,我可以提供一个可重现的示例,使用内置数据集。此外,这个问题可能应该发布在Github上的问题 - undefined
谢谢@missuse。RStudio在catboost::catboost.save_model时崩溃。我在这里打开了一个GitHub问题:https://github.com/catboost/catboost/issues/342#issuecomment-431256587 - undefined
请告诉我我在 GitHub 上发布的测试是否有意义。否则,我们将尝试一种带有内置数据的方法。 - undefined
你误解了我的评论,请检查我添加的答案。这应该可以运行。 - undefined
1个回答

6
你误解了我的评论。这里是使用内置数据集Sonar的答案:
library(caret)
library(catboost)
library(mlbench)
data(Sonar)

创建训练和测试数据集:

set.seed(1)

tr <- createDataPartition(Sonar$Class, p = 0.7, list = FALSE)

trainer <- Sonar[tr,]
tester <- Sonar[-tr,]

训练模型:

fitControl <- trainControl(method = "cv",
                           number = 3,
                           savePredictions = TRUE,
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE)

model <- train(x = trainer[,1:60],
               y = trainer$Class,
               method = catboost.caret, 
               trControl = fitControl, 
               tuneLength = 5,
               metric = "ROC")

使用caret进行预测:

preds1 <- predict(model, tester, type = "prob")

保存最终模型:

catboost::catboost.save_model(model$finalModel, "model")

加载已保存的模型:

model2 <- catboost::catboost.load_model("model")

使用已保存的模型进行预测:

preds2 <- catboost.predict(model2,
                           catboost.load_pool(tester),
                           prediction_type = "Probability")

检查预测结果是否相等

all.equal(preds1[,2], preds2)

编辑:在此期间:

saveRDS(model, "caret.model.rds")
model3 <- readRDS("caret.model.rds")
preds3 <- predict(model3, tester, type = "prob")

R会话崩溃的原因

R version 3.5.0 (2018-04-23)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C                           LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] mlbench_2.1-1        catboost_0.10.3      caret_6.0-80         ggplot2_2.2.1        lattice_0.20-35      RevoUtils_11.0.0    
[7] RevoUtilsMath_11.0.0

loaded via a namespace (and not attached):
 [1] httr_1.3.1         magic_1.5-8        ddalpha_1.3.3      tidyr_0.8.1        sfsmisc_1.1-2      jsonlite_1.5      
 [7] viridisLite_0.3.0  splines_3.5.0      foreach_1.5.0      prodlim_2018.04.18 assertthat_0.2.0   stats4_3.5.0      
[13] DRR_0.0.3          yaml_2.1.19        robustbase_0.93-0  ipred_0.9-6        pillar_1.2.3       glue_1.2.0        
[19] digest_0.6.15      colorspace_1.3-2   recipes_0.1.2      htmltools_0.3.6    Matrix_1.2-14      plyr_1.8.4        
[25] psych_1.8.4        timeDate_3043.102  pkgconfig_2.0.1    CVST_0.2-2         broom_0.4.4        purrr_0.2.4       
[31] scales_0.5.0       gower_0.1.2        lava_1.6.1         tibble_1.4.2       withr_2.1.2        nnet_7.3-12       
[37] lazyeval_0.2.1     mnormt_1.5-5       survival_2.41-3    magrittr_1.5       nlme_3.1-137       MASS_7.3-49       
[43] dimRed_0.1.0       foreign_0.8-70     class_7.3-14       tools_3.5.0        data.table_1.11.4  stringr_1.3.1     
[49] plotly_4.7.1       kernlab_0.9-26     munsell_0.4.3      bindrcpp_0.2.2     compiler_3.5.0     RcppRoll_0.2.2    
[55] rlang_0.2.0        grid_3.5.0         iterators_1.0.10   htmlwidgets_1.2    geometry_0.3-6     gtable_0.2.0      
[61] ModelMetrics_1.1.0 codetools_0.2-15   abind_1.4-5        reshape2_1.4.3     R6_2.2.2           lubridate_1.7.4   
[67] dplyr_0.7.5        bindr_0.1.1        stringi_1.1.7      parallel_3.5.0     Rcpp_0.12.17       rpart_4.1-13      
[73] DEoptimR_1.0-8     tidyselect_0.2.4  

谢谢!它运行正常。但是我很好奇为什么它在使用其他保存方法时会崩溃。 - undefined
1
很高兴能帮到你。正如我在对你问题的评论中提到的,当保存一个插入符模型时无法工作的事实值得在 GitHub 上的问题中提及。 - undefined

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接