我正在尝试使用Numpy和Matplotlib可视化一个二维平面通过一个三维图形,以解释偏导数的直观感受。
具体而言,我使用的函数是J(θ1,θ2) = θ1^2 + θ2^2,并且我想在θ2=0处绘制一个θ1-J(θ1,θ2)平面。
我已经成功绘制了一个二维平面,但是二维平面和三维图形的重叠部分并不完全正确,并且二维平面略微偏离,因为我希望该平面看起来像在θ2=0处穿过了三维图形。
如果您能在此方面提供专业意见,那将是非常好的,谢谢。
def f(theta1, theta2):
return theta1**2 + theta2**2
fig, ax = plt.subplots(figsize=(6, 6),
subplot_kw={'projection': '3d'})
x,z = np.meshgrid(np.linspace(-1,1,100), np.linspace(0,2,100))
X = x.T
Z = z.T
Y = 0 * np.ones((100, 100))
ax.plot_surface(X, Y, Z)
r = np.linspace(-1,1,100)
theta1_grid, theta2_grid = np.meshgrid(r,r)
J_grid = f(theta1_grid, theta2_grid)
ax.contour3D(theta1_grid,theta2_grid,J_grid,500,cmap='binary')
ax.set_xlabel(r'$\theta_1$',fontsize='large')
ax.set_ylabel(r'$\theta_2$',fontsize='large')
ax.set_zlabel(r'$J(\theta_1,\theta_2)$',fontsize='large')
ax.set_title(r'Fig.2 $J(\theta_1,\theta_2)=(\theta_1^2+\theta_2^2)$',fontsize='x-large')
plt.tight_layout()
plt.show()
这是代码输出的图像: