在NumPy数组中获取第k维度的第i个切片。

9

我有一个 n 维的numpy数组,我想要获取第 k 维中的第 i 个切片。肯定有更好的方法。

# ... 
elif k == 5:
    b = a[:, :, :, :, :, i, ...]
# ...
5个回答

16
b = a[(slice(None),) * k + (i,)]
手动构建索引元组。
正如Python 语言参考手册中所述,一个表达式的形式为:
a[:, :, :, :, :, i]

被转换成

a[(slice(None), slice(None), slice(None), slice(None), slice(None), i)]

我们可以直接构建元组来实现相同的效果,而不是使用切片符号。 (有一个小问题,直接构建元组会产生a[(i,)]而不是a[i]对于k=0,但是NumPy对标量i处理方式相同。)


哇,我很乐意复制粘贴这个,但是文档链接可能仍然有用。 - Nico Schlömer
切片符号,例如 a[<something>] 只是对 __getitem__ 方法的调用的语法糖,传递给它一个元组,这种情况下是一个由切片对象组成的元组(: 符号是 slice 对象的语法糖)。 - juanpa.arrivillaga
你也可以这样写:b[np.index_exp[:] * k + np.index_exp[i]],这样可能更易读一些。 - Eric

4

我不确定它是否适用于k维,但对于2维是有效的。

a.take(i,axis=k)

2
这样做的缺点是会创建一个副本而不是视图。 - user2357112
@user2357112:哎呀,有点出乎意料。 - Eric

2
基本上,您希望能够以编程方式创建元组:,:,:,:,:,i,...,以便将其作为a的索引传递。不幸的是,您不能直接在冒号运算符上使用普通元组乘法(即(:,) * k无法生成k个冒号操作符的元组)。但是,您可以通过使用colon = slice(None)来获得“冒号切片”的实例。然后,您可以执行b = a[(colon,) * k + (i,)],这将有效地在a的第k维中的第i列处索引a

将其封装在一个函数中,您将得到:

def nDimSlice(a, k, i):
    colon = slice(None)
    return a[(colon,) * k + (i,)]

2
对于NumPy来说,i之后的额外切片是不必要的。 - Aaron
@Aaron 谢谢,我不知道! - user108471
这个答案真的很有帮助!如果有人想知道,在该索引上提取N个切片,只需将(i,)替换为(slice(0, n, 1),)即可。 - Roy

2

这里有一个可以处理负轴参数的延迟输入,而不需要事先知道它的操作数的形状:

def put_at(inds, axis=-1, slc=(slice(None),)):
    return (axis<0)*(Ellipsis,) + axis*slc + (inds,) + (-1-axis)*slc

用于以下情况:

a[put_at(ind_list,axis=axis)]

ind_list可以像您的情况一样是标量,也可以是更有趣的东西。

我的这条评论中复制。


0

我不确定这种方法是否会创建数组的整个副本*,但我会取传输矩阵的一个切片以获取第k轴:

import numpy as np

def get_slice(arr, k, i):
    if k >= arr.ndim: #we need at least k dimensions (0 indexed)
        raise ValueError("arr is of lower dimension than {}".format(k))

    axes_reorder = list(range(arr.ndim)) #order of axes for transpose
    axes_reorder.remove(k) #remove original position of k
    axes_reorder.insert(0,k) #insert k at beginning of order

    return arr.transpose(axes_reorder)[i] #k is first axis now

这还有一个额外的好处,就是在尝试切片之前更容易检查维数的数量。

* 根据文档,只要可能,就会创建一个内存视图。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接