在Python中加速读取非常大的netcdf文件

19

我使用Python中的netCDF4读取一个非常大的netCDF文件。

由于该文件的维度(1200 x 720 x 1440)太大了,无法一次性将整个文件都存储在内存中。第1维表示时间,下一个2维分别表示纬度和经度。

import netCDF4 
nc_file = netCDF4.Dataset(path_file, 'r', format='NETCDF4')
for yr in years:
    nc_file.variables[variable_name][int(yr), :, :]

然而,每年阅读一次是极其缓慢的。如何才能加快以下用例的速度?

--编辑

块大小为1

  1. 我可以阅读一系列年份:nc_file.variables[variable_name][0:100, :, :]

  2. 有几种用例:

    for yr in years:

numpy.ma.sum(nc_file.variables[variable_name][int(yr), :, :])

# Multiply each year by a 2D array of shape (720 x 1440)
for yr in years:
    numpy.ma.sum(nc_file.variables[variable_name][int(yr), :, :] * arr_2d)

# Add 2 netcdf files together 
for yr in years:
    numpy.ma.sum(nc_file.variables[variable_name][int(yr), :, :] + 
                 nc_file2.variables[variable_name][int(yr), :, :])

你确定以其他方式读取文件(例如一次性读取整个文件)会更快吗?你可以尝试使用裁剪后的文件吗? - ivan_pozdeev
有进行任何必要的性能分析吗? - ivan_pozdeev
你读取数据后是否对该年份的数据进行任何处理?你能否读取一定范围内的年份,例如 [1997:2007,:,:] - hpaulj
感谢@hapulj,我可以读取一系列年份。有几种用例。编辑问题以反映它们。 - user308827
3个回答

35

我强烈建议你查看xarraydask项目。使用这些强大的工具可以让您轻松地将计算分成块。这有两个优点:您可以计算不适合内存的数据,还可以使用机器中的所有核心以获得更好的性能。通过适当选择块大小(请参阅文档),您可以优化性能。

只需执行如下简单操作即可从netCDF加载数据:

import xarray as xr
ds = xr.open_dataset(path_file)

如果你想将数据沿时间维度分块为年份,那么你需要指定chunks参数(假设年份坐标被命名为“year”):

ds = xr.open_dataset(path_file, chunks={'year': 10})

由于其他坐标未出现在 chunks 字典中,因此它们将使用单个数据块。(更多细节请参见文档这里。) 这对于您的第一个需求非常有用,您希望将每年乘以一个 2D 数组。您只需执行以下操作:

ds['new_var'] = ds['var_name'] * arr_2d

现在,xarraydask 正在延迟计算您的结果。为了触发实际计算,您只需要求xarray将结果保存回 netCDF:

ds.to_netcdf(new_file)

计算是通过 dask 触发的,它负责将处理分成块,从而使得可以处理不适合内存的数据。此外,dask 会利用所有处理器核心来计算这些块。

xarraydask 项目仍然无法很好地处理那些不适合并行计算的块。由于在本例中我们只在“年份”维度上进行了分块,因此我们不希望出现任何问题。

如果你想将两个不同的 netCDF 文件相加,只需执行以下简单操作:

ds1 = xr.open_dataset(path_file1, chunks={'year': 10})
ds2 = xr.open_dataset(path_file2, chunks={'year': 10})
(ds1 + ds2).to_netcdf(new_file)

我提供了一个使用在线可用数据集的完全可工作示例。

In [1]:

import xarray as xr
import numpy as np

# Load sample data and strip out most of it:
ds = xr.open_dataset('ECMWF_ERA-40_subset.nc', chunks = {'time': 4})
ds.attrs = {}
ds = ds[['latitude', 'longitude', 'time', 'tcw']]
ds

Out[1]:

<xarray.Dataset>
Dimensions:    (latitude: 73, longitude: 144, time: 62)
Coordinates:
  * latitude   (latitude) float32 90.0 87.5 85.0 82.5 80.0 77.5 75.0 72.5 ...
  * longitude  (longitude) float32 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 ...
  * time       (time) datetime64[ns] 2002-07-01T12:00:00 2002-07-01T18:00:00 ...
Data variables:
    tcw        (time, latitude, longitude) float64 10.15 10.15 10.15 10.15 ...

In [2]:

arr2d = np.ones((73, 144)) * 3.
arr2d.shape

Out[2]:

(73, 144)

