使用长度可变的数组索引多维torch张量

3

我有一个索引列表和一个形状为:

shape = [batch_size, d_0, d_1, ..., d_k]
idx = [i_0, i_1, ..., i_k]

有没有一种有效的方法,可以在每个维度 d_0, ..., d_k 上使用索引 i_0, ..., i_k 进行高效的张量索引?(k 只能在运行时确定)

结果应该是:

tensor[:, i_0, i_1, ..., i_k] #tensor.shape = [batch_size]

目前我正在创建一个切片的元组,每个维度都有一个切片:

idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]

但我更喜欢像这样的东西:

tensor[:, *idx]

抱歉,我只能使用英语回答问题。
a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])

我想仅索引最后 len(indexes) 维度,如下所示:
a[:, indexes[0], indexes[1], indexes[2]]

但是在一般情况下,我不知道 indexes 有多长。


注意:这个答案并不适用于所有维度,它索引了所有的维度,而且无法处理一个合适的子集!

1个回答

2
不幸的是,您无法同时将切片和迭代器提供给索引(例如,a[:,*idx])。然而,您可以通过将其包装在括号中以转换为迭代器来实现几乎相同的效果。
a[(slice(None), *idx)]

在Python中,`x[(exp1, exp2, ..., expN)]`等同于`x[exp1, exp2, ..., expN]`;后者只是前者的语法糖。详见https://numpy.org/doc/stable/reference/arrays.indexing.html

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接