使用索引列表按行选择NumPy中特定的列索引

151

我很难选择NumPy矩阵中每行的特定列。

假设我有以下矩阵,我会称其为X:


[1, 2, 3]
[4, 5, 6]
[7, 8, 9]

我还有一份每行的列索引列表,我称之为Y

[1, 0, 2]

我需要获取这些值:

[2]
[4]
[9]

我可以生成一个与X形状相同的矩阵,其中每列都是一个bool / int值,范围为0-1,表示是否为所需列,而不是具有索引Ylist

[0, 1, 0]
[1, 0, 0]
[0, 0, 1]

我知道可以通过迭代数组并选择所需的列值来完成此操作。但是,这将在大量数据的大型数组上频繁执行,因此必须尽可能快地运行。

因此,我想知道是否有更好的解决方案?


1
对您来说,这个答案是否更好?https://dev59.com/HGQn5IYBdhLWcg3wIkMF#17081678 - GoingMyWay
7个回答

155

如果您有一个布尔数组,您可以直接基于该数组进行选择,如下所示:

>>> a = np.array([True, True, True, False, False])
>>> b = np.array([1,2,3,4,5])
>>> b[a]
array([1, 2, 3])

跟随你最初的例子,你可以执行以下操作:

>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> b = np.array([[False,True,False],[True,False,False],[False,False,True]])
>>> a[b]
array([2, 4, 9])

您还可以添加 arange 并对其进行直接选择,但这取决于您生成布尔数组的方式以及您的代码外观如何。 YMMV。

>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> a[np.arange(len(a)), [1,0,2]]
array([2, 4, 9])

28
感谢赞赏arange的例子。对我来说,它特别有用,因为可以从多个矩阵中检索不同的块(基本上是这个例子的三维情况)。 - Gerrit-K
2
你好,能否解释一下为什么我们要使用 arange 而不是 :?我知道你的方法可行而我的不行,但我想了解原因。 - marcotama
1
因为它是一个numpy数组而不是普通的Python列表,所以“:”语法的工作方式不同。 - Slater Victoroff
3
@SlaterTyranus,谢谢您的回复。经过一些阅读后,我的理解是,将 : 与高级索引混合使用意味着:“对于沿着 : 的每个子空间,应用给定的高级索引”。我的理解正确吗? - marcotama
@tamzord,请解释一下你所说的“子空间”是什么意思。 - Slater Victoroff
@SlaterTyranus 我的意思是数学定义:当你有一个n维空间并固定一个坐标时,你会得到一个子空间;也就是说,一个(n-1)维空间。文档使用了这个术语,我认为它也使用了相同的定义:(ctrl-F“subspace”)http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html 希望这有意义。 - marcotama

56

4
为什么需要使用arange而不是仅使用':'或range,我很难理解。 - MadmanLee
2
@MadmanLee 你好,使用:会输出多个len(a)次的结果,而指定每行的索引会打印预期的结果。 - GoingMyWay
1
我认为这是解决这个问题的恰当而优雅的方式。 - GoingMyWay

40

最近的numpy版本新添加了take_along_axis(和put_along_axis)函数可以更干净地进行索引。

In [101]: a = np.arange(1,10).reshape(3,3)                                                             
In [102]: b = np.array([1,0,2])                                                                        
In [103]: np.take_along_axis(a, b[:,None], axis=1)                                                     
Out[103]: 
array([[2],
       [4],
       [9]])

它的运作方式和以下相同:

In [104]: a[np.arange(3), b]                                                                           
Out[104]: array([2, 4, 9])

但是轴处理不同。它的特别目标是应用argsortargmax的结果。


非常感谢您提供这个出色的答案! - gsandhu

7
一个简单的方法可能如下所示:
In [1]: a = np.array([[1, 2, 3],
   ...: [4, 5, 6],
   ...: [7, 8, 9]])

In [2]: y = [1, 0, 2]  #list of indices we want to select from matrix 'a'

range(a.shape[0]) will return array([0, 1, 2])

In [3]: a[range(a.shape[0]), y] #we're selecting y indices from every row
Out[3]: array([2, 4, 9])

3
您可以使用迭代器来实现。像这样:

np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)

时间:

N = 1000
X = np.zeros(shape=(N, N))
Y = np.arange(N)

#@Aशwini चhaudhary
%timeit X[np.arange(len(X)), Y]
10000 loops, best of 3: 30.7 us per loop

#mine
%timeit np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)
1000 loops, best of 3: 1.15 ms per loop

#mine
%timeit np.diag(X.T[Y])
10 loops, best of 3: 20.8 ms per loop

1
OP提到它应该在大型数组上运行得很快,所以你的基准测试并不是很有代表性。我很好奇你的最后一种方法在(更)大的数组上的表现如何! - user2379410
@moarningsun:已更新。np.diag(X.T[Y])非常慢...但是np.diag(X.T)非常快(10微秒)。我不知道为什么。 - Kei Minagawa

1

hpaulj使用take_along_axis的答案应该被接受。

这里是一个带有N维索引数组的派生版本:

>>> arr = np.arange(20).reshape((2,2,5))
>>> idx = np.array([[1,0],[2,4]])
>>> np.take_along_axis(arr, idx[...,None], axis=-1)
array([[[ 1],
        [ 5]],

       [[12],
        [19]]])

请注意,选择操作对形状是无知的。我使用这个方法来通过拟合抛物线来优化可能是向量值的argmax结果,该结果来自于histogram函数:
def interpol(arr):
    i = np.argmax(arr, axis=-1)
    a = lambda Δ: np.squeeze(np.take_along_axis(arr, i[...,None]+Δ, axis=-1), axis=-1)
    frac = .5*(a(1) - a(-1)) / (2*a(0) - a(-1) - a(1)) # |frac| < 0.5
    return i + frac

请注意squeeze的使用,以去除大小为1的维度,从而得到与峰值位置的整数和小数部分ifrac相同的形状。
我很确定可以避免使用lambda,但是插值公式是否仍然美观呢?

0
另一个聪明的方法是先转置数组,然后再对其进行索引。最后,取出对角线,这总是正确的答案。
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
Y = np.array([1, 0, 2, 2])

np.diag(X.T[Y])

步骤:

原始数组:

>>> X
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]])

>>> Y
array([1, 0, 2, 2])

转置以使其能够正确索引。

>>> X.T
array([[ 1,  4,  7, 10],
       [ 2,  5,  8, 11],
       [ 3,  6,  9, 12]])

按 Y 轴顺序获取行。

>>> X.T[Y]
array([[ 2,  5,  8, 11],
       [ 1,  4,  7, 10],
       [ 3,  6,  9, 12],
       [ 3,  6,  9, 12]])

现在对角线应该变得清晰了。

>>> np.diag(X.T[Y])
array([ 2,  4,  9, 12]

2
这个技术上看起来非常优雅并且能够正常工作。然而,当你处理大型数组时,我发现这种方法会完全崩溃。在我的情况下,NumPy占用了30GB的交换空间并填满了我的SSD。我建议使用高级索引方法代替。 - 5nefarious

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接