Scipy curve_fit 使用错误拟合

6

我正在尝试将一些数据拟合到带指数截止的幂律函数上。 我使用numpy生成了一些数据,并尝试使用scipy.optimization来拟合这些数据。

以下是我的代码:

import numpy as np
from scipy.optimize import curve_fit

def func(x, A, B, alpha):
    return A * x**alpha * np.exp(B * x)

xdata = np.linspace(1, 10**8, 1000)
ydata = func(xdata, 0.004, -2*10**-8, -0.75)
popt, pcov = curve_fit(func, xdata, ydata)
print popt

我得到的结果是:[1, 1, 1],这与数据不符。我做错了什么吗?

1
我的方法出错了,因为你的最后一个数据点(在655642210处)的值为0。当你对它取对数时,会得到NaN。我使用我的方法计算拟合,不包括那个点,并获得了看起来合理的结果:A = 0.00326,alpha = -0.767,B = -1.88e-8。 - Simon Gibbons
是的,你说得对!我注意到了我的错误并删除了我的第二个问题。非常感谢。 - ivangtorre
2个回答

5
虽然xnx已经告诉你为什么curve_fit失败了,但我想提供一种不依赖于梯度下降(因此需要合理的初始猜测)来解决拟合函数形式问题的不同方法。
请注意,如果您对要拟合的函数取对数,则会得到以下形式: \log f = \log A + \alpha \log x + B x 其中每个未知参数(log A、alpha、B)都是线性的。
因此,我们可以使用线性代数的工具来通过将方程写成矩阵的形式来解决这个问题: log y = M p 其中log y是您的ydata点的对数列向量,p是未知参数的列向量,M是矩阵[[1], [log x], [x]]
或者明确地表达为: enter image description here 然后可以通过使用np.linalg.lstsq高效地找到最佳拟合参数向量。
因此,您的示例问题的代码可以编写为:
import numpy as np

def func(x, A, B, alpha):
    return A * x**alpha * np.exp(B * x)

A_true = 0.004
alpha_true = -0.75
B_true = -2*10**-8

xdata = np.linspace(1, 10**8, 1000)
ydata = func(xdata, A_true, B_true, alpha_true)

M = np.vstack([np.ones(len(xdata)), np.log(xdata), xdata]).T

logA, alpha, B = np.linalg.lstsq(M, np.log(ydata))[0]

print "A =", np.exp(logA)
print "alpha =", alpha
print "B =", B

这样可以很好地恢复初始参数:

A = 0.00400000003736
alpha = -0.750000000928
B = -1.9999999934e-08

还要注意的是,这种方法比使用curve_fit在手头问题上快大约20倍。

In [8]: %timeit np.linalg.lstsq(np.vstack([np.ones(len(xdata)), np.log(xdata), xdata]).T, np.log(ydata))
10000 loops, best of 3: 169 µs per loop


In [2]: %timeit curve_fit(func, xdata, ydata, [0.01, -5e-7, -0.4])
100 loops, best of 3: 4.44 ms per loop

2
显然,您的初始猜测(默认为[1,1,1],因为您没有给出一个 -- 请参见文档)距离实际参数太远,导致算法无法收敛。主要问题可能在于B,如果是正数,则会将指数函数发送到提供的xdata的非常大的值。
尝试提供一些更接近实际参数的内容,就可以解决问题:
p0 = 0.01, -5e-7, -0.4    # Initial guess for the parameters
popt, pcov = curve_fit(func, xdata, ydata, p0)
print popt

输出:

[  4.00000000e-03  -2.00000000e-08  -7.50000000e-01]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接