为什么在赋值/复制时NumPy数组会失去维度?

4

I have the following code:

print(type(a1), a1.shape)
a2 = a1                  #.reshape(-1,1,2) this solves my problem
print(type(a2), a2.shape)

输出结果如下:
<class 'numpy.ndarray'> (8, 1, 2)
<class 'numpy.ndarray'> (8, 2)

我知道(注释掉的)reshape 可以解决我的问题,但是我想了解为什么简单的赋值会导致数组失去中心维度。

有人知道发生了什么事吗?为什么使用另一个名称引用数组会改变其维度?


1
我无法想象任何一种方式,使得赋值a2 = a1会改变底层的NumPy数组对象。请问您能否发布完整的代码以重现您所看到的问题,同时附上您使用的任何库的版本?(我已经看到了您在Ajit的回答下面留下的opencv链接,但我无法准确地推断出这段代码与本问题的关系。) - Alex Riley
我注意到了reshape,但是在那段代码中发生的远不止简单的赋值或复制。NumPy数组不会因为赋值或使用copy()方法而失去维度,只有在特定的索引操作、显式重塑、沿轴缩减等情况下才会失去维度。在你提供的链接中,good_new被赋值为p1的索引结果,而p1cv.calcOpticalFlowPyrLK(..., p0, ...)的输出,所以如果需要引入额外的维度,我并不感到惊讶。我认为你问题的前提是不正确的,所以我不确定该如何回答。 - Alex Riley
good_new有一个中心维度,只需打印它即可。如果您在赋值语句上删除reshape语句,p0将会失去它。我也对此感到惊讶,但事实就是如此。干杯! - Tony Power
1
好的,我想我明白了:当我运行脚本时,“good_new”的形状为(17,2)。这似乎是因为“p0”和“p1”都具有形状(17,1,2),而“st”具有形状(17,1),脚本设置“good_new = p1 [st == 1]”(第43行)。由于布尔索引(即“p1 [st == 1]”),从“p1”中删除了中心维度,然后将此2D数组分配给“good_new”变量。这就是为什么在将其重新分配回名称“p0”之前必须重塑“good_new”(必须为3D)。因此,由于索引而不是赋值/复制,维度丢失了。 - Alex Riley
亲爱的Alex,我完全忘记了那个操作。你是对的!我需要更多的睡眠。非常感谢。 - Tony Power
显示剩余3条评论
2个回答

1

看一下评论中提到的openCV脚本,通过布尔索引会丢失一个维度,而不仅仅是赋值操作。

那个脚本中数组的名称是p0good_new

以下是该脚本中的操作步骤:

  1. p0 is a 3D array with shape (17, 1, 2).

  2. The line:

    p1, st, err = cv.calcOpticalFlowPyrLK(old_gray, frame_gray, p0, None, **lk_params)
    

    creates new arrays, with array p1 having shape (17, 1, 2) and array st having shape (17, 1).

  3. The assignment good_new = p1[st==1] creates a new array object by a Boolean indexing operation on p1. This is a 2D array has shape (17, 2). A dimension has been lost through the indexing operation.

  4. The name p0 needs to be assigned back to the array data contained in good_new, but p0 also needs to be 3D. To achieve this, the script uses p0 = good_new.reshape(-1, 1, 2).


为了完整起见,值得总结一下为什么步骤(3)中的布尔索引操作会导致维度消失。
布尔数组st == 1的形状为(17, 1),与p1的初始维度(17, 1, 2)相匹配。
这意味着选择发生在p1的第二个维度上:索引器数组st == 1决定了应该在结果数组中包含哪些形状为(2,)的数组。最终数组的形状将为(n, 2),其中n是布尔数组中True值的数量。
此行为在NumPy文档此处中详细说明。

1
我不确定为什么会出现这种情况,但它不应该返回这样的结果。请问您可以分享一下a1是如何创建的吗?
我尝试了下面的方法,但无法重新创建它。
a1=np.ones((8,1,2),dtype=np.uint8)
print(type(a1), a1.shape)

<class 'numpy.ndarray'> (8, 1, 2)

a2=a1

print(type(a2), a2.shape)

<class 'numpy.ndarray'> (8, 1, 2)`

嗨,Ajit。您可以在此处查看代码:https://github.com/opencv/opencv/blob/master/samples/python/tutorial_code/video/optical_flow/optical_flow.py。请注意代码中的最后一行。祝好。 - Tony Power

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接