如何以编程方式检测Scikit-learn警告

8

使用sklearn.neural_network.MLPClassifier拟合模型时,有时会在控制台打印一个警告:

ConvergenceWarning: Stochastic Optimizer:达到最大迭代次数(300),但优化尚未收敛。

有没有方法在运行时检测警告,以便我可以采取行动?

3个回答

9
你可以使用 warnings.catch_warnings 实时捕获警告。
import warnings

with warnings.catch_warnings()
    warnings.filterwarnings('error')
    try:
        model.fit(X, y)
    except Warning:
        # do something in response

这个结构将捕获任何在代码行中出现的警告,使你能够根据需要回应它。在这种情况下,你可以修改一些超参数使模型更容易收敛。

你也可以使用 warnings.filterwarnings 忽略警告,并指定要忽略的警告类型。

要忽略 ConvergenceWarning

from sklearn.execpetions import ConvergenceWarning

warnings.filterwarnings('ignore', category=ConvergenceWarning)

...

警告是否是API的一部分?文档似乎没有给出任何保证,即警告在各个版本中都将保持不变,因此这种方法对我来说似乎有点脆弱。 - Juan I Carrano
在“with”行末缺少分号吗? - Dusan Kojic
这个答案有两个问题:(1) 'with'行应该以冒号结尾。(2) 使用'warnings.filterwarnings('error')'将警告转换为异常,所以捕获子句'except Warning'不起作用。应该改为'except Exception'。 - undefined

2

在拟合后检查n_iter_属性。如果它小于您配置的最大迭代次数(max_iter),则说明已经收敛。


你可以从哪个对象/类中在块内获取n_iter?一些代码会有所帮助。 - Shawn Cicoria
这个答案是不正确的。例如,AffinityPropagation可能会提前停止,发出警告并返回一些任意结果,而不是实际收敛。 - Artur Pschybysz
1
assert model.n_iter_ < MAX_ITER, "Convergence failed!" - Rune Kaagaard

1
假设您想训练您的scikit-learn模型,并且希望能够存储警告(如果有的话)。 假设您按以下方式拟合模型:
clf.fit(X_train,y)
如果您想捕获警告,则可以运行该模型:
with warnings.catch_warnings(record=True) as caught_warnings:
    clf.fit(X_train, y)

最后,您可以通过以下方式迭代caught_warnings来获取警告:
for warn in caught_warnings:
    print(warn)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接