具有n个拐点的分段线性拟合

8

我使用了在问题如何在Python中应用分段线性拟合?中找到的一些代码,以执行具有单个断点的分段线性逼近。

代码如下:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise_linear(x, x0, y0, k1, k2):
    return np.piecewise(x, 
                       [x < x0], 
                       [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 15, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

我正在尝试找出如何扩展此代码以处理n个断点。我尝试了以下代码来处理2个断点的piecewise_linear()方法,但它并没有改变任何方式中断点的值。
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158])

def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3):
    return np.piecewise(x,
                       [x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ], 
                       [lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 20, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

非常感谢您提供任何意见。


“它不起作用”是一个几乎没有用的描述。我认为您也无法通过curve_fit()实现这一点,当有多个断点时,它会变得更加复杂(需要线性约束来处理b0 < b1;不支持;在np.piecewise触及最后一个参数之前忽略此问题并排序)。这也是一个非凸优化问题,因此scipy中所有这些可用的优化器只能达到局部最小值(如果它们真的能够实现)。话虽如此,我也怀疑使用curve-fit的单断点方法的有效性(因为它是不光滑的)。 - sascha
我认为,如果我最初在x轴上均匀分布断点,那么找到局部最小值就足以提供一个不错的非最优解决方案。你知道是否有另一个支持线性约束的优化模块吗? - Erlend Vollset
就像我告诉你的那样,这不仅仅是关于那个问题。忽略平滑性和潜在的非凸性,您可以使用scipy的更一般的优化函数来解决此问题,即COBYLA和SQSLP(仅支持约束的两个函数)。我唯一看到的真正方法是混合整数凸规划,但软件很稀少(bonmin和couenne是两个开源求解器,从python中使用不太好;pajarito @ julialang;但这种方法通常需要一些非平凡的公式)。 - sascha
1个回答

7

NumPy拥有一个polyfit函数,它可以非常容易地通过一组点找到最佳拟合直线:

coefs = npoly.polyfit(xi, yi, 1)

实际上,唯一的困难在于找到断点。对于给定的断点集,找到适合给定数据的最佳拟合线是微不足道的。

因此,不需要同时找到断点的位置和线性部分的系数,只需在断点的参数空间内进行最小化即可。

由于断点可以通过它们在x数组中的整数索引值来指定,因此参数空间可以被视为具有N维度的整数网格上的点,其中N是断点的数量。

optimize.curve_fit 不是解决这个问题的好选择,因为参数空间是整数值。如果你使用 curve_fit,算法将调整参数以确定向哪个方向移动。如果调整小于1个单位,则断点的x值不会更改,误差也不会更改,因此算法无法获得有关正确方向的任何信息,所以当参数空间基本上是整数值时,curve_fit 倾向于失败。

一个更好但速度不太快的优化器是暴力网格搜索。如果断点的数量较少(x-值的参数空间较小),则可能足以用网格搜索方法。如果断点的数量较多和/或参数空间较大,则可以设置多阶段的粗/细(暴力)网格搜索。或者,也许有人会提出比暴力搜索更聪明的优化器...


import numpy as np
import numpy.polynomial.polynomial as npoly
from scipy import optimize
import matplotlib.pyplot as plt
np.random.seed(2017)

def f(breakpoints, x, y, fcache):
    breakpoints = tuple(map(int, sorted(breakpoints)))
    if breakpoints not in fcache:
        total_error = 0
        for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
            total_error += ((f(xi) - yi)**2).sum()
        fcache[breakpoints] = total_error
    # print('{} --> {}'.format(breakpoints, fcache[breakpoints]))
    return fcache[breakpoints]

def find_best_piecewise_polynomial(breakpoints, x, y):
    breakpoints = tuple(map(int, sorted(breakpoints)))
    xs = np.split(x, breakpoints)
    ys = np.split(y, breakpoints)
    result = []
    for xi, yi in zip(xs, ys):
        if len(xi) < 2: continue
        coefs = npoly.polyfit(xi, yi, 1)
        f = npoly.Polynomial(coefs)
        result.append([f, xi, yi])
    return result

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
              18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 
              126.14, 140.03, 150, 152, 154, 156, 158])
# Add some noise to make it exciting :)
y += np.random.random(len(y))*10

num_breakpoints = 2
breakpoints = optimize.brute(
    f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None)

plt.scatter(x, y, c='blue', s=50)
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
    x_interval = np.array([xi.min(), xi.max()])
    print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval))
    plt.plot(x_interval, f(x_interval), 'ro-')


plt.show()

打印
y = poly([ 4.58801083  2.94476604])    , if x in [1.0, 6.0]
y = poly([-70.36472935  14.37305793])  , if x in [7.0, 15.0]
y = poly([ 123.24565235    1.94982153]), if x in [16.0, 20.0]

并绘制图表

在此输入图片描述


很好的答案... 我已经尝试了所有可能的leastsqminimize方法,但是分段参数x0x1并没有被正确优化。 - Saullo G. P. Castro
完美。谢谢! - Erlend Vollset
如何更改代码以使curve_fit工作(并假设数据不是严格的整数间隔)? - genjong
此外,这里的 fcache 代表什么? - genjong
请注意,成功使用 curve_fit 通常需要 f 具有足够平滑的变化,同时也具有足够的变化量,以便梯度指向最小值方向。 如果您修改 f,使其仅将连续值断点“弹回”到最近的整数,则 f 将在局部上变得平坦,因此优化器将无法找到最小值,因为梯度在所有方向上都为零。 - unutbu
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接