Python xarray索引/切片非常缓慢

3

我目前正在处理一些海洋模型输出。每个时间步长,它有42*1800*3600个网格点。

我发现我的程序瓶颈在于切片和调用xarray内置方法来提取值。更有趣的是,同样的语法有时需要非常不同的时间。

ds = xarray.open_dataset(filename, decode_times=False)
vvel0=ds.VVEL.sel(lat=slice(-60,-20),lon=slice(0,40))/100    #in CCSM output, unit is cm/s convert to m/s
uvel0=ds.UVEL.sel(lat=slice(-60,-20),lon=slice(0,40))/100   ## why the speed is that different? now it's regional!!
temp0=ds.TEMP.sel(lat=slice(-60,-20),lon=slice(0,40)) #de

例如,阅读VVEL和UVEL需要大约4秒的时间,而仅阅读TEMP只需要大约6毫秒。如果没有切片,VVEL和UVEL需要大约1秒的时间,而TEMP只需要120纳秒。
我一直以为,当我只输入完整数组的一部分时,我需要更少的内存,因此需要更少的时间。事实证明,XARRAY加载完整数组并且任何额外的切片都需要更多时间。但是,请问为什么从同一netcdf文件中读取不同的变量需要如此不同的时间?
该程序旨在提取逐步截面,并计算横截面热输运,因此我需要挑选出UVEL或VVEL,将其与沿截面的TEMP相乘。所以,似乎快速加载TEMP很好,不是吗?
不幸的是,情况并非如此。当我循环遍历约250个网格点沿着预定的截面时...
# Calculate VT flux orthogonal to the chosen grid cells, which is the heat transport across GOODHOPE line
vtflux=[]
utflux=[]
vap = vtflux.append
uap = utflux.append
#for i in range(idx_north,idx_south+1):
for i in range(10):
    yidx=gh_yidx[i]
    xidx=gh_xidx[i]
    lon_next=ds_lon[i+1].values
    lon_current=ds_lon[i].values
    lat_next=ds_lat[i+1].values
    lat_current=ds_lat[i].values
    tt=np.squeeze(temp[:,yidx,xidx].values)  #<< calling values is slow
    if (lon_next<lon_current) and (lat_next==lat_current):   # The condition is incorrect
        dxlon=Re*np.cos(lat_current*np.pi/180.)*0.1*np.pi/180.
        vv=np.squeeze(vvel[:,yidx,xidx].values)  
        vt=vv*tt
        vtdxdz=np.dot(vt[~np.isnan(vt)],layerdp[0:len(vt[~np.isnan(vt)])])*dxlon
        vap(vtdxdz)
        #del  vtdxdz
    elif (lon_next==lon_current) and (lat_next<lat_current):
        #ut=np.array(uvel[:,gh_yidx[i],gh_xidx[i]].squeeze().values*temp[:,gh_yidx[i],gh_xidx[i]].squeeze().values) # slow
        uu=np.squeeze(uvel[:,yidx,xidx]).values  # slow
        ut=uu*tt
        utdxdz=np.dot(ut[~np.isnan(ut)],layerdp[0:len(ut[~np.isnan(ut)])])*dxlat
        uap(utdxdz) #m/s*degC*m*m ## looks fine, something wrong with the sign
        #del utdxdz
total_trans=(np.nansum(vtflux)-np.nansum(utflux))*3996*1026/1e15

特别是这一行:
tt=np.squeeze(temp[:,yidx,xidx].values)

它需要大约3.65秒,但现在必须重复大约250次。如果我删除 .values,这个时间会减少到约4毫秒。但我需要将 tt 计时到 vt,因此必须提取值。奇怪的是,类似的表达式 vv=np.squeeze(vvel[:,yidx,xidx].values) 所需的时间更少,只有约1.3毫秒。


总结我的问题:

  1. 为什么从同一netcdf文件中加载不同变量需要不同的时间?
  2. 是否有更有效的方法来选择多维数组中的单个列?(不一定是xarray结构,也可以是numpy.ndarray)
  3. 为什么从Xarray结构提取值需要不同的时间,对于完全相同的语法?

谢谢!

1个回答

5
当您索引从netCDF文件加载的变量时,xarray不会立即将其加载到内存中。相反,我们创建了一个懒惰的数组,支持任意数量的进一步的索引操作。即使您没有使用dask.array(通过在open_dataset中设置chunks=或使用open_mfdataset触发),这也是正确的。
这解释了您观察到的令人惊讶的性能。计算temp0很快,因为它不需要从磁盘加载任何数据。vvel0很慢,因为除以100需要将数据作为numpy数组加载到内存中。
稍后,索引temp0会变慢,因为每个操作都会从磁盘加载数据,而不是索引已经在内存中的numpy数组。
解决方法是首先显式地将您需要的数据集部分加载到内存中,例如,通过编写temp0.load()。 xarray文档的netCDF部分也给出了这个提示。

非常感谢!根据您的建议,现在我将程序加速了3倍。将数据读入内存是不可避免的,只需不将该步骤放入任何循环中即可。 - Yu Cheng

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接