scikit-learn转换器中的数据不持久化

3

我想在scikit-learn中向变换器传递额外的数据:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier

from sklearn.pipeline import Pipeline
import numpy as np
from sklearn.model_selection import GridSearchCV

class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

data = np.random.rand(20,20)
data2 = np.random.rand(6,6)
y = np.array([1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 3, 3, 3, 3])

pipe = Pipeline(steps=[('myt', myTransformer(data2)), ('randforest', RandomForestClassifier())])
params = {"randforest__n_estimators": [100, 1000]}
estimators = GridSearchCV(pipe, param_grid=params, verbose=True)
estimators.fit(data, y)

然而,在 scikit-learn 管道中使用时,似乎会消失。

在 init 方法中打印出来的是 None,我该如何修复它?


你确定在传递mydata时它不是空的吗? - aberger
是的,它不是无。 - Bob
我猜你应该为估计器添加.fit,这样错误才会出现。 - lejlot
是的,谢谢。刚刚添加了。 - Bob
1个回答

6
这是因为sklearn以一种非常特定的方式处理估算器。通常情况下,它会为像网格搜索这样的操作创建该类的新实例,并将参数传递给构造函数。这是因为sklearn有自己的clone操作(在base.py中定义),它接受您的估算器类,获取参数(由get_params返回)并将其传递给您的类的构造函数。
klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in six.iteritems(new_object_params):
    new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params) 

为了支持您的对象必须重写get_params(deep=False)方法,该方法应返回一个字典,该字典将传递给构造函数。

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

    def get_params(self, deep=False):
        return {'my_np_array': self.data}

将按预期工作。


这不适用于整型和浮点型,对吗?所有变量都必须是实例本地的(self.variable)吗?例如,我需要它来获取最佳估计器。 - Bob
“不适用于整数和浮点数”?我的意思是你需要一个 get_params 方法,它将提供克隆转换器实例所需的一切。你可以把它们存储在任何地方,get_params 需要返回它,但你可以把它们作为静态属性、全局变量或任何你想要的东西。 - lejlot
抱歉再次打扰你,但我不确定我是否理解。我不太明白为什么参数不可见,因为网格搜索的目的是尝试不同的设置并找到最佳设置。那么官方库中的估算器是如何工作的呢?谢谢。 - Bob
这与可见性无关。这是scikit-learn开发人员决定克隆对象的方式(它有优点和缺点,但这是他们的设计决策)。在此没有其他要补充的。它们(估计器)都实现了此接口 - 它们具有get_params,调用它将返回复制对象所需的所有内容。还有set_param等具有对称含义的函数。 - lejlot

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接