如何理解Python NumPy数组中的空维度?

11
在Python的numpy包中,我不太理解当一个ndarray的第二个维度为空时的情况。以下是一个例子:
    In[1]: d2 = np.random.rand(10)
    In[2]: d2.shape = (-1, 1)

    In[3]: print d2.shape
    In[4]: print(d2)

    In[5]: print d2[::2, 0].shape
    In[6]: print d2[::2, 0]

    Out[3]:(10, 1)
    Out[4]:
[[ 0.12362278]
 [ 0.26365227]
 [ 0.33939172]
 [ 0.91501369]
 [ 0.97008342]
 [ 0.95294087]
 [ 0.38906367]
 [ 0.1012371 ]
 [ 0.67842086]
 [ 0.23711077]]

    Out[5]: (5,)
    Out[6]: [ 0.12362278  0.33939172  0.97008342  0.38906367  0.67842086]

我理解d2是一个10行1列的ndarray。Out[6]很明显是一个1行5列的数组,怎么可能维度是(5,)呢?空的第二个维度代表什么意思?


7
(5,) 是 Python 中表示只有一个元素的元组的方式,因为(5)可能被解释为仅仅是数字5。 - jojonas
3
“Out[6] 显然是一个 1x5 的数组” - 不,该数组上没有“1x”的描述。它是一维数组,其唯一的维度长度为5。 - user2357112
3个回答

13

让我举一个例子,说明一个重要的区别。

d1 = np.array([1,2,3,4,5]) # array([1, 2, 3, 4, 5])
d1.shape -> (5,) # row array.    
d1.size -> 5
# Note: d1.T is the same as d1.

d2 = d1[np.newaxis] # array([[1, 2, 3, 4, 5]]). Note extra []
d2.shape -> (1,5) 
d2.size -> 5
# Note: d2.T will give a column array
array([[1],
       [2],
       [3],
       [4],
       [5]])
d2.T.shape -> (5,1)

6
我认为ndarrays应该将一维数组表示为具有1的厚度的二维数组。可能是因为“ndarray”这个名称让我们想到高维,但n可以是1,所以ndarrays可以只有一个维度。
比较以下内容:
x = np.array([[1], [2], [3], [4]])
x.shape
# (4, 1)
x = np.array([[1, 2, 3, 4]])
x.shape
#(1, 4)
x = np.array([1, 2, 3, 4])
x.shape
#(4,)

而 (4,) 表示 (4)。

如果我将 x 重新塑形并返回到 (4),它会恢复原状。

x.shape = (2,2)
x
# array([[1, 2],
#       [3, 4]])
x.shape = (4)
x
# array([1, 2, 3, 4])

1

这里需要理解的主要是,使用整数进行索引与使用切片进行索引是不同的。例如,当您使用整数对1d数组或列表进行索引时,您会得到一个标量值,但是当您使用切片进行索引时,您将分别得到一个数组或列表。对于2d+数组也是如此。例如:

# Make a 3d array:
import numpy as np
array = np.arange(60).reshape((3, 4, 5))

# Indexing with ints gives a scalar
print array[2, 3, 4] == 59
# True

# Indexing with slices gives a 3d array
print array[:2, :2, :2].shape
# (2, 2, 2)

# Indexing with a mix of slices and ints will give an array with < 3 dims
print array[0, :2, :3].shape
# (2, 3)
print array[:, 2, 0:1].shape
# (3, 1)

这个概念在理论上非常有用,因为有时候把一个数组看作向量的集合会非常方便。例如,我可以把空间中的N个点表示为一个(N, 3)的数组:
n_points = np.random.random([10, 3])
point_2 = n_points[2]
print all(point_2 == n_points[2, :])
# True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接