使用Matplotlib在三维空间中绘制线性模型图。

6

我正在尝试创建一个线性模型的三维图,用于数据集拟合。我在R中能够相对容易地完成这个任务,但是我在Python中做同样的事情真的很困难。这是我在R中所做的:

3d plot

这是我在Python中所做的:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.formula.api as sm

csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0)
model = sm.ols(formula='Sales ~ TV + Radio', data = csv)
fit = model.fit()

fit.summary()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(csv['TV'], csv['Radio'], csv['Sales'], c='r', marker='o')

xx, yy = np.meshgrid(csv['TV'], csv['Radio'])

# Not what I expected :(
# ax.plot_surface(xx, yy, fit.fittedvalues)

ax.set_xlabel('TV')
ax.set_ylabel('Radio')
ax.set_zlabel('Sales')

plt.show()

我做错了什么,应该怎么做?谢谢。

你介意提供你的代码吗? - singularli
2个回答

12

搞定了!

我评论mdurant的答案时提到的问题是,曲面没有像结合散点图和曲面图本身那样呈现为一个漂亮的方格图案。

我意识到问题在于我的meshgrid,所以我纠正了两个范围(xy),并对np.arange使用了比例步长。

这使我能够使用mdurant答案中提供的代码,而且完美地运行了!

下面是结果:

带OLS平面的三维散点图

以下是代码:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.formula.api as sm
from matplotlib import cm

csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0)
model = sm.ols(formula='Sales ~ TV + Radio', data = csv)
fit = model.fit()

fit.summary()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x_surf = np.arange(0, 350, 20)                # generate a mesh
y_surf = np.arange(0, 60, 4)
x_surf, y_surf = np.meshgrid(x_surf, y_surf)

exog = pd.core.frame.DataFrame({'TV': x_surf.ravel(), 'Radio': y_surf.ravel()})
out = fit.predict(exog = exog)
ax.plot_surface(x_surf, y_surf,
                out.reshape(x_surf.shape),
                rstride=1,
                cstride=1,
                color='None',
                alpha = 0.4)

ax.scatter(csv['TV'], csv['Radio'], csv['Sales'],
           c='blue',
           marker='o',
           alpha=1)

ax.set_xlabel('TV')
ax.set_ylabel('Radio')
ax.set_zlabel('Sales')

plt.show()

4

你正确地假设plot_surface需要一个网格坐标来工作,但是predict需要像你拟合的那个数据结构一样的数据结构(即“exog”)。

exog = pd.core.frame.DataFrame({'TV':xx.ravel(),'Radio':yy.ravel()})
out = fit.predict(exog=exog)
ax.plot_surface(xx, yy, out.reshape(xx.shape), color='None')

1
这种方法可以用,但生成的平面图非常糟糕。我很困惑为什么似乎没有更好、更简单的方法来完成这个任务。 - ivanmp
糟糕在哪里?您可以向plot_surface提供许多参数。但是,如果您想要花哨的3D效果,可能需要考虑使用mayavi。 - mdurant
这是我得到的内容(我删除了 color='None')。我正在寻找像这样的东西:https://dev59.com/WmUp5IYBdhLWcg3wf3ys。顺便说一句,谢谢你的帮助。 - ivanmp
明白了!请看下面的答案。我会将您的标记为正确的。再次感谢。 - ivanmp

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接