不使用caret包实现模型选择的留一交叉验证(LOOCV)

4

我有一个数据集,我正在尝试使用R来找到最佳模型

数据集:

structure(list(V1 = c(1.43359910241166, 0.411971077467806, 0.236361845246534, 
-0.289263426819727, -1.23202861459847, -0.738796384986188, 0.200420172968439, 
1.55763841132305, 0.306848974278087, -1.06336757529454, 0.208462982445177, 
-0.161933544137143, -0.529226737265933, 1.06311471300635, 0.154281146875831, 
0.609577869238014, -0.13720552696616, 0.920650581183744, 1.18282854178987, 
-0.792945405521446, -0.609722647650392, -0.21688852962299, -3.06426186175807, 
0.5498848363865), V2 = c(-0.322161064448354, -0.203202315321523, 
-1.37681357972322, -2.09183896169083, -1.73416522569493, -0.167163678879473, 
-0.496644140621754, -0.378640254832213, 1.71897857982319, 0.987886990249993, 
-0.464176577243306, 0.313599912560739, 0.279305189424942, 0.621879051693468, 
-1.35413705469938, -0.904307866112488, 0.563960402008738, 0.942178870082166, 
1.05504527675313, 1.72684177309, 0.487583880103759, 0.366982237506534, 
0.341207392409481, 0.0878011635613361), V3 = c(-3.06259779143185, 
0.113156471002083, -0.596111339640452, -0.0549465535711572, -0.941898864240695, 
-0.653015082018507, -0.169956676284042, -0.35411953808696, 0.713862293279259, 
1.20019049753438, 0.295042002436139, -0.248609439893179, 1.9312167684667, 
0.674670687298312, 0.224140747830105, -0.59349261052001, 0.0558808922143246, 
0.749007982254512, 1.04584894162381, 0.280651184742914, -0.313568542992107, 
-1.54267082673779, 0.397265080266878, 0.850053716467332), V4 = c(-2.72697312636474, 
0.851743869193346, -0.0599187094978506, 0.341978048955579, -0.484015693411596, 
-0.131475393689722, -0.021866557862478, -1.8191792655517, 1.74883985589495, 
-0.446343374015597, 0.0107633789594956, 0.55528371030783, 0.31132242799237, 
0.0710046563366782, 0.701388784100771, 1.56870481640847, -0.841113890934613, 
-0.881987858407386, 1.37693978208629, -0.488560120797117, 0.366895195216852, 
0.0627972059134885, -0.655416452787133, 0.589188711953821), V5 = c(-1.79836984688233, 
0.50295466271361, -1.17227869532777, 0.661412408202374, 0.853890060320874, 
0.349725611664228, -0.308069063888987, -0.433246608902138, -0.178767449882736, 
1.34125510863996, 0.206474174580616, -0.657831069822233, 0.215632332747088, 
0.573672331330443, -0.202823754124207, 0.609758501891791, 0.222044482387977, 
2.56037433110525, -1.29345283990688, 0.174550400877521, -0.174265216769768, 
0.55419775558349, -0.458225457879011, -2.14861215865916), V6 = c(-0.18026818728965, 
-0.480816154309526, -0.50256960223903, -1.31874854057412, -0.896086924318379, 
-1.79382217103909, -1.60213450587948, -0.481119812364401, 0.377075792056211, 
1.34981730088023, 0.0611706096060544, 0.83874651540465, 0.58899516399665, 
1.24066391945654, -1.08080170411743, 0.597620326597847, -1.21365483260366, 
0.230893469563153, -0.576677068566099, 1.31703258659203, 0.35136844419016, 
0.925208426922233, 1.73348977742475, 0.514617170610343), V7 = c(0.692646184527114, 
1.64958468445801, -0.722861261417701, -0.411292490473929, -1.73926867251488, 
0.479847732965793, 0.224291785874008, -0.650661070391403, -0.20779505689401, 
-0.900990363217965, 0.712570690351891, 0.0291624484927884, 0.613871305452367, 
-0.901767959624604, -0.184130922600279, 2.60941994159236, 0.0144701586285878, 
1.00941096184201, -1.07148389565784, -0.439790917550134, -0.786567592396622, 
0.926243735906836, -1.39392614240757, 0.449016715055174), V8 = c(-0.218730876718155, 
0.279536175230915, -0.860839531512879, 1.62382620633742, -0.656202640703168, 
-3.05801703213563, 0.243884147081474, 0.926579301241956, 0.58184138659717, 
-0.0814078168437784, -0.0668035158044736, 0.00153834639170001, 
0.806767895958209, 0.834326360087515, -0.0790896439523125, 0.07028192584928, 
-0.619273530317688, 1.07556660504801, -1.13473924521572, 0.668145147063421, 
0.758090513962191, 0.456430947715887, -1.73160959029873, 0.179898464937389
), V9 = c(2.56974590352874, -0.263155790779132, 0.646658371629822, 
-0.752843366448987, 0.200047856906594, 0.659371008337854, 1.24620285734473, 
0.94634794321528, -1.3304334794271, 1.33090401796431, -0.819840444239054, 
0.272969704571894, -0.486961950780986, 0.169639870524667, -0.451658048721127, 
-1.04537018765646, -1.16107891054576, -1.20995090654021, -0.839823653138378, 
0.62253221198192, 0.622634591405887, -0.547608828939565, 0.786557248787584, 
-1.16488601898254), V10 = c(-2.26412916115509, 0.67348993363598, 
-0.342027192999345, 0.249815496496033, 0.30352488488975, -0.744451635640458, 
1.58487417838063, -1.01570448604582, -0.541105970352036, 1.13647671257197, 
-0.54886598448313, -0.962789161396563, -0.538065955333129, 0.0781727823942247, 
0.0970193660300894, 1.18927210039089, -0.6957686086705, -0.386785336508124, 
-0.35257548033064, 2.31937096293864, -0.549132531058022, -0.0974568592721698, 
1.43853645612397, -0.0316945106071529), V11 = c(-1.86095070927053, 
0.573330283491408, -1.03183858717977, -1.83745190916475, -0.077180684913356, 
-0.94533768863225, -0.641638632478328, 0.154349543995556, 1.89664953662371, 
1.3494700201932, 1.04343452008192, 1.03948878970461, 0.394740150081754, 
1.24869842481551, 0.33270007318232, 0.373677276693529, 0.670774298645023, 
-0.0191045174843475, 0.0901593335518681, -0.813757209813031, 
-0.527741614949631, -1.55637393322463, -0.0817683516977811, 0.225671587747989
), V12 = c(0.235155165117673, 0.0334071835637513, 0.141983465568844, 
0.441692874434554, 0.0707526888389656, 0.332161357520943, 0.0735800395703528, 
-0.281305763416249, 0.16538364649173, -1.15487983901285, 1.56899928098857, 
-0.567750194144175, 0.541218236160627, 1.48159680904495, -0.568523352759803, 
-0.0545712227404042, -2.93340050534491, 0.662421496450859, 1.11729205722267, 
-0.581175560009803, 0.792548304722282, 0.955149345977461, -0.821090667653583, 
-1.65064484659245), V13 = c(-1.97412125867671, 0.44572205242864, 
-0.274712915255066, -1.44692140049933, -1.18035700830368, -0.260286573948736, 
-0.95815595797825, -0.242760674716397, 0.477953228907608, 0.992878959448502, 
0.48518262700317, -0.882424015844636, 2.03856721097186, 0.782640940939034, 
0.00789969362112054, -0.295894328060507, 1.27922468162261, 0.51472928905797, 
0.0447383908218823, 0.165638463053774, -0.263332324321804, -1.15204704327981, 
-0.258342890933598, 1.95418085394235), V14 = c(-0.181993529177506, 
1.39403983793056, -0.152733307069606, -1.52421030170283, -0.924924418962197, 
-0.364387222675804, 1.10283509955152, 0.0727783277608945, -1.77522562543095, 
1.08664918075833, -1.04803884297856, -0.940631906527986, 1.12617755875177, 
1.21705368328955, -0.279102677856877, 0.343713803473868, 1.26542530994074, 
-0.774396836280874, 0.417125600747737, 1.49096714826284, 0.284166748008431, 
-1.53295609357739, 0.105608954195959, -0.407940490431605), V15 = c(-1.46474265513464, 
1.19486941463858, 0.244933071673175, -0.459011700723317, 0.241718140420906, 
0.282959623977014, 0.00585677416957126, -2.03400384857495, 0.537918956631718, 
-1.04030075327707, -0.557219563096931, -0.252427064540924, 0.547956268292219, 
-0.526158422645334, 0.251554548033225, -0.745912076395139, -0.0351666299711204, 
1.15204026955591, 0.842246979246097, 1.52268303136091, -1.90156582122334, 
-0.142035061237368, 0.385224459566802, 1.94858205925399), V16 = c(0.828548104520814, 
0.713189024971904, 0.774573684318552, -0.425568343697551, 0.259608074896051, 
-1.22029633555545, -0.344755278537263, 0.973749897026122, -0.474553098183039, 
0.0257155566445092, -0.476287023663646, 0.974669054546108, -1.77164686907544, 
1.56028342699847, 1.24959541751606, -0.574201649578301, 1.2099741843225, 
-0.0750690376790856, -0.0110241372862062, -0.984530244128971, 
-2.52086075001167, 0.0287667805602271, 0.731343831738835, -0.451224270663529
), V17 = c(-0.681074029216176, -0.0390433509889875, 0.0328512523391066, 
1.12428796011696, 0.176765286103444, -0.222850967042728, 0.988520019729737, 
2.09179105565111, 0.116819106946508, 0.51447781508645, 1.87648378755979, 
-1.08036997332246, -0.418517756914466, 0.291253915397003, -0.355756145391065, 
0.874359244531183, -2.35192438381252, -0.200559130397419, -1.29305021151605, 
-0.216777649470054, -1.43207151780606, -0.392317470556723, 0.447601162558867, 
0.149101980414553), V18 = c(-1.96475300593026, 0.422711683040055, 
-1.12996029903421, -2.33587910613298, 0.179352498545959, -0.600058127770143, 
-1.35077156778998, -0.727365308346169, 1.43052873254504, 1.07048786910024, 
1.15649152054786, 0.702163956193049, 0.599458156020645, 0.489172517239038, 
0.957116387643539, 0.335186798948586, -0.598777825023964, 0.10012893280699, 
0.0822063408722808, 0.393896776121708, 0.968441995451939, -0.625513747288306, 
-0.437871585012806, 0.883606407251895), V19 = c(0.203243289070699, 
0.206783154660488, 0.0730205054389099, 0.151752499129077, 0.339065300597841, 
0.198750153846351, 0.246574181097875, 0.219716854159337, 0.112571755773366, 
0.108437458425644, 0.159923853880819, 0.198217376539615, 1.27794667790059, 
0.0628191359027579, -0.023668700184257, 0.0103470645871769, -4.55192891533295, 
0.0932248108210876, 0.0372915017676821, 0.103290843005291, 0.1485089149749, 
0.167015138770557, 0.258108289841612, 0.198988855325523), V20 = c(-0.6885610185506, 
0.215106818871655, -1.26229703607397, -1.15415874394993, -0.770942786330788, 
-1.07811513531511, -1.34581518035362, 0.296281823344214, -0.525449013409778, 
1.52659228597052, 1.66011376586839, 0.204981756466606, 2.25710524990656, 
0.850893107617607, 0.181598239123184, 0.0790398588000734, -0.0665218787774753, 
0.411298611581292, 0.0839458342094344, -0.122405563089466, -1.6897393933796, 
1.24061257187769, -0.157685318761091, -0.145878855645788), outcome_var = c(-3, 
4, 1, -1, -1, -3, -1, -3, 3, 2, -2, -3, 1, 0, 0, 0, 3, 0, 2, 
2, 1, -3, 1, 0)), class = "data.frame", row.names = c(NA, -24L
)) 

