快速插值三维数组以获取三维原点x。

3
这个问题与之前回答的一个类似的问题 Fast interpolation over 3D array,但不能解决我的问题。
我有一个四维数组,尺寸为 (time,altitude,latitude, longitude),表示为 y.shape=(nt,nalt,nlat,nlon)。x 是高度,随着 (time, latitude, longtitude) 变化,这意味着 x.shape=(nt,nalt,nlat,nlon)。我想在每个 (nt,nlat,nlon) 中的高度上进行插值计算。插值后的 x_new 应该是1维的,不会随着 (time, latitude, longtitude) 变化。
我使用了 numpy.interp,与 scipy.interpolate.interp1d相同,并考虑了之前帖子中的答案。但这些答案都无法减少循环次数。
我只能这样做:
# y is a 4D ndarray
# x is a 4D ndarray
# new_y is a 4D array
for i in range(nlon):
    for j in range(nlat):
        for k in range(nt):
            y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i])

这些循环使得这种插值计算过于缓慢。是否有人有好的想法?非常感谢您的帮助。

如果new_x是升序的,我认为你可以编写一个Cython函数来提高计算速度。 - HYRY
@HYRY 新的new_x是升序排列的。但原始的x可能不是升序排列的。这可行吗?我不熟悉Cython,也许需要一些时间来学习它,或者可以学习前面帖子中Tiago的代码。 - Hao
但是,numpy.interp 需要 x 以升序排列,因此您需要首先按照 xxy 进行排序。 - HYRY
是的,我会检查原始的 x 和 y。非常感谢~ - Hao
原始的x是132个时间片的66层潜在温度。水平网格为96*144。感谢您的耐心等待。 - Hao
显示剩余4条评论
1个回答

1
这是我的使用numba的解决方案,速度大约快了3倍。
首先创建测试数据,x需要按升序排列:
import numpy as np
rows = 200000
cols = 66
new_cols = 69
x = np.random.rand(rows, cols)
x.sort(axis=-1)
y = np.random.rand(rows, cols)
nx = np.random.rand(new_cols)
nx.sort() 

使用numpy对interp进行200000次操作:

%%time
ny = np.empty((x.shape[0], len(nx)))
for i in range(len(x)):
    ny[i] = np.interp(nx, x[i], y[i])

我使用合并方法而不是二分查找方法,因为nx是有序的,并且nx的长度与x的长度大致相同。
  • interp()使用二分查找,时间复杂度为O(len(nx)*log2(len(x))
  • 合并方法:时间复杂度为O(len(nx) + len(x))

这是 numba 代码:

import numba

@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])")
def interp2(x, xp, fp, f):
    n = len(x)
    n2 = len(xp)
    j = 0
    i = 0
    while x[i] <= xp[0]:
        f[i] = fp[0]
        i += 1

    slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])        
    while i < n:
        if x[i] >= xp[j] and x[i] < xp[j+1]:
            f[i] = slope*(x[i] - xp[j]) + fp[j]
            i += 1
            continue
        j += 1
        if j + 1 == n2:
            break
        slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])   

    while i < n:
        f[i] = fp[n2-1]
        i += 1

@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])")
def multi_interp(x, xp, fp):
    nrows = xp.shape[0]
    f = np.empty((nrows, x.shape[0]))
    for i in range(nrows):
        interp2(x, xp[i, :], fp[i, :], f[i, :])
    return f

然后调用numba函数:
%%time
ny2 = multi_interp(nx, x, y)

检查结果:

np.allclose(ny, ny2)

On my pc, the time is:

python version: 3.41 s
numba version: 1.04 s

这种方法需要一个数组,其中最后一个轴是要进行interp()的轴。


非常感谢。太好了!我明天会检查它。 - Hao

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接