使用Numpy进行类似MATLAB的数组索引

13
在MATLAB和Numpy中,都可以用数组来索引其他数组。但是它们的行为不同。让我通过一个例子来解释一下。
MATLAB:
>> A = rand(5,5)

A =

    0.1622    0.6020    0.4505    0.8258    0.1067
    0.7943    0.2630    0.0838    0.5383    0.9619
    0.3112    0.6541    0.2290    0.9961    0.0046
    0.5285    0.6892    0.9133    0.0782    0.7749
    0.1656    0.7482    0.1524    0.4427    0.8173

>> A([1,3,5],[1,3,5])

ans =

    0.1622    0.4505    0.1067
    0.3112    0.2290    0.0046
    0.1656    0.1524    0.8173

Numpy:

In [2]: A = arange(25).reshape((5,5))

In [3]: A
Out[3]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])

In [6]: A[[0,2,4], [0,2,4]]
Out[6]: array([ 0, 12, 24])

简而言之:MATLAB选择行和列,Numpy "压缩"这两个索引数组,并使用元组指向条目。

如何在Numpy中实现与MATLAB相同的行为?

3个回答

20

您可以使用辅助函数numpy.ix_来获得Matlab的行为:

from numpy import ix_
A[ ix_( [0,2,4], [0,2,4] ) ]

13

你可以这样做:

A[[0,2,4],:][:,[0,2,4]]

这将会给你想要的类似于MATLAB的结果。

需要注意的是,如果你使用切片进行索引,则可以获得类似于MATLAB的结果,而无需任何此类技巧:

>>> A[1:3,1:3]
array([[ 6, 7],
       [11,12]])
在NumPy中,与MATLAB不同,1:3不仅仅是[1,2]或类似内容的缩写。(这时我觉得有必要提一下你可能已经知道的事情,即Python的1:3有点像[1,2],而MATLAB的则有点像[1,2,3]:在MATLAB中,右端点包含在内,在Python中则不包含。)

3
实际上,这样做效率相当低下。它需要在内存中创建一个临时数组,这个数组的大小取决于你正在处理的数组的大小,可能非常大。有几种更高效的方法可以实现这一点,包括使用ix_辅助函数。 - Bi Rico
2
是的,一切都正确。另一方面,由ix_构建的东西也相当大,尽管是临时的。我对一个5x5的数组进行了一些时间实验,就像原来的问题一样,得到了以下结果。[,:][:,][ix_()]快约25%,但如果您每次使用相同的索引,则使用ix_构造一个索引数组一次并重复使用它约快10倍 - 尽管当然您会付出内存使用成本。 - Gareth McCaughan
2
np.ix_ 在大多数情况下只使用微不足道的内存,因为它返回其参数的视图。此外,np.ix_ 是一个常量时间操作,而 A[I1, :][:, I2] 的时间和内存使用都是 ~ n^2。但是,如果你真的需要在 5x5 数组上获得 25% 的性能提升,那么你必须去做。 - Bi Rico
你关于 ix_ 几乎不占用内存的说法是正确的;我之前并没有意识到这一点。谢谢。如果你是在暗示索引一个小数组,使用不可预测的索引集合,并且关心每一微秒的情况很少见,那当然我同意!如果你关心性能,那么 ix_ 方法通常会更好。 - Gareth McCaughan
感谢!关于如何使用 A[[0,2,4],:][:,[0,2,4]] 在Numpy中获取正确的类似于Matlab的索引,请参阅最后的解释,我会为您翻译原始文本。 - linello
只是想补充一下..你不能使用这种方法来设置值。相反,请使用@User2593047的方法。 - zwep

2

使用numpy进行高效操作的方法是将索引数组重塑为与它们索引的轴相匹配。

In [103]: a=numpy.arange(100).reshape(10,10)

In [104]: a
Out[104]: 
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
   [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
   [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
   [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
   [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
   [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
   [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
   [70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
   [80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
   [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])

In [105]: x=numpy.array([3,6,9])

In [106]: y=numpy.array([2,7,8])

In [107]: a[x[:,numpy.newaxis],y[numpy.newaxis,:]]
Out[107]: 
array([[32, 37, 38],
      [62, 67, 68],
      [92, 97, 98]])

Numpy的广播规则是您的好朋友(比Matlab好得多)...
希望对您有所帮助。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接