如何使用scikit-learn API实现元估计器?

8

我希望实现一个简单的包装器/元估计器,它与scikit-learn兼容。很难找到对我所需内容的全面描述。

目标是拥有一个回归器,还要学习一个阈值以成为分类器。因此,我想到:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

这个实现是否包含我需要的完整 API?

我的主要问题是在哪里放置 threshold。我希望它只被学习一次,并且可以在后续的 .fit 调用中使用新数据而不需要重新调整。但是当前版本需要在每次 .fit 调用时重新调整 - 这不是我想要的。

另一方面,如果我将其作为固定参数 self.threshold 并传递给 __init__,那么我就不能使用数据更改它了吗?

如何创建一个 threshold 参数,它可以在一次 .fit 调用中进行调整,并在后续的 .fit 调用中保持不变?


请问为什么要进行多次“fit”调用?是因为在线学习吗?还是由于交叉验证?或者其他原因? - Shihab Shahriar Khan
@ShihabShahriarKhan 只有一个适合特定数据集(以及存储在 optimal_threshold 中的某些测试数据)的阈值才能用于确定阈值。从那时起,我希望不再重新调整阈值,只调整我的数据折叠的回归器。 - Gere
1
我可能误解了什么...如果你在init中将threshold_初始化为None,并在fit中检查它的值是否被设置,那么它不会起作用吗?有点类似于warm_start参数。 - Shihab Shahriar Khan
3
很难说它是否有效,因为sklearn有一定的API,并且像clone和其他元估计器之类的函数在幕后执行某种魔法。这就是为什么我想知道使用sklearn的正确方式。 - Gere
1
为什么不在构造函数中初始化 self.threshold = None,然后加上一个 if 语句 - if self.threshold is not None: self.threshold = optimal_threshold(y_raw)?虽然我认为更好的方法是在 fit 方法中添加一个布尔值,指示是否更新阈值。 - Rotem Tal
2个回答

1

我实际上前几天写了一篇关于这个的博客文章。我猜你正在尝试构建类似于TransformedTargetRegressor的东西,我建议查看它的源代码以构建类似的东西。

您当前的实现似乎正确。就此问题而言:

如何使阈值参数可以在一个.fit调用中进行调整,并在随后的.fit调用中保持不变?

我建议不要这样做,因为scikit-learn的API是基于fit方法重新拟合模型的所有可调参数。在这里有两种路线可以选择,一种是向fit添加一个**kwarg,明确保护threshold不受更新的影响,或者你可以选择@rotem-tal建议的方法。如果你选择后者,可能会像这样:
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin

def optimal_threshold(y_raw: np.ndarray) -> np.ndarray:
    return np.array([0.1, 0.5, 1])  # some implementation here

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        self.threshold = None

    def fit(self, X, y, optimal_threshold):
        # you don't need to clone the regressor
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        if self.threshold is None:
            self.threshold = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

你需要遵循sklearn的API,否则可能会出现问题。我不知道sklearn的意图。你的版本将无法正常运行,因为 clone 不会复制 .threshold 属性。然而, cross_validate 使用了 clone。因此,使用我的修复阈值时 cross_validate 将无法正常工作! - Gere
我不确定你的意思,正如我在帖子中提到的那样,你试图做的事情是非标准的,可能无法与生态系统的其他部分(例如管道等)很好地集成,但这肯定是可能的。 - Adithya
1
我正在尝试找到一种方法来创建符合标准的估算器。并不是所有还未包含在sklearn中的内容都自动成为非标准。 - Gere
以上代码符合API标准,就编程接口而言,请试用一下。不过,“从概念上讲”,它并不符合标准,因为当调用“fit”时,阈值不会更新。 - Adithya
1
如果它不能与sklearn的关键函数配合使用,那么它就不符合API标准。仅仅使用类似名称的函数是不够的,因为还必须满足概念要求才能遵循API标准。这个代码版本在使用固定阈值的cross_validate时无法正常工作。我尝试过了,但出于已经解释的原因它就是无法正常工作。cross_validate将重新拟合每一折的阈值。 - Gere
从机器学习的角度来看,跨折叠保持阈值不变是没有意义的。在sklearn底层所做的是克隆估计器的结构并运行拟合,这会丢失之前折叠中学习到的阈值。虽然我不建议这样做,但解决这个问题的一种方法可能是使用Python全局值或类变量,这可以避开克隆问题。 - Adithya

0

一道完全合理的问题。是的,为了确保兼容性,必须:

  1. 除了参数持久化以外,不要在 init 中执行任何操作
  2. fit 中克隆内部评估器。只需使用下划线: self.regressor_ = clone(self.regressor)
  3. 为了提高灵活性,最好采用以下方式:def fit(self, X, y, **fit_params): optimal_threshold=fit_params.get('optimal_threshold',0.5) self.regressor_.fit(X, y, **fit_params),而不是仅使用 fit(self, X, y, optimal_threshold)

以及(可能)更好的性能

y_raw = self.regressor_.fit_predict(X, y, **fit_params)
  1. 添加到适配器

    # 检查 X 和 y 的形状是否正确
    
     X,y = check_X_y(X,y)
    
  2. 添加到预测器

    检查是否已经拟合

    check_is_fitted(self)

    输入验证

    X = check_array(X)

  3. 确保一致性检查通过,例如:

    from sklearn.utils.estimator_checks import check_estimator from sklearn.linear_model import LinearRegression

    check_estimator(Thresholder(regressor=LinearRegression()))

当然,阅读开发指南是必要的,特别是如果你需要设置random_state


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接