NumPy中高级索引不一致性问题

7
为什么以下索引形式会产生不同形状的输出?
a = np.zeros((5, 5, 5, 5))
print(a[:, :, [1, 2], [3, 4]].shape)
# (5, 5, 2)

print(a[:, :, 1:3, [3, 4]].shape)
#(5, 5, 2, 2)

几乎可以确定我错过了一些显而易见的东西。

这就是如何使用__广播__进行高级索引的工作原理。 - VimNing
4个回答

5

[1, 2], [3, 4]并不意味着“在一个维度中选择索引1和2,在另一个维度中选择3和4”。它的意思是“选择索引对(1, 3)和(2, 4)”。

你的第一个表达式选择形如a,b,c,d位置上的所有元素,其中ab可以是任何索引,而cd必须是(1,3)(2,4)这一对。

你的第二个表达式选择形如a,b,c,d位置上的所有元素,其中ab可以是任何索引,c必须在半开区间[1,3)内,d必须是3或4。与第一个表达式不同的是,c和d允许是(2,3)(1,4)


请注意,在同一个索引表达式中同时使用基础索引和高级索引(其中大多数情况下是混合使用:和高级索引)会对结果的轴顺序产生不直观的影响。最好避免混合使用它们。

2
这比 pairwise 要微妙一些;列表会被 'broadcast' 相互匹配。在这种情况下 (2,),(2,)->(2,)。如果其中一个是 '列向量',我们将得到一个块:a[:, :, [[1], [2],[3]], [3, 4]] 会产生一个 (5,5,3,2) 的数组。 - hpaulj
@hpaulj:是的,在它的全面性方面,它比这个答案描述的要复杂得多。由于在这个例子中广播是一个NOP,所以我没有提到它。 - user2357112
@hpaulj:在这种情况下,广播(3,1),(2,)->(3,1),(1,2)->(3,2),对吗? - VimNing
1
@Rainning,没错,广播的第一步是根据需要添加前导大小为1的维度。这里将(2,)转换为(1,2)。 - hpaulj

0
在第一种情况下: [1,2][3,4] 都是形状为 (2,) 的数组,它们合在一起会得到一个形状为 (2,)单一(数组)维度。因此,在第一个结果中,您得到了 (5,5,2),其中最后的 (2,) 是在过程中新创建的。
在第二种情况下:唯一的列表 [3,4] 本身就会产生一个形状为 (2,) 的(数组)维度。而切片 1:3 只会将其自己(数组)维度的长度更改为 2。因此结果是 (5,5,2,2)

0

第一个,

a[:, :, [1, 2], [3, 4]]

将索引成对取出并选择以下子数组:

a[:, :, 1, 3]
a[:, :, 2, 4]

而第二个则生成所有可能的组合(并相应地进行形状调整),即

a[:, :, 1, 3]
a[:, :, 1, 4]
a[:, :, 2, 3]
a[:, :, 2, 4]

可以通过运行以下练习来验证这一点。不要将a初始化为零阵列,而是使用np.arange并重塑它。
a = np.arange(5**4).reshape((5, 5, 5, 5))
print(a[:, :, [1, 2], [3, 4]])

输出的前几行是:

[[[  8  14]
  [ 33  39]
  [ 58  64]...

而数组a本身是

[[[[  0   1   2   3   4]
   [  5   6   7   8   9]
   [ 10  11  12  13  14]
   [ 15  16  17  18  19]
   [ 20  21  22  23  24]]...

所以8出现在(1,3)(在最内层的二维数组中,1表示第二行,3表示第四列),与预期相符;14出现在(2,4)。同样地,33也在下一个二维子数组中的索引(1,3)处,39在(2,4)处。


我不确定你在说什么,但是 list 不是那样工作的。 - user2357112
@user2357112,我认为NumPy数组确实可以。那么为什么这个练习会给出那个结果呢? - R. S. Nikhil Krishna
不,我的意思是像list(a[:, :, 1, 3], a[:, :, 2, 4])list(a[:, :, 1, 3], a[:, :, 2, 3], a[:, :, 2, 3], a[:, :, 2, 4])这样的表达式没有任何意义。 - user2357112
如果你是指顺序的话,那是我的错。我只是想强调使用索引会选择所有可能的组合,但定义显式列表可能会生成对。已经将更改纳入考虑。感谢你的指出。 - R. S. Nikhil Krishna

0

当您在高级索引中有多个列表时,它表示这些列表应该成对取出。相比之下,当您使用切片时,您会得到每个列表中所有切片的元素。

为了看到区别,请考虑以下示例:

>>> print(a[:, :, [1, 2, 3], [3, 4]].shape)
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (3,) (2,) 

这是因为第一个列表长度为3,而第二个长度为2。它们不匹配,因此会出现错误。

相比之下,如果您使用切片,它可以完美地工作:

>>> print(a[:, :, 1:4, [3, 4]].shape)
(5, 5, 3, 2)

说明

为了明白为什么会出现这种情况,我们查阅the numpy indexing documentation,其中提到:

When the index consists of as many integer arrays as the array being indexed has dimensions, the indexing is straight forward, but different from slicing.

Advanced indexes always are broadcast and iterated as one:

result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
                       ..., ind_N[i_1, ..., i_M]]

因为从你的回答中不清楚“taken pairwise”的含义,所以将另一个答案标记为解决方案。无论如何感谢你!真是疯狂,只用了几分钟就有了三个答案。 - Vadim Kantorov

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接