对于非递减的
x
样条,如果您将
x
和
y
函数作为另一个参数
t
的函数进行计算,则可以轻松地计算出它们:
x(t)
,
y(t)
。
在您的情况下,有5个点,因此
t
应该是这些点的枚举,即
t = 0, 1, 2, 3, 4
代表5个点。
所以如果
x = [5, 2, 7, 3, 6]
,那么
x(t) = x(0) = 5
,
x(1) = 2
,
x(2) = 7
,
x(3) = 3
,
x(4) = 6
。同理可得
y
。
然后为
x(t)
和
y(t)
计算样条函数。然后在所有中间的
t
点计算样条值。最后只需使用所有计算出的值
x(t)
和
y(t)
作为函数
y(x)
的值。
以前,我使用Numpy从头实现了三次样条插值计算,所以如果你不介意的话,我在下面的示例中使用了这段代码(它可能对你学习样条数学有用),请用你的库函数替换。此外,在我的代码中,你可以看到已经注释掉了numba
行,如果你想的话,可以使用这些Numba注释来加速计算。
你需要查看代码底部的main()
函数,它展示了如何计算和使用x(t)
和y(t)
。
在线试玩!
import numpy as np, matplotlib.pyplot as plt
def tri_diag_solve(A, B, C, F):
n = B.size
assert A.ndim == B.ndim == C.ndim == F.ndim == 1 and (
A.size == B.size == C.size == F.size == n
)
Bs, Fs = np.zeros_like(B), np.zeros_like(F)
Bs[0], Fs[0] = B[0], F[0]
for i in range(1, n):
Bs[i] = B[i] - A[i] / Bs[i - 1] * C[i - 1]
Fs[i] = F[i] - A[i] / Bs[i - 1] * Fs[i - 1]
x = np.zeros_like(B)
x[-1] = Fs[-1] / Bs[-1]
for i in range(n - 2, -1, -1):
x[i] = (Fs[i] - C[i] * x[i + 1]) / Bs[i]
return x
def calc_spline_params(x, y):
a = y
h = np.diff(x)
c = np.concatenate((np.zeros((1,), dtype = y.dtype),
np.append(tri_diag_solve(h[:-1], (h[:-1] + h[1:]) * 2, h[1:],
((a[2:] - a[1:-1]) / h[1:] - (a[1:-1] - a[:-2]) / h[:-1]) * 3), 0)))
d = np.diff(c) / (3 * h)
b = (a[1:] - a[:-1]) / h + (2 * c[1:] + c[:-1]) / 3 * h
return a[1:], b, c[1:], d
def func_spline(x, ix, x0, a, b, c, d):
dx = x - x0[1:][ix]
return a[ix] + (b[ix] + (c[ix] + d[ix] * dx) * dx) * dx
def piece_wise_spline(x, x0, a, b, c, d):
xsh = x.shape
x = x.ravel()
ix = np.searchsorted(x0[1 : -1], x)
y = func_spline(x, ix, x0, a, b, c, d)
y = y.reshape(xsh)
return y
def main():
x0 = np.array([4.0, 5.638304088577984, 6.785456961280076, 5.638304088577984, 4.0])
y0 = np.array([0.0, 1.147152872702092, 2.7854569612800755, 4.423761049858059, 3.2766081771559668])
t0 = np.arange(len(x0)).astype(np.float64)
plt.plot(x0, y0)
vs = []
for e in (x0, y0):
a, b, c, d = calc_spline_params(t0, e)
x = np.linspace(0, t0[-1], 100)
vs.append(piece_wise_spline(x, t0, a, b, c, d))
plt.plot(vs[0], vs[1])
plt.show()
if __name__ == '__main__':
main()
输出:
![enter image description here](https://istack.dev59.com/zgCEb.webp)