使用GridSearchCV时跳过禁止的参数组合

20

我想使用GridSearchCV来贪婪地搜索支持向量分类器的整个参数空间,但是某些参数组合被LinearSVC所禁止,并且会抛出异常。特别地,dualpenaltyloss参数存在互斥的组合:

例如,以下代码:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV

iris = datasets.load_iris()
parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \
              'loss': ['hinge', 'squared_hinge']}
svc = svm.LinearSVC()
clf = GridSearchCV(svc, parameters)
clf.fit(iris.data, iris.target)

返回ValueError:不支持的参数组合:当dual=False时,penalty='l2'和loss='hinge'的组合不受支持。 参数:penalty='l2',loss='hinge',dual=False

我的问题是:是否可能使GridSearchCV跳过模型禁止的参数组合?如果不能,是否有一种简单的方法来构造不会违反规则的参数空间?


如果至少在这种情况下我们能够抑制FitFailedWarning语句,那么这仍然是一个问题,但问题会小一些。我面临着同样的问题,我知道某些组合是非法的,但是为了防止这些组合,逻辑(如下所述)过于丑陋。 - demongolem
2个回答

28

我通过将error_score=0.0传递给GridSearchCV来解决了这个问题:

error_score: ‘raise’(默认)或数字

如果在estimator拟合期间发生错误,则为分数分配的值。如果设置为‘raise’,则会引发错误。如果给定数字值,则会引发FitFailedWarning。该参数不影响refit步骤,后者将始终引发错误。

更新:sklearn的新版本会打印出一堆ConvergenceWarningFitFailedWarning。我很难用contextlib.suppress抑制它们,但有一个骚操作可以绕过这个问题,涉及到一种测试上下文管理器:

from sklearn import svm, datasets 
from sklearn.utils._testing import ignore_warnings 
from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 
from sklearn.model_selection import GridSearchCV 

with ignore_warnings(category=[ConvergenceWarning, FitFailedWarning]): 
    iris = datasets.load_iris() 
    parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \ 
                 'loss': ['hinge', 'squared_hinge']} 
    svc = svm.LinearSVC() 
    clf = GridSearchCV(svc, parameters, error_score=0.0) 
    clf.fit(iris.data, iris.target)

1
有没有一种方法可以在它们实际输出任何错误之前避免这些组合(或任何其他组合)的解决方法? - GRoutar
@Khabz我的答案太长了,无法放在评论中,所以我将其作为另一个答案发布。 - crypdick
@crypdick 有没有办法避免在结果中看到 FitFailedWarning? - Nihat
1
@Nihat 我编辑了我的答案以消除新的警告。 - crypdick

5
如果您想完全避免探索特定组合(而不必等待出现错误),则必须自己构建网格。GridSearchCV可以接受一个字典列表,其中列表中的每个字典所覆盖的网格都会被探索。
在这种情况下,条件逻辑还好,但对于更复杂的情况,这将非常繁琐。
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from itertools import product

iris = datasets.load_iris()

duals = [True, False]
penaltys = ['l1', 'l2']
losses = ['hinge', 'squared_hinge']
all_params = list(product(duals, penaltys, losses))
filtered_params = [{'dual': [dual], 'penalty' : [penalty], 'loss': [loss]}
                   for dual, penalty, loss in all_params
                   if not (penalty == 'l1' and loss == 'hinge') 
                   and not ((penalty == 'l1' and loss == 'squared_hinge' and dual is True))
                  and not ((penalty == 'l2' and loss == 'hinge' and dual is False))]

svc = svm.LinearSVC()
clf = GridSearchCV(svc, filtered_params)
clf.fit(iris.data, iris.target)

2
我很感激你的努力,但这似乎是一个有些靠不住的解决方案,会导致大量冗长的代码,对于一个有很多限制条件的问题来说并不适合。 - GRoutar
1
@Khabz同意,这段代码真是诅咒!如果有无数个条件语句,一种可能是在filtered_params中以编程方式构建条件语句列表,然后使用str.join(conditionals_list)将其连接起来,最后使用eval()函数执行字符串以进行列表推导。 - crypdick

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接