我正在阅读一些深度学习代码。我在numpy数组的高级索引方面遇到了问题。我测试的代码如下:
import numpy
x = numpy.arange(2 * 8 * 3 * 64).reshape((2, 8, 3, 64))
x.shape
p1 = numpy.arange(2)[:, None]
sd = numpy.ones(2 * 64, dtype=int).reshape((2, 64))
p4 = numpy.arange(128 // 2)[None, :]
y = x[p1, :, sd, p4]
y.shape
y
的形状为什么是 (2, 64, 8)
?
以下是上述代码的输出:
>>> x.shape
(2, 8, 3, 64)
>>> p1
array([[0], [1]])
>>> sd
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
>>> p4
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]])
>>> y.shape
(2, 64, 8)
我看了这个链接:https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing,我认为它与广播有关:
x
的形状为(2, 8, 3, 64)
。
p1
很简单,它是array([[0], [1]])
,只是选择第一个维度中的ind 0,1
。而双重数组是用于广播。
p2
是:
,它表示选择第二个维度中的所有8个元素。
p3
有些棘手,它包含两个“列表”,以从第三个维度中的3个元素中选择一个,因此新生成的第三个维度应为1。
p4
表示选择第四个维度中的所有64个元素。因此,我认为
y.shape
应该为(2, 8, 1, 64)
。但正确的结果是
(2, 64, 8)
。为什么呢?
x[arr1, :, 1]
与x[arr1, :, arr2]
是相同的。如果所有的索引都是切片,标量将会被视为一个切片,所以这是需要注意的。 - Mad Physicistarange
索引,一切都应该正常工作。 - Mad Physicist