如何逐步提取经过mlr3调整的图形?

3

以下是我的代码

library(mlr3verse)
library(mlr3pipelines)
library(mlr3filters)
library(paradox)
filter_importance = mlr_pipeops$get(
  "filter",
  filter = FilterImportance$new(learner = lrn("classif.ranger", importance = "impurity")),
  param_vals = list(filter.frac = 0.7)
)

learner_classif = lrn(
  "classif.ranger",
  predict_type = "prob",
  importance = "impurity",
  num.trees = 500
)
polrn_classif = PipeOpLearner$new(learner_classif)

# create learner graph 
glrn_classif = filter_importance %>>%  polrn_classif
glrn_classif = GraphLearner$new(glrn_classif)
glrn_classif$predict_type = "prob"

# task 

task = tsk("german_credit")

# set search_space
ps_classif = ParamSet$new(list(
  ParamInt$new("classif.ranger.num.trees", lower = 300, upper = 500),
  ParamDbl$new("classif.ranger.sample.fraction", lower = 0.7, upper = 0.8)
))

# auto tunning
at = AutoTuner$new(
  learner = glrn_classif, 
  resampling = rsmp("cv", folds = 3),
  measure = msr("classif.auc"), 
  search_space = ps_classif, 
  terminator = trm("evals", n_evals = 3), 
  tuner = tnr("random_search")
)

# sampling
rr = resample(task, at, rsmp("cv", folds = 2))

在我从重采样得到rr对象和训练完成的学习器at之后,我能否请问如何提取出这些步骤所做的内容?

例如:

  • 当我有at对象的结果时,我该如何手动重新运行?
  • 每个步骤使用了哪些样本(train_index, test_index)?
  • filter_importance步骤中选择了哪些变量?这些变量在这一步骤中的分数是多少?

非常感谢!!!

1个回答

4
为了在重新取样后更改模型,最好使用 store_models = TRUE 调用 resample 函数。
以您的示例为例。
library(mlr3verse)

set.seed(1)
rr <- resample(task,
               at,
               rsmp("cv", folds = 2),
               store_models = TRUE)

完成重采样后,您可以像这样访问生成对象的内部结构:

获取每个折叠中的行ID:

rr$resampling$instance
#output
      row_id fold
   1:      5    1
   2:      8    1
   3:      9    1
   4:     12    1
   5:     13    1
  ---            
 996:    989    2
 997:    993    2
 998:    994    2
 999:    995    2
1000:    996    2

通过这些调整后的自动调谐器,我们可以手动生成预测。

生成测试索引列表

rsample <- split(rr$resampling$instance$row_id,
                 rr$resampling$instance$fold)

遍历交叉验证折叠和调整自动调参器并进行预测:

lapply(1:2, function(i){
  x <- rsample[[i]] #get the test row ids
  task_test <- task$clone() #clone the task so we don't change the original task
  task_test$filter(x) #filter on the test row ids
  preds <- rr$learners[[i]]$predict(task_test) #use the trained autotuner and above filtered task
  preds
  }) -> preds_manual

要检查这些预测是否与resample的输出匹配

all.equal(preds_manual,
          rr$predictions())
#output
TRUE

获取有关调整的信息

zz <- rr$data$learners()$learner

lapply(zz, function(x) x$tuning_result)
#output
[[1]]
   classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1:                      342                      0.7931022          <list[7]>
    x_domain classif.auc
1: <list[2]>   0.7981283

[[2]]
   classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1:                      407                      0.7964164          <list[7]>
    x_domain classif.auc
1: <list[2]>   0.7706533

插槽
zz[[1]]$learner$state$model$importance

包含有关 filter_importance 步骤的信息。

具体来说:

lapply(zz, function(x) x$learner$state$model$importance$scores)
#output
[[1]]
                 amount                  status                     age 
              27.491369               25.776145               22.021369 
               duration                 purpose          credit_history 
              18.732521               16.251643               14.884843 
    employment_duration                 savings                property 
              11.225678               10.796583                9.078619 
    personal_status_sex       present_residence        installment_rate 
               8.914802                7.875384                7.491573 
                    job          number_credits other_installment_plans 
               6.293323                5.662485                5.345666 
                housing               telephone           other_debtors 
               4.869471                3.742213                3.548856 
          people_liable          foreign_worker 
               2.632163                1.054919 

[[2]]
                 amount                duration                     age 
              26.764389               22.139400               20.749865 
                 status                 purpose     employment_duration 
              20.524764               11.793789               10.962301 
         credit_history        installment_rate                 savings 
              10.416572                9.597835                9.491894 
               property       present_residence                     job 
               9.403157                7.877391                6.760945 
    personal_status_sex                 housing other_installment_plans 
               6.699065                5.811131                5.710761 
              telephone           other_debtors          number_credits 
               4.716322                4.318972                3.974793 
          people_liable          foreign_worker 
               3.196563                0.846520 

包含要素排名。尽管

lapply(zz, function(x) x$learner$state$model$importance$outtasklayout)
#output
[[1]]
                     id    type
 1:                 age integer
 2:              amount integer
 3:      credit_history  factor
 4:            duration integer
 5: employment_duration  factor
 6:    installment_rate ordered
 7:                 job  factor
 8:      number_credits ordered
 9: personal_status_sex  factor
10:   present_residence ordered
11:            property  factor
12:             purpose  factor
13:             savings  factor
14:              status  factor

[[2]]
                     id    type
 1:                 age integer
 2:              amount integer
 3:      credit_history  factor
 4:            duration integer
 5: employment_duration  factor
 6:             housing  factor
 7:    installment_rate ordered
 8:                 job  factor
 9: personal_status_sex  factor
10:   present_residence ordered
11:            property  factor
12:             purpose  factor
13:             savings  factor
14:              status  factor

此处包含了筛选步骤后保留下来的特性。


非常感谢@missuse。我需要采取哪些步骤来set.seed,是自动调整还是重新采样或两者都要?我能否从hash中复制? - BinhNN
首先,我建议实例化重采样的外部折叠,以便您可以再次使用相同的数据分割。如果您计划比较多个流水线,请使用benchmark函数调用。在实际计算之前设置种子,因此在调用resamplebenchmark之前。我不知道“从哈希值重现”是什么意思? - missuse
哦,抱歉,哈希是对象的唯一标识符。我记住它是另一种“set.seed”类型。 - BinhNN

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接