数值错误:x和y不能大于2-D,但其形状分别为(2,1,1)和(2,)。

3

源代码

cyprus_predicted_life_satisfaction = lin1.predict(cyprus_gdp_per_capita)[0][0]
#cyprus_predicted_life_satisfaction = lin1.predict(cyprus_gdp_per_capita)

cyprus_predicted_life_satisfaction

OUT: array([[5.96244744]])

sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3), s=1)
X=np.linspace(0, 60000, 1000)
plt.plot(X, t0 + t1*X, "b")
plt.axis([0, 60000, 0, 10])
plt.text(5000, 7.5, r"$\theta_0 = 4.85$", fontsize=14, color="b")
plt.text(5000, 6.6, r"$\theta_1 = 4.91 \times 10^{-5}$", fontsize=14, color="b")
plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita], [0, cyprus_predicted_life_satisfaction], "r--")
plt.text(25000, 5.0, r"Prediction = 5.96", fontsize=14, color="b")
plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, "ro")
save_fig('cyprus_prediction_plot')
plt.show()

错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-c7d7af89ced4> in <module>
      5 plt.text(5000, 7.5, r"$\theta_0 = 4.85$", fontsize=14, color="b")
      6 plt.text(5000, 6.6, r"$\theta_1 = 4.91 \times 10^{-5}$", fontsize=14, color="b")
----> 7 plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita], [0, cyprus_predicted_life_satisfaction], "r--")
      8 plt.text(25000, 5.0, r"Prediction = 5.96", fontsize=14, color="b")
      9 plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, "ro")

~/anaconda3/lib/python3.7/site-packages/matplotlib/pyplot.py in plot(scalex, scaley, data, *args, **kwargs)
   2809     return gca().plot(
   2810         *args, scalex=scalex, scaley=scaley, **({"data": data} if data
-> 2811         is not None else {}), **kwargs)
   2812 
   2813 

~/anaconda3/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1808                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1809                         RuntimeWarning, stacklevel=2)
-> 1810             return func(ax, *args, **kwargs)
   1811 
   1812         inner.__doc__ = _add_data_doc(inner.__doc__,

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in plot(self, scalex, scaley, *args, **kwargs)
   1609         kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D._alias_map)
   1610 
-> 1611         for line in self._get_lines(*args, **kwargs):
   1612             self.add_line(line)
   1613             lines.append(line)

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in _grab_next_args(self, *args, **kwargs)
    391                 this += args[0],
    392                 args = args[1:]
--> 393             yield from self._plot_args(this, kwargs)
    394 
    395 

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in _plot_args(self, tup, kwargs)
    368             x, y = index_of(tup[-1])
    369 
--> 370         x, y = self._xy_from_xy(x, y)
    371 
    372         if self.command == 'plot':

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in _xy_from_xy(self, x, y)
    232         if x.ndim > 2 or y.ndim > 2:
    233             raise ValueError("x and y can be no greater than 2-D, but have "
--> 234                              "shapes {} and {}".format(x.shape, y.shape))
    235 
    236         if x.ndim == 1:

ValueError: x and y can be no greater than 2-D, but have shapes (2, 1, 1) and (2,)

尝试的解决方案

1个回答

3
你需要使用 reshape 函数来完成:
a = np.random.random(size=(2, 1, 1))
a.shape
>> (2, 1, 1)

a = a.reshape(-1,)
a.shape
>> (2,)

如果这不能解决您的问题,请提供样本数据。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接