Numpy: 多数组的高级索引

3

有没有一种高效的方法可以在多个数组之间进行索引?

例如,我有一个想要索引的数组:

a = [[1,2,3],[4,5,6]]

另一个数组包含索引。 b = [[0, 1], [1,2]] 我期望得到 [[1, 2], [5, 6]],它用[0,1]索引了a的第一行,并用[1,2]索引了a的第二行。
谢谢。

我认为https://dev59.com/b2Ik5IYBdhLWcg3wke3M不是这个问题的好重复引用。 - hpaulj
2个回答

2
如果变量 ab 长度相同,您可以尝试使用以下方式中的np.take
import numpy as np

a = [[1,2,3],[4,5,6]]
b = [[0, 1], [1,2]]
result = [np.take(a[i],b[i]).tolist() for i in range(len(a))]

print(result)
# result: [[1, 2], [5, 6]]

2
In [107]: a = [[1,2,3],[4,5,6]]
In [108]: b = [[0, 1], [1,2]]

ab都是列表。适当的解决方案是嵌套列表推导式。

In [111]: [[a[i][j] for j in x] for i,x in enumerate(b)]
Out[111]: [[1, 2], [5, 6]]

现在,如果将a转换为numpy数组:
In [112]: np.array(a)[np.arange(2)[:,None], b]
Out[112]: 
array([[1, 2],
       [5, 6]])

在这里,数组的第一维使用 (2,1) 数组进行索引,第二维使用 (2,2)。它们一起广播以生成一个 (2,2) 的结果。

Numpy提取子矩阵

是在相同的方向上工作,但被接受的答案使用了 ix_

Y[np.ix_([0,3],[0,3])]

在 (2,2) 情况下将不起作用,b

In [113]: np.array(a)[np.ix_(np.arange(2), b)]
ValueError: Cross index must be 1 dimensional

ix_ 将把第一个维度的 np.arange(2) 变成右边的 (2,1)。


这可能会使广播更加明显:

In [114]: np.array(a)[[[0,0],[1,1]], [[0,1],[1,2]]]
Out[114]: 
array([[1, 2],
       [5, 6]])

它选择元素(0,0)、(0,1)、(1,1)和(1,2)


为了进一步测试这个,让b不对称:

In [138]: b = [[0, 1,1], [1,2,0]]       # (2,3)
In [139]: np.array(a)[np.arange(2)[:,None], b]
Out[139]: 
array([[1, 2, 2],
       [5, 6, 4]])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接