xarray:重新塑造数据,拆分维度

8

我有一个在xarray中的数据集,具有以下维度:

Dimensions:      (subject: 30, session: 5, time: 45000)
Coordinates:
  * subject      (subject) object '110' '112' '114' '117' ...
  * session      (session) object 'week1' 'week2' 'week3' ...
  * time         (time) timedelta64[ns] 00:00:00 00:00:00.040000 ...

我希望将每个试验(主题/会话组合)分成更小的时间段,例如分成每个15000个值的3个段。结果的维度可能如下所示:

(subject: 30, session: 5, segment: 3, time: 15000)

我已经搜索并尝试了很多方法,但都未成功,这怎么能做到?
其中一个我一直在尝试的方法似乎很接近,就是创建一个新的MultiIndex并将其解压缩。
segment_data = np.repeat(range(3),len(ds.time)//3)
segment = xr.Variable(dims='time',data=segment_data)
newtime_data = np.tile(ds.time[:len(ds.time)//3],3)
newtime = xr.Variable(dims='time',data=newtime_data)
dsr = ds.assign_coords(segment=segment,newtime=newtime)
dsr = dsr.set_index(segment='segment',newtime='newtime')
dsr = dsr.stack(fragment=['segment','newtime'])

然而,最后一行需要大量的内存,并且似乎会创建一个维度fragment: len(ds.time)**2,这看起来不太正常。我也不确定在此之后我需要做什么(unstack('fragment')?)。
编辑:一些尝试使我到达了这里:
x = np.repeat(range(3),15000)
y = np.tile(ds.time[:len(ds.time)//3],3)
dsr = (ds.assign_coords(segment=x,time2=y)
      .set_index(fragment=['segment','time2'])
      .unstack('fragment'))

这将产生以下结果:
(subject: 30, segment: 3, session: 5, time: 45000, time2: 15000)

虽然看起来很接近,但是每个time2点现在有45000个值,而应该只有一个值:

dsr.isel(subject=0,segment=0,session=0,time2=0)
# (time: 45000)

编辑:最终我找到了一个方法来做到这一点,请参见我的回答。欢迎进一步的建议!

1个回答

11

首先确保您有两个新维度的标签。在这种情况下,如下所示:

x = range(3) # 3 segments
y = ds.time[:len(ds.time)//3] # the first 1/3rd of the time labels

然后从这些标签*创建一个 pandas 多级索引。
ind = pd.MultiIndex.from_product((x,y),names=('segment','new_time'))

最后,用这个新索引替换数据集中的time索引,然后展开其级别以创建两个所需维度。
dsr = ds.assign(time=ind).unstack('time')

您可能需要使用rename来重命名新维度:

dsr = dsr.rename({'new_time':'time'})

结果尺寸:

(subject: 30, segment: 3, session: 5, time: 15000)

现在唯一的问题是维度的顺序(理想情况下,应该交换segmentsession)。我认为使用transpose可以解决这个问题,但是"尽管每个数组的维度顺序会改变,但数据集的维度本身仍将保持固定(排序)顺序。" 所以我可能会接受这种情况。

* 请注意,您无法使用要拆分的维度名称,因此我们在这里使用'new_time'。这是assign的一个不必要的限制吗?

** 我无法解释的另一个限制。


我发现创建MultiIndex时出现问题,报错为'unhashable type: 'DataArray'' - 针对y。需要将其转换为numpy数组才能继续。 - creanion

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接