In [3]:

myds = ds
myds['new_var'] = ds['tcw'] * arr2d

In [4]:

myds

Out[4]:

<xarray.Dataset>
Dimensions:    (latitude: 73, longitude: 144, time: 62)
Coordinates:
  * latitude   (latitude) float32 90.0 87.5 85.0 82.5 80.0 77.5 75.0 72.5 ...
  * longitude  (longitude) float32 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 ...
  * time       (time) datetime64[ns] 2002-07-01T12:00:00 2002-07-01T18:00:00 ...
Data variables:
    tcw        (time, latitude, longitude) float64 10.15 10.15 10.15 10.15 ...
    new_var    (time, latitude, longitude) float64 30.46 30.46 30.46 30.46 ...

In [5]:

myds.to_netcdf('myds.nc')
xr.open_dataset('myds.nc')

Out[5]:

<xarray.Dataset>
Dimensions:    (latitude: 73, longitude: 144, time: 62)
Coordinates:
  * latitude   (latitude) float32 90.0 87.5 85.0 82.5 80.0 77.5 75.0 72.5 ...
  * longitude  (longitude) float32 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 ...
  * time       (time) datetime64[ns] 2002-07-01T12:00:00 2002-07-01T18:00:00 ...
Data variables:
    tcw        (time, latitude, longitude) float64 10.15 10.15 10.15 10.15 ...
    new_var    (time, latitude, longitude) float64 30.46 30.46 30.46 30.46 ...

In [6]:

(myds + myds).to_netcdf('myds2.nc')
xr.open_dataset('myds2.nc')

Out[6]:

<xarray.Dataset>
Dimensions:    (latitude: 73, longitude: 144, time: 62)
Coordinates:
  * time       (time) datetime64[ns] 2002-07-01T12:00:00 2002-07-01T18:00:00 ...
  * latitude   (latitude) float32 90.0 87.5 85.0 82.5 80.0 77.5 75.0 72.5 ...
  * longitude  (longitude) float32 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 ...
Data variables:
    tcw        (time, latitude, longitude) float64 20.31 20.31 20.31 20.31 ...
    new_var    (time, latitude, longitude) float64 60.92 60.92 60.92 60.92 ...

2

检查文件的块大小。使用ncdump -s <infile>命令可以得到答案。如果时间维度的块大小大于1,您应该一次读取同样数量的年份,否则您将从磁盘中一次读取多年并仅使用一个年份。

慢有多慢?每个时间步骤最多几秒钟对于这个大小的数组听起来是合理的。

提供更多关于您以后如何处理数据的信息可能会为我们提供更多指导,以确定问题出现在哪里。


在时间维度上是1,其他维度上是什么?您能否澄清在您的情况下“慢”有多慢! - kakk11
块大小为720和1440,适用于其他维度。每次循环迭代只需要几分之一秒的时间。但当你需要迭代1200年时,这些时间会累加起来。 - user308827
那么您可能已经达到了当前文件和硬件的速度。如果您有重写数据的选项,可以尝试使用PyTables并将文件转换为blosc压缩的HDF5格式。这应该比zlib压缩的NetCDF4更快,尽管文件会稍微大一些。由于在您的问题中重写文件不是一个选项,因此我暂时不会将其添加到答案中,但是由于我最近将NetCDF转换为PyTables,我可以给您一些提示。 - kakk11
谢谢@kakk11,重写选项有多慢/快?也就是说,将netcdf重写为hdf5需要花费很长时间,以至于后续的速度优势毫无意义吗? - user308827
在尝试之前很难估计时间,也许需要15-30分钟?但是您只需对每个文件执行一次,所有后续分析都可以在hdf文件上完成,并且所有分析都将运行得更快。您还可以在转换过程中重新划分数据,这将使读取空间子集更快。因此,这取决于您计划读取文件的次数。您还可以尝试并行化读取,但再次取决于IO速度,它可能无法证明额外的编码工作。 - kakk11

0
这种方法有点取巧,但可能是最简单的解决方案:
将文件的子集读入内存,然后使用cPickle(https://docs.python.org/3/library/pickle.html)将文件重新保存到磁盘以供将来使用。从pickled数据结构加载数据很可能比每次解析netCDF更快。

很可能使用Blosc压缩(例如在PyTables中)来读写HDF5比使用cPickle更快。更不用说未压缩的数值数据文件大小可能会非常大了! - kakk11

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接