代码:

train.control <- trainControl(method = "LOOCV")

step.model <- train(outcome_var ~., data = total,
                method = "leapSeq", 
                tuneGrid = data.frame(nvmax = 1:5),
                trControl = train.control
)

step.model$results

summary(step.model$finalModel)

结果:

20 Variables  (and intercept)
Forced in Forced out
V1      FALSE      FALSE
V2      FALSE      FALSE
V3      FALSE      FALSE
V4      FALSE      FALSE
V5      FALSE      FALSE
V6      FALSE      FALSE
V7      FALSE      FALSE
V8      FALSE      FALSE
V9      FALSE      FALSE
V10     FALSE      FALSE
V11     FALSE      FALSE
V12     FALSE      FALSE
V13     FALSE      FALSE
V14     FALSE      FALSE
V15     FALSE      FALSE
V16     FALSE      FALSE
V17     FALSE      FALSE
V18     FALSE      FALSE
V19     FALSE      FALSE
V20     FALSE      FALSE
1 subsets of each size up to 3
Selection Algorithm: 'sequential replacement'
         V1  V2  V3  V4  V5  V6  V7  V8  V9  V10 V11 V12 V13 V14 V15 V16 V17 V18 V19 V20
1  ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " "
2  ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " "*" " " " " " "
3  ( 1 ) " " " " " " " " " " " " " " " " " " "*" " " " " "*" " " " " " " "*" " " " " " "

