在GridSearchCV中是否有一种方法可以查看交叉验证的折叠方式?

11

我目前正在使用Python中的GridSearchCV进行3倍交叉验证,以优化超参数。我想知道是否有办法查看在GridSearchCV中使用的交叉验证中的训练和测试数据的索引?

1个回答

6

如果您不想在CV阶段折叠之前洗牌样本,那么可以这样做。您可以将KFold(或其他CV类)的实例传递给GridSearchCV构造函数,并像这样访问其折叠:

import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold

params = {'penalty' : ['l1', 'l2'], 'C' : [1,2,3]}
grid = GridSearchCV(LogisticRegression(), params, cv=KFold(n_splits=3))

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8]])

for train, test in grid.cv.split(X):
    print('TRAIN: ', train, ' TEST: ', test)

输出结果为:

TRAIN:  [2 3 4 5]  TEST:  [0 1]
TRAIN:  [0 1 4 5]  TEST:  [2 3]
TRAIN:  [0 1 2 3]  TEST:  [4 5]

对于非随机的CV,交叉验证的折叠总是相同的,因此您可以确信这些折叠将在网格搜索期间使用。
如果您想在折叠之前对样本进行洗牌,则会更加复杂,因为每次调用cv.split()都会生成不同的拆分。我能想到两种方法:
1. 您可以向CV对象提供固定的random_state,例如KFold(n_splits=3, shuffle=True, random_state=42)。 2. 在创建GridSearchCV对象之前,从KFold迭代器中创建一个列表。
因此,对于第二种方法,请执行:
grid = GridSearchCV(LogisticRegression(), params, 
                    cv=list(KFold(n_splits=3, shuffle=True).split(X)))

除了迭代器外,列表是一个固定的对象,除非您手动操作它,否则它将在所有GridSearch迭代中保持相同的值。


非常感谢您的帮助。它可以在sklearn v0.18中工作,但不幸的是我目前正在使用v0.17。有没有办法在v0.17中实现这个? - Frederica
它应该是一样的,只需记住您需要从sklearn.grid_searchsklearn.cross_validation导入GridSearchCVKFold,而不是从model_selection。在0.18中,模块组织方式发生了变化。 - Toterich
@Toterich - 我有点惊讶你的第二种方法“在创建GridSearchCV对象之前,从KFold迭代器创建一个列表。”能够工作,因为根据文档(http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html),cv应该是None、int、generator或iterable。将列表传递给它是否显而易见,还是未经记录? - DJBunk
列表是可迭代对象,因为它们实现了 __iter__()。请参见这里的“可迭代对象” - Toterich

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接