Scikit-learn管道:连接轴的所有输入数组维度必须完全匹配。

6
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.compose import ColumnTransformer

data = [[1, 3, 4, 'text', 'pos'], [9, 3, 6, 'text more', 'neg']]
data = pd.DataFrame(data, columns=['Num1', 'Num2', 'Num3', 'Text field', 'Class'])

tweet_text_transformer = Pipeline(steps=[
    ('count_vectoriser', CountVectorizer()),
    ('tfidf', TfidfTransformer())
])

numeric_transformer = Pipeline(steps=[
    ('scaler', MinMaxScaler())
])

preprocessor = ColumnTransformer(transformers=[
    # (name, transformer, column(s))
    ('tweet', tweet_text_transformer, ['Text field']),
    ('numeric', numeric_transformer, ['Num1', 'Num2', 'Num3'])
])

pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', LinearSVC())
])

X_train = data.loc[:, 'Num1':'Text field']
y_train = data['Class']
pipeline.fit(X_train, y_train)

我不明白这个错误是从哪里来的:

值错误:连接轴上的所有输入数组维度必须完全匹配,但在第 0 维上,索引为 0 的数组大小为 1,索引为 1 的数组大小为 2。


请提供完整的错误输出。 - AMC
我也遇到了完全相同的错误。你最终解决了吗? - Josh
2个回答

5

原因

问题出现在preprocessor流水线中。这个流水线的工作方式是将tweet_text_transformer的输出和numeric_transformer的输出水平堆叠。为了成功实现这一点,这两个输出(tweet_text_transformer和numeric_transformer)必须具有相同的行数(即轴0或维度-0中的元素数量)。

但是,当上述流水线被执行时,尽管我们期望tweet_text_processor给出一个2 * 2矩阵,共4个元素,但实际上由于CountVectorizer将输出存储为稀疏矩阵,它会删除矩阵中的任何零(以节省内存),这将导致数组变成2*2矩阵,但只有3个元素。当这个矩阵与numeric_transformer的输出堆叠在一起时,它无法满足上述条件(因为numeric transformer在轴0上只有两个元素,而tweet_text_processor没有)。

解释的输出

解决方法

  • 创建一个自定义转换器,将这个稀疏矩阵转换为numpy数组
  • 由于只有一列,所以挤压Pandas数据框以将其转换为Panadas Series
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.compose import ColumnTransformer

data = [[1, 3, 4, 'text', 'pos'], [9, 3, 6, 'text more', 'neg']]
data = pd.DataFrame(data, columns=['Num1', 'Num2', 'Num3', 'Text field', 'Class'])



class TweetTextProcessor(BaseEstimator, TransformerMixin):
    def __init__(self):
        self.tweet_text_transformer = Pipeline(steps=[
        ('count_vectoriser', CountVectorizer()),
        ('tfidf', TfidfTransformer())    ])
       
        
    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
       
        return  self.tweet_text_transformer.fit_transform(X.squeeze()).toarray()
        




numeric_transformer = Pipeline(steps=[
    ('scaler', MinMaxScaler())
])

preprocessor = ColumnTransformer(transformers=[
    ('tweet', TweetTextProcessor(), ['Text field']),
    ('numeric', numeric_transformer, ['Num1', 'Num2', 'Num3'])
])

pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', LinearSVC())
])

X_train = data.loc[:, 'Num1':'Text field']
y_train = data['Class']
pipeline.fit(X_train, y_train)

以上代码应该可以正常工作,如果有问题请告知,或者如果解释不够清晰(希望已经足够清晰了)


2

我实现了您的代码解决方案,将稀疏矩阵转换为数组,并修复了错误。但是,当我调用predict时,它显示另一个错误。

model = pipeline.fit(X_train,y_train)
y_pred = model.predict(X_test)

它给我这个错误

ValueError: X每个样本有574个特征; 期望是493个

我的理解是,在这种情况下,它没有使用训练好的向量化模型,而是在X_test数据集上重新训练一个新的模型。我该如何解决这个问题,我不知道。

注意:需要为BaseEstimator和TransformerMixin添加导入语句

更新:

为了解决这个问题,使用FunctionTransformer代替定义一个类

使用FunctionTransformer代替定义一个类

from sklearn.preprocessing import FunctionTransformer

vectorizer_params = dict(ngram_range=(1, 2), min_df=5, max_df=0.8)

TweetTextProcessor = Pipeline(steps=[
    ("squeez", FunctionTransformer(lambda x: x.squeeze())),
    ("vect", CountVectorizer(**vectorizer_params)),
    ("tfidf", TfidfTransformer()),
    ("toarray", FunctionTransformer(lambda x: x.toarray())),
])

numeric_transformer = Pipeline(steps=[
    ('scaler', MinMaxScaler())
])

preprocessor = ColumnTransformer(transformers=[
    ('tweet', TweetTextProcessor, ['Text field']),
    ('numeric', numeric_transformer, ['Num1', 'Num2', 'Num3'])
])

pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', LinearSVC())
])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接