我有一个3D数组
(z, y, x)
,其shape=(92, 4800, 4800)
,其中沿axis 0
的每个值代表不同时间点。在某些情况下,时间域内的值获取失败,导致一些值为np.NaN
。在其他情况下,没有获取到任何值,所有z
上的值都是np.NaN
。
最有效的方法是什么,可以使用线性插值来填充axis 0
上的np.NaN
,而忽略所有值都是np.NaN
的情况?
这里是一个工作示例,它使用pandas
包装器来调用scipy.interpolate.interp1d
。这需要大约2秒钟来处理原始数据集中的每个切片,这意味着整个数组需要2.6小时才能处理完毕。缩小大小后的示例数据集需要大约9.5秒钟。
import numpy as np
import pandas as pd
# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768 # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768
def interpolate_nan(arr, method="linear", limit=3):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for i in range(arr.shape[1]):
# slice along y axis, interpolate with pandas wrapper to interp1d
line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
result[:, i, :] = line_stack.values.astype(np.int16)
return result
使用示例数据集在我的计算机上的性能表现:
%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop
编辑:
我需要澄清的是,该代码正在生成我期望的结果。问题是-我如何优化这个过程?
(92, 480, 480)
。如果将其增加到真实数据集的大小(92, 4800, 4800)
并使用更多的NaN进行传播,则此方法需要更长时间。 - Kersten