我目前正在使用Python中的GridSearchCV进行3倍交叉验证,以优化超参数。我想知道是否有办法查看在GridSearchCV中使用的交叉验证中的训练和测试数据的索引?
我目前正在使用Python中的GridSearchCV进行3倍交叉验证,以优化超参数。我想知道是否有办法查看在GridSearchCV中使用的交叉验证中的训练和测试数据的索引?
如果您不想在CV阶段折叠之前洗牌样本,那么可以这样做。您可以将KFold
(或其他CV类)的实例传递给GridSearchCV
构造函数,并像这样访问其折叠:
import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
params = {'penalty' : ['l1', 'l2'], 'C' : [1,2,3]}
grid = GridSearchCV(LogisticRegression(), params, cv=KFold(n_splits=3))
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8]])
for train, test in grid.cv.split(X):
print('TRAIN: ', train, ' TEST: ', test)
输出结果为:
TRAIN: [2 3 4 5] TEST: [0 1]
TRAIN: [0 1 4 5] TEST: [2 3]
TRAIN: [0 1 2 3] TEST: [4 5]
grid = GridSearchCV(LogisticRegression(), params,
cv=list(KFold(n_splits=3, shuffle=True).split(X)))
除了迭代器外,列表是一个固定的对象,除非您手动操作它,否则它将在所有GridSearch迭代中保持相同的值。
sklearn.grid_search
和sklearn.cross_validation
导入GridSearchCV
和KFold
,而不是从model_selection
。在0.18中,模块组织方式发生了变化。 - Toterich__iter__()
。请参见这里的“可迭代对象”。 - Toterich