使用statsmodel估计与scikit-learn交叉验证,是否可行?

39

我将这个问题发布到了Cross Validated论坛,后来意识到可能在stackoverflow中会找到更合适的受众。

我正在寻找一种方法,可以使用从Python statsmodels得出的fit对象(结果),将其输入到scikit-learn交叉验证方法的cross_val_score中?附带的链接表明这可能是可能的,但我没有成功。

我遇到了以下错误:

评估器应该成为实现'fit'方法的评估器 传递了statsmodels.discrete.discrete_model.BinaryResultsWrapper对象于0x7fa6e801c590

请参阅此链接


我认为这是正确的:https://stackoverflow.com/questions/75535732/how-to-use-the-commonly-used-wrapper-for-models-from-statsmodels-to-apply-cross - Xtiaan
4个回答

50

确实,您不能直接在statsmodels对象上使用cross_val_score,因为它们的接口不同:

  • 在 statsmodels 中,训练数据直接传递到构造函数中。
  • 一个单独的对象包含模型估计的结果。

但是,您可以编写一个简单的包装器,使statsmodels对象看起来像sklearn评估器:

import statsmodels.api as sm
from sklearn.base import BaseEstimator, RegressorMixin

class SMWrapper(BaseEstimator, RegressorMixin):
    """ A universal sklearn-style wrapper for statsmodels regressors """
    def __init__(self, model_class, fit_intercept=True):
        self.model_class = model_class
        self.fit_intercept = fit_intercept
    def fit(self, X, y):
        if self.fit_intercept:
            X = sm.add_constant(X)
        self.model_ = self.model_class(y, X)
        self.results_ = self.model_.fit()
        return self
    def predict(self, X):
        if self.fit_intercept:
            X = sm.add_constant(X)
        return self.results_.predict(X)

这个类包含正确的fitpredict方法,并且可以与sklearn一起使用,例如交叉验证或包含到管道中。就像这样:

from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression

X, y = make_regression(random_state=1, n_samples=300, noise=100)

print(cross_val_score(SMWrapper(sm.OLS), X, y, scoring='r2'))
print(cross_val_score(LinearRegression(), X, y, scoring='r2'))

你可以看到两个模型的输出是相同的,因为它们都是OLS模型,并以相同的方式进行了交叉验证。

[0.28592315 0.37367557 0.47972639]
[0.28592315 0.37367557 0.47972639]

1
我的交叉验证得分(cross_val_score())使用包装器后出现了NaN。有什么想法是什么原因? - Tony Ng
你正在每次在 cross_val_score 中重新初始化模型,我认为它应该在外面。 - Probhakar Sarkar
1
初始化的时刻不会影响结果。 - David Dale
我也得到了一系列NaN,因为我将一个已实例化的sm.OLS(y, X)版本传递给了SMWrapper()。只有当您像上面显示的那样传递SMWrapper(sm.OLS)时,它才起作用。 - fact_finder

13

参照David的建议(但是出现了错误,提示缺少函数get_parameters)以及scikit learn文档,我编写了以下线性回归包装器。它具有与sklearn.linear_model.LinearRegression相同的接口,并且还有summary()函数,该函数提供有关p值、R2和其他统计信息的信息,就像statsmodels.OLS一样。

import statsmodels.api as sm
from sklearn.base import BaseEstimator, RegressorMixin
import pandas as pd
import numpy as np

from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.estimator_checks import check_estimator



class MyLinearRegression(BaseEstimator, RegressorMixin):
    def __init__(self, fit_intercept=True):

        self.fit_intercept = fit_intercept


    """
    Parameters
    ------------
    column_names: list
            It is an optional value, such that this class knows 
            what is the name of the feature to associate to 
            each column of X. This is useful if you use the method
            summary(), so that it can show the feature name for each
            coefficient
    """ 
    def fit(self, X, y, column_names=() ):

        if self.fit_intercept:
            X = sm.add_constant(X)

        # Check that X and y have correct shape
        X, y = check_X_y(X, y)


        self.X_ = X
        self.y_ = y

        if len(column_names) != 0:
            cols = column_names.copy()
            cols = list(cols)
            X = pd.DataFrame(X)
            cols = column_names.copy()
            cols.insert(0,'intercept')
            print('X ', X)
            X.columns = cols

        self.model_ = sm.OLS(y, X)
        self.results_ = self.model_.fit()
        return self



    def predict(self, X):
        # Check is fit had been called
        check_is_fitted(self, 'model_')

        # Input validation
        X = check_array(X)

        if self.fit_intercept:
            X = sm.add_constant(X)
        return self.results_.predict(X)


    def get_params(self, deep = False):
        return {'fit_intercept':self.fit_intercept}


    def summary(self):
        print(self.results_.summary() )

使用示例:

cols = ['feature1','feature2']
X_train = df_train[cols].values
X_test = df_test[cols].values
y_train = df_train['label']
y_test = df_test['label']
model = MyLinearRegression()
model.fit(X_train, y_train)
model.summary()
model.predict(X_test)

如果想要显示列的名称,可以调用以下方法:

model.fit(X_train, y_train, column_names=cols)

要在交叉验证中使用它:

from sklearn.model_selection import cross_val_score
scores = cross_val_score(MyLinearRegression(), X_train, y_train, cv=10, scoring='neg_mean_squared_error')
scores

2
在最后一个评论中,“为什么在交叉验证中使用cross_val_score的时候要使用X_train和y_train而不是只用X和y?” - ctrl_z
4
因为我考虑以下协议: (i) 将样本分成训练和测试集 (ii) 仅使用训练集选择最佳模型,即给出最高交叉验证分数的模型,以避免任何数据泄漏 (iii) 在测试集中检查这种模型在“未见过”的数据上的性能。如果您在整个数据集上进行交叉验证,则会基于您对模型进行评估的同一数据选择模型。这将技术上造成数据泄漏。实际上,这不会向您展示模型如何在完全未见过的数据上运作。 - Andrea Araldo

6

为了参考,如果您使用 statsmodels 公式API 和/或使用 fit_regularized 方法,您可以按照以下方式修改 @David Dale 的包装器类。

import pandas as pd
from sklearn.base import BaseEstimator, RegressorMixin
from statsmodels.formula.api import glm as glm_sm

# This is an example wrapper for statsmodels GLM
class SMWrapper(BaseEstimator, RegressorMixin):
    def __init__(self, family, formula, alpha, L1_wt):
        self.family = family
        self.formula = formula
        self.alpha = alpha
        self.L1_wt = L1_wt
        self.model = None
        self.result = None
    def fit(self, X, y):
        data = pd.concat([pd.DataFrame(X), pd.Series(y)], axis=1)
        data.columns = X.columns.tolist() + ['y']
        self.model = glm_sm(self.formula, data, family=self.family)
        self.result = self.model.fit_regularized(alpha=self.alpha, L1_wt=self.L1_wt, refit=True)
        return self.result
    def predict(self, X):
        return self.result.predict(X)

-1

虽然我认为这不是技术上的scikit-learn,但有一个名为pmdarima(链接到PyPi上的pmdarima包)的软件包,它封装了statsmodel并提供了类似于scikit-learn的接口。


你好,安德烈。请考虑在你的回答中添加更多信息,而不是链接到外部来源。 - Igor Escodro
请仅返回翻译后的文本:请总结链接内容,以防链接失效。 - LudvigH

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接