NumPy 4D数组的高级索引(附例)

5

我正在阅读一些深度学习代码。我在numpy数组的高级索引方面遇到了问题。我测试的代码如下:

import numpy

x = numpy.arange(2 * 8 * 3 * 64).reshape((2, 8, 3, 64))
x.shape

p1 = numpy.arange(2)[:, None]
sd = numpy.ones(2 * 64, dtype=int).reshape((2, 64))
p4 = numpy.arange(128 // 2)[None, :]

y = x[p1, :, sd, p4]
y.shape

y 的形状为什么是 (2, 64, 8)

以下是上述代码的输出:

>>> x.shape
(2, 8, 3, 64)

>>> p1
array([[0], [1]])

>>> sd
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

>>> p4
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]])

>>> y.shape
(2, 64, 8)

我看了这个链接:https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing,我认为它与广播有关: x的形状为(2, 8, 3, 64)p1很简单,它是array([[0], [1]]),只是选择第一个维度中的ind 0,1。而双重数组是用于广播。 p2:,它表示选择第二个维度中的所有8个元素。 p3有些棘手,它包含两个“列表”,以从第三个维度中的3个元素中选择一个,因此新生成的第三个维度应为1。 p4表示选择第四个维度中的所有64个元素。
因此,我认为y.shape应该为(2, 8, 1, 64)
但正确的结果是(2, 64, 8)。为什么呢?
1个回答

4
我第一次遇到numpy中的花式索引时也遇到了同样的问题。简短的答案是,没有什么诀窍:花式索引只是选择与索引相同形状的元素作为输出。如果仅使用花式索引,则输出数组的形状将与广播后的索引数组相同(在此处描述)。输出的形状几乎与输入的形状无关,除非您还添加了一个常规切片索引(在此处描述)。你的情况属于后者,这增加了混淆。

让我们看一下你的索引以了解发生了什么:

y = x[p1, :, sd, p4]
x.shape -> 2, 8, 3, 64
p1.shape -> 2, 1
sd.shape -> 2, 64
p4.shape -> 1, 64

如何进行操作的具体文档在这里:

需要区分两种索引组合:

  • 高级索引由一个切片、Ellipsisnewaxis分隔。例如:x[arr1, :, arr2]
  • 高级索引都相互靠近。例如:x[..., arr1, arr2, :],但不是x[arr1, :, 1],因为在这种情况下1是一个高级索引。

在第一种情况下,高级索引操作导致的维度首先出现在结果数组中,然后是子空间维度。 在第二种情况下,来自高级索引操作的维度插入到结果数组中与它们在初始数组中的位置相同的位置(后面的逻辑使得简单高级索引行为就像切片一样)。

强调我的

请注意,在上述两种情况下,花式索引部分的维度是索引数组的维度,而不是您正在索引的数组。
因此,您应该期望看到的是具有广播维度为p1sdp42, 64)的内容,然后是x的第二个维度的大小(8)。这确实是您得到的结果。
>>> y.shape
(2, 64, 8)

非常感谢您的解释。我正在消化这些信息。我可以问几个后续问题吗?选择结果正确吗?代码正在尝试索引选择(2,8,3,64)的第三个维度,因此我必须将轴从(2,64,8)交换到(2,8,1,64)?有没有更直接的方法来获得所需的结果:(2,8,1,64)?谢谢! - manhon
我不理解这行代码"For example x[..., arr1, arr2, :]but not x[arr1, :, 1]since 1 is an advanced index in this regard.",如果1是高级索引,那么x[arr1, :, 1]不就和x[arr1, :, arr2]一样吗? - manhon
@mahon。先回答第二个问题:你说得完全正确。它警告你x[arr1, :, 1]x[arr1, :, arr2]是相同的。如果所有的索引都是切片,标量将会被视为一个切片,所以这是需要注意的。 - Mad Physicist
@Mahon。您可以用适当的切片替换生成的arange索引,一切都应该正常工作。 - Mad Physicist

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接