这使我得到了我想要的结果,但现在我正在尝试制作自己的LOOCV函数,而不使用caret包。但我没有得到相同的结果。

loocv = function(fit) {
  n = length(fit$residuals)
  yvar = fit$model[, 1]
  index = 1:n
  e = rep(NA, n)
  for (i in index) {
    refit = update(fit, subset = index != i)
    pred = predict(refit, dplyr::slice(fit$model, i))
    e[i] = yvar[i] - pred
  }
  return(mean(e^2))
}

如何在不使用caret包的情况下使用LOOCV并找到最佳拟合模型?


很好的问题,希望有人能够帮助你,不幸的是我不知道你问题的答案。 - NotRikBurgers
@IanCampbell 感谢您的推荐,我已经添加了 dput(total) 的输出。 - FlubberBeer
除非是我看漏了,否则我在你的代码中没有看到变量I的定义。正如你可能知道的那样,R中的变量是区分大小写的。 - Ian Campbell
是的,从Rstudio复制到这里出了问题,但应该使用小写字母i - FlubberBeer
1个回答

0
对于像LOOCV这样的交叉验证,模型应该针对每个测试折叠从头开始构建。通过试错,我相信caret使用leaps::regsubsets进行逐步模型选择。
library(leaps)

nvmax = 3 #number of max variables
pred = rep(NA, nrow(total))
for (i in seq(nrow(total))) #LOOCV
  {#train a new model
   tem = regsubsets(x=total[-i,1:20], 
                    y=total[-i,21], 
                    nvmax=nvmax, 
                    method="seqrep")
  coef(tem, nvmax) #best coef chosen
  fit = lm(outcome_var ~ ., 
           data = total[-i,
                  c(which(summary(tem)$which[nvmax,-1]), 
                  21)])

  #predict the hold-out data
  pred[i] = predict(fit, newdata=total[i,])
  }

RMSE(pred, total[,'outcome_var'])
#1.945036

MAE(pred, total[,'outcome_var'])
#1.442353

caret 的结果:

step.model$results
# nvmax     RMSE    Rsquared      MAE
#     3 1.945036 0.238655497 1.442353

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接