Scikit-learn如何检查模型(例如TfidfVectorizer)是否已经拟合?

6

对于从文本中提取特征,如何检查矢量化器(例如TfIdfVectorizer或CountVectorizer)是否已经适合了训练数据?
特别是,我想让代码自动找出矢量化器是否已经适合。

from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()

def vectorize_data(texts):
  # if vectorizer has not been already fit
  vectorizer.fit_transform(texts)
  # else
  vectorizer.transform(texts)

请查看我的答案并告诉我。 - seralouk
2个回答

7
你可以使用专门用于此目的的check_is_fitted,具体可参考这里
TfidfVectorizer.transform()源代码中,你可以查看其用法。
def transform(self, raw_documents, copy=True):

    # This is what you need.
    check_is_fitted(self, '_tfidf', 'The tfidf vector is not fitted')

    X = super(TfidfVectorizer, self).transform(raw_documents)
    return self._tfidf.transform(X, copy=False)

所以在您的情况下,您可以这样做:
from sklearn.utils.validation import check_is_fitted

def vectorize_data(texts):

    try:
        check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')
    except NotFittedError:
        vectorizer.fit(texts)

    # In all cases vectorizer if fit here, so just call transform()
    vectorizer.transform(texts)

4
我提出了两种检查方法:

适用于所有scikit-learn模型的个人代码:

import inspect

def my_inspector(model):
    return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

现在让我们测试这段代码:

from sklearn.feature_extraction.text import TfidfVectorizer
import inspect

vectorizer = TfidfVectorizer()

def my_inspector(model):
        return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

my_inspector(vectorizer)
# False

使用check_is_fitted的第二种方法

from sklearn.utils.validation import check_is_fitted

check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')

明白了。我认为更好的方法是检查self.vocabulary_的存在,而不是所有类属性。 - CentAu
我发布了一个回答,涵盖了所有sklearn函数,并且对你的情况有效。请考虑给我的回答点赞并接受它。谢谢! - seralouk

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接