如何在scikit-learn中实现多项式逻辑回归?

12

我试图使用scikit-learn创建一个非线性逻辑回归,即多项式逻辑回归。但我找不到如何定义多项式的次数。有人尝试过吗?谢谢!

1个回答

24

您需要分两步进行。假设您正在使用鸢尾花数据集(因此您有一个可重现的示例):

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

data = load_iris()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

步骤1

首先需要将您的数据转换为多项式特征。原始数据有4列:

X_train.shape
>>> (112,4)

您可以使用scikit learn创建多项式特征(以下是二次多项式特征的示例):

poly = PolynomialFeatures(degree = 2, interaction_only=False, include_bias=False)
X_poly = poly.fit_transform(X_train)
X_poly.shape
>>> (112,14)
我们现在有14个特征(原始的4个特征,它们的平方,以及6种交叉组合)。
第二步
在此基础上,现在您可以构建您的逻辑回归模型,调用X_poly
lr = LogisticRegression()
lr.fit(X_poly,y_train)
注意:如果您想在测试数据上评估您的模型,则还需要遵循这两个步骤并执行以下操作:
lr.score(poly.transform(X_test), y_test)

将所有内容整合到管道中(可选)

您可能希望使用管道而不是构建中间对象来处理这两个步骤的方式,以便将它们处理在一个对象中:

pipe = Pipeline([('polynomial_features',poly), ('logistic_regression',lr)])
pipe.fit(X_train, y_train)
pipe.score(X_test, y_test)

1
非常感谢您的详细解释! - Inna

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接