使用Scipy curve_fit和分段函数

14

我收到了一个优化警告:

OptimizeWarning: Covariance of the parameters could not be estimated
                 category=OptimizeWarning)

当我尝试使用scipy.optimize.curve_fit将我的分段函数拟合到数据时,意味着没有进行拟合。我可以轻松地将二次曲线拟合到我的数据上,并且我正在向curve_fit提供我认为是良好的初始参数。下面是完整的代码示例。有人知道为什么curve_fit可能与np.piecewise不兼容吗?还是我犯了其他错误?

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt


def piecewise_linear(x, x0, y0, k1, k2):
    y = np.piecewise(x, [x < x0, x >= x0],
                     [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])
    return y

def parabola(x, a, b):
    y = a * x**2 + b
    return y

x = np.array([-3, -2, -1, 0, 1, 2, 3])
y = np.array([9.15, 5.68, 2.32, 0.00, 2.05, 5.29, 8.62])


popt_piecewise, pcov = curve_fit(piecewise_linear, x, y, p0=[0.1, 0.1, -5, 5])
popt_parabola, pcov = curve_fit(parabola, x, y, p0=[1, 1])

new_x = np.linspace(x.min(), x.max(), 61)


fig, ax = plt.subplots()

ax.plot(x, y, 'o', ls='')
ax.plot(new_x, piecewise_linear(new_x, *popt_piecewise))
ax.plot(new_x, parabola(new_x, *popt_parabola))

ax.set_xlim(-4, 4)
ax.set_ylim(-2, 16)

输入图像描述

2个回答

10

这是一个类型问题,您需要更改以下行,以便将x给定为浮点数:

x = np.array([-3, -2, -1, 0, 1, 2, 3]).astype(np.float)

否则,piecewise_linear可能会强制转换类型。为了保险起见,你也可以在这里将初始点设置为浮点数:
popt_piecewise, pcov = curve_fit(piecewise_linear, x, y, p0=[0.1, 0.1, -5., 5.])

你是怎么得出这个结论的? - Bill Bell
我试图使用给定的数据点评估piecewise_linear,但它没有起作用,因此我得出结论问题肯定出在那里。我认为这可能与np.piecewise的某些奇怪行为有关。 - J. P. Petersen
我尝试了同样的事情,但完全错过了。非常好! - Bill Bell
3
我建议使用x = np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.float),这会直接告诉NumPy构建一个浮点数数组,而不是先构建整数数组再转换类型。 - user6655984
@zaq 是的,那更好。 - J. P. Petersen
你也可以在数字中的一个后面加一个点,例如 x = np.array([-3., -2, -1, 0, 1, 2, 3]),这会触发np将其转换为浮点数。 - J. P. Petersen

4

为了完整起见,我指出拟合分段线性函数并不需要使用np.piecewise:可以使用绝对值来构造任何这样的函数,对于每一个转折点,使用np.abs(x-x0)的倍数。以下代码可以很好地适配数据:

def pl(x, x0, a, b, c):
    y = a*np.abs(x-x0) + b*x + c
    return y

popt_pl, pcov = curve_fit(pl, x, y, p0=[0, 0, 0, 0])

print(pl(x, *popt_pl))

输出值接近于原始y值:

[ 8.90899998  5.828       2.74700002 -0.33399996  2.03499998  5.32
  8.60500002]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接