GridSearchCV: "TypeError: 'StratifiedKFold' object is not iterable" 网格搜索交叉验证: "TypeError: 'StratifiedKFold'对象不可迭代"

11

我想在一个RandomForestClassifier中执行GridSearchCV,但是数据不平衡,所以我使用了StratifiedKFold:

from sklearn.model_selection import StratifiedKFold
from sklearn.grid_search import GridSearchCV
from sklearn.ensemble import RandomForestClassifier

param_grid = {'n_estimators':[10, 30, 100, 300], "max_depth": [3, None],
          "max_features": [1, 5, 10], "min_samples_leaf": [1, 10, 25, 50], "criterion": ["gini", "entropy"]}

rfc = RandomForestClassifier()

clf = GridSearchCV(rfc, param_grid=param_grid, cv=StratifiedKFold()).fit(X_train, y_train)

但是出现了一个错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-597-b08e92c33165> in <module>()
     9 rfc = RandomForestClassifier()
     10 
---> 11 clf = GridSearchCV(rfc, param_grid=param_grid, cv=StratifiedKFold()).fit(X_train, y_train)

c:\python34\lib\site-packages\sklearn\grid_search.py in fit(self, X, y)
    811 
    812         """
--> 813         return self._fit(X, y, ParameterGrid(self.param_grid))

c:\python34\lib\site-packages\sklearn\grid_search.py in _fit(self, X, y, parameter_iterable)
    559                                     self.fit_params, return_parameters=True,
    560                                     error_score=self.error_score)
--> 561                 for parameters in parameter_iterable
    562                 for train, test in cv)

c:\python34\lib\site-packages\sklearn\externals\joblib\parallel.py in __call__(self, iterable)
    756             # was dispatched. In particular this covers the edge
    757             # case of Parallel used with an exhausted iterator.
--> 758             while self.dispatch_one_batch(iterator):
    759                 self._iterating = True
    760             else:

c:\python34\lib\site-packages\sklearn\externals\joblib\parallel.py in dispatch_one_batch(self, iterator)
    601 
    602         with self._lock:
--> 603             tasks = BatchedCalls(itertools.islice(iterator, batch_size))
    604             if len(tasks) == 0:
    605                 # No more tasks available in the iterator: tell caller to stop.

c:\python34\lib\site-packages\sklearn\externals\joblib\parallel.py in __init__(self, iterator_slice)
    125 
    126     def __init__(self, iterator_slice):
--> 127         self.items = list(iterator_slice)
    128         self._size = len(self.items)

c:\python34\lib\site-packages\sklearn\grid_search.py in <genexpr>(.0)
    560                                     error_score=self.error_score)
    561                 for parameters in parameter_iterable
--> 562                 for train, test in cv)
    563 
    564         # Out is a list of triplet: score, estimator, n_test_samples

TypeError: 'StratifiedKFold' object is not iterable
当我写cv=StratifiedKFold(y_train)时出现了ValueError: The number of folds must be of Integral type.,但是当我写cv=5时可以正常工作。我不明白StratifiedKFold有什么问题。
4个回答

10

我也遇到了完全相同的问题。对于我起作用的解决方案是替换

from sklearn.grid_search import GridSearchCV

with

:与,带有。
from sklearn.model_selection import GridSearchCV

那么它应该可以正常工作。


6
问题在于API更改,正如其他答案中提到的那样,但答案可以更加明确。 "cv"参数文档说明如下:
cv:int,交叉验证生成器或可迭代对象,可选。 确定交叉验证拆分策略。 cv的可能输入为: - None,使用默认的3倍交叉验证; - 整数,指定折叠次数; - 用作交叉验证生成器的对象; - 产生训练/测试拆分的可迭代对象。
对于整数/无输入,如果y是二进制或多类,则使用StratifiedKFold;如果估计量是分类器或者y既不是二进制也不是多类,则使用KFold。
因此,无论使用何种交叉验证策略,只需使用函数“split”提供生成器即可,如建议所示。
kfolds = StratifiedKFold(5)
clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain))
clf.fit(xtrain, ytrain)

2
似乎应将cv=StratifiedKFold()).fit(X_train, y_train)更改为cv=StratifiedKFold()).split(X_train, y_train).,这里涉及到it技术。

这与错误无关。这一行代码:clf = GridSearchCV(rfc, param_grid=param_grid, cv=StratifiedKFold()).fit(X_train, y_train) 仅仅定义了 clf 对象,然后调用 fit 方法来训练/拟合 clf。 - seralouk
@rll还提到应该用split替换fit。 - ebrahimi

0

API在最新的版本中已经更改。创建stratifiedKFold对象时,您先前使用过y参数,现在只需传递数字即可。您稍后再传递y参数。


我写了 cv=StratifiedKFold(10),但是出现了 TypeError: 'StratifiedKFold' object is not iterable 的错误。我应该在什么时候传递 y 呢? - user183897
在当前版本中,您需要导入sklearn.model_selection.StratifiedKFold。然后,您可以执行cv=StratifiedKFold(10),这样就不会出现错误。但是,也许您正在从先前的模块中导入,该模块仍然存在于兼容性目的,直到20版。 - simon
我能再问一个问题吗?我从这个网站http://www.lfd.uci.edu/~gohlke/pythonlibs/#scikit-learn下载了文件scikit_learn-0.18-cp34-cp34m-win32.whl,安装后出现了“ImportError: DLL load failed: %1 is not a valid Win32 application.”的错误提示。这是怎么回事? - user183897
可能是某个依赖项缺失了。简单的方法是下载Anaconda,然后它就可以正常工作了。 - simon

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接