Sklearn线性回归 - “索引错误:元组索引超出范围”

10

我有一个“.dat”文件,其中保存了X和Y的值(因此是一个元组(n,2),其中n是行数)。

import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate as interp
from sklearn import linear_model

in_file = open(path,"r")
text = np.loadtxt(in_file)
in_file.close()
x = np.array(text[:,0])
y = np.array(text[:,1])

我为linear_model.LinearRegression()创建了一个实例,但是当我调用.fit(x,y)方法时出现了

IndexError: tuple index out of range

regr = linear_model.LinearRegression()
regr.fit(x,y)

我做错了什么?


抱歉我完全误读了你的问题 :( 我已经删除了答案,如果我能得到修复,那么我会取消删除编辑后的答案。但是你可以提供更多信息吗?比如你的完整代码? - Ffisegydd
这就是你需要的代码,没有其他重要的东西。 - JackLametta
真的吗?linear_model是什么?你是怎么得到它的? - Ffisegydd
现在就这些了,感谢您的帮助。 - JackLametta
x和Y的长度是否相同? - Santi Peñate-Vera
text.shape 是 (n,2),因此 x 和 y 都有 (n,)。 - JackLametta
1个回答

17

线性回归期望X是一个具有两个维度的数组,并且在内部需要X.shape[1]来初始化一个np.ones数组。因此,将X转换为nx1数组就可以解决问题。所以,请替换:

regr.fit(x,y)

作者:

regr.fit(x[:,np.newaxis],y)

这将解决问题。演示:

>>> from sklearn import datasets
>>> from sklearn import linear_model
>>> clf = linear_model.LinearRegression()
>>> iris=datasets.load_iris()
>>> X=iris.data[:,3]
>>> Y=iris.target
>>> clf.fit(X,Y)  # This will throw an error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python2.7/dist-packages/sklearn/linear_model/base.py", line 363, in fit
    X, y, self.fit_intercept, self.normalize, self.copy_X)
  File "/usr/lib/python2.7/dist-packages/sklearn/linear_model/base.py", line 103, in center_data
    X_std = np.ones(X.shape[1])
IndexError: tuple index out of range
>>> clf.fit(X[:,np.newaxis],Y)  # This will work properly
LinearRegression(copy_X=True, fit_intercept=True, normalize=False)

使用以下代码绘制回归线:

>>> from matplotlib import pyplot as plt
>>> plt.scatter(X, Y, color='red')
<matplotlib.collections.PathCollection object at 0x7f76640e97d0>
>>> plt.plot(X, clf.predict(X[:,np.newaxis]), color='blue')
<matplotlib.lines.Line2D object at 0x7f7663f9eb90>
>>> plt.show()

在此输入图片描述


非常感谢您的帮助!另一个问题:现在我只得到线性回归的系数,这正常吗?我该如何绘制它的线条? - JackLametta
@JackLametta,这是非常正常的。这些系数用于根据给定的Y值预测X值。我已经上传了绘制线条的代码。 - Irshad Bhat

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接