为什么在将StratifiedKFold()作为GridSearchCV的参数传递时需要调用split()函数?

3

我想做什么?

我试图在 GridSearchCV() 中使用 StratifiedKFold()

那么,是什么让我感到困惑了?

当我们使用 K 折交叉验证时,只需在 GridSearchCV() 中传递 CV 的数量,如下所示。

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=5, scoring='f1', return_train_score=True, n_jobs=2)

接下来,当我需要使用StratifiedKFold()时,我认为流程应该保持不变。也就是说,只需设置拆分数-StratifiedKFold(n_splits=5)cv

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=StratifiedKFold(n_splits=5), scoring='f1', return_train_score=True, n_jobs=2)

但是这个答案说:

whatever the cross validation strategy used, all that is needed is to provide the generator using the function split, as suggested:

kfolds = StratifiedKFold(5)
clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain))
clf.fit(xtrain, ytrain)
此外,这个问题的一个答案也建议这样做。 这意味着,他们建议在使用GridSearchCV()时调用split()函数:StratifiedKFold(n_splits=5).split(xtrain,ytrain)。 但是,我发现调用split()和不调用split()会给出相同的f1分数。
因此,我的问题是:
  • 我不明白为什么在分层K折交叉验证中需要调用split()函数,而在K折交叉验证中不需要这样做。

  • 如果调用了split()函数,GridSearchCV()如何工作,因为split()函数返回训练和测试数据集索引? 也就是说,我想知道GridSearchCV()将如何使用这些索引?

1个回答

3

GridSearchCV是聪明的,可以为cv参数提供多个选项 - 一个数字、一个拆分索引的迭代器或一个带有拆分函数的对象。你可以在这里看到代码 这里,下面是复制的代码。

cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
    if (classifier and (y is not None) and
            (type_of_target(y) in ('binary', 'multiclass'))):
        return StratifiedKFold(cv)
    else:
        return KFold(cv)

if not hasattr(cv, 'split') or isinstance(cv, str):
    if not isinstance(cv, Iterable) or isinstance(cv, str):
        raise ValueError("Expected cv as an integer, cross-validation "
                         "object (from sklearn.model_selection) "
                         "or an iterable. Got %s." % cv)
    return _CVIterableWrapper(cv)

return cv  # New style cv objects are passed without any modification

基本上,如果你没有传递任何参数,它将使用5折交叉验证(KFold)。如果是分类问题并且目标是二元/多元的,它还会自动聪明地使用分层k折(StratifedKFold)。

如果你传递了一个带有split函数的对象,它就会使用该函数。如果您没有传递任何参数,但传递了可迭代对象,则它会假定该对象是拆分索引的可迭代对象,并为您封装它。

所以,在您的情况下,假设这是一个分类问题,并且目标是二元/多元的,则以下所有选项都将给出完全相同的结果/拆分-使用哪个选项都没有关系!

cv=5
cv=StratifiedKFold(5)
cv=StratifiedKFold(5).split(xtrain,ytrain)

感谢您的回复。您提到:“如果传递一个带有分割函数的对象,它将使用该函数。”但我不明白“GridSearchCV()如何使用通过分割找到的那些索引?”请您详细说明一下吗? - Md. Sabbir Ahmed
因此,对于网格搜索中的每个参数集,它将使用拆分来运行交叉验证 - 因此,如果您在param grid中有2个参数的3个选项(6个集合),并且进行5倍交叉验证,则实际上您将训练和验证30个模型。然后,在交叉验证运行中具有最高平均验证分数的参数集被选为“最佳”。 - Ken Syme

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接