多维度中的NumPy PolyFit和PolyVal是什么?

8
假设有一个n维数组,将其重塑为二维数组,其中每一行都是一个观察值集。使用这种重塑方法,np.polyfit可以计算整个ndarray(矢量化)的二次拟合系数。
fit = np.polynomial.polynomialpolyfit(X, Y, 2)

Y的形状为(304000, 21),X为向量,通过这种方式可以得到一个(304000,3)的系数数组fit。

使用迭代器,可以对每一行调用np.polyval(fit, X)。但是如果存在向量化的方法,此方法效率较低。是否能够将fit结果应用于整个观测值数组而不需要迭代?如果可以,请问如何做到?

这类似于此SO问题的内容。


1
请注意,不要在 np.polynomial.polynomial.polyfit 的结果上调用 np.polyval。请使用 np.polynomial.polynomial.polyval。参见https://dev59.com/rGMl5IYBdhLWcg3wAC6h#18767992。 - askewchan
дёҖз§Қжӣҙдјҳйӣ…дҪҶд»Қ然иҫғж…ўзҡ„ж–№жі•жҳҜдҪҝз”Ёnp.apply_along_axisе’Ңpolyfit...иҝҷйҮҢжңүдёҖдёӘдҫӢеӯҗгҖӮ - Saullo G. P. Castro
@askewchan 当然可以!这只是我的懒惰,没有写完整的调用路径。@Saullo Castro - 正如你所建议的那样,np.apply_along_axis 并不比 [i, j] 迭代器快。我在思考是否存在真正矢量化(在 C 级别)的方法。 - Jzl5325
2个回答

8

np.polynomial.polynomial.polyval 接受多维系数数组:

>>> x = np.random.rand(100)
>>> y = np.random.rand(100, 25)
>>> fit = np.polynomial.polynomial.polyfit(x, y, 2)
>>> fit.shape # 25 columns of 3 polynomial coefficients
(3L, 25L)
>>> xx = np.random.rand(50)
>>> interpol = np.polynomial.polynomial.polyval(xx, fit)
>>> interpol.shape # 25 rows, each with 50 evaluations of the polynomial
(25L, 50L)

当然,还有:
>>> np.all([np.allclose(np.polynomial.polynomial.polyval(xx, fit[:, j]),
...                     interpol[j]) for j in range(25)])
True

0

np.polynomial.polynomial.polyval 是一种非常好(且方便)的方法,用于高效地评估多项式拟合。

然而,如果您正在寻找“最快”的方法,那么简单地构建多项式输入并使用基本的numpy矩阵乘法函数会导致稍微更快的计算速度(大约快4倍)。

设置

使用与上文相同的设置,我们将创建25个不同的线性拟合。

>>> num_samples = 100000
>>> num_lines = 100
>>> x = np.random.randint(0,100,num_samples)
>>> y = np.random.randint(0,100,(num_samples, num_lines))
>>> fit = np.polyfit(x,y,deg=2)
>>> xx = np.random.randint(0,100,num_samples*10)

Numpy的polyval函数

res1 = np.polynomial.polynomial.polyval(xx, fit)

基本矩阵乘法

inputs = np.array([np.power(xx,d) for d in range(len(fit))])
res2 = fit.T.dot(inputs)

计时函数

使用上述相同的参数...

%timeit _ = np.polynomial.polynomial.polyval(xx, fit)
1 loop, best of 3: 247 ms per loop

%timeit inputs = np.array([np.power(xx, d) for d in range(len(fit))]);_ = fit.T.dot(inputs)
10 loops, best of 3: 72.8 ms per loop

反复强调一个已经解决的问题...

enter image description here

平均效率提升了约3.61倍。速度波动可能来自后台的随机计算机进程。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接