NumPy中将4D数组重塑为2D数组的直觉和想法

25

出于教学目的(不使用明显且易得的np.kron()),我正在实现Kronecker-product,结果得到一个4维数组作为中间结果,我必须对其进行重塑以获得最终结果。

但是,我仍然无法理解如何重塑这些高维数组。我有这个4D数组:

array([[[[ 0,  0],
         [ 0,  0]],

        [[ 5, 10],
         [15, 20]]],


       [[[ 6, 12],
         [18, 24]],

        [[ 7, 14],
         [21, 28]]]])

这是一个形状为(2, 2, 2, 2)的数组,我想将其重塑为(4,4)。有人可能认为这很容易实现。
np.reshape(my4darr, (4,4))

但是,上面的reshape操作并没有给我期望的结果,期望的结果是:

array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

正如您所看到的,期望结果中的所有元素都存在于4D数组中。我只是无法正确地进行所需的重塑操作。除了答案之外,一些关于如何对这样高维度的数组进行reshape的解释将非常有帮助。谢谢!
3个回答

47

ndnd转换的一般思路

这种ndnd转换的想法只使用两个东西 -

重新排列轴:以使扁平化版本对应于输出的扁平化版本的顺序。因此,如果您不知何故使用它两次,请再次查看,因为您不应该这样做。

重塑:将轴分割或将最终输出带到所需的形状。当输入较低维度且我们需要将其分成块时,大多数情况下需要拆分轴。同样,您不应该需要这超过两次。

因此,通常我们会有三个步骤:

    [ Reshape ]      --->  [ Permute axes ]   --->  [ Reshape ]

 Create more axes             Bring axes             Merge axes
                          into correct order

回溯法

在给定输入和输出的情况下,最安全的解决方法是使用回溯法,即将输入的轴(从较小的nd到较大的nd)或输出的轴(从较大的nd到较小的nd)进行分割。拆分的想法是将较小的nd的维数与较大的nd的维数相同。然后,研究输出的步幅,并将其与输入匹配以获得所需的排列顺序。最后,如果最终的nd是较小的,则可能需要进行重塑(默认方式或C顺序)以合并轴。

如果输入和输出的维数相同,则需要拆分两者并分成块,并将它们的步幅相互比较。在这种情况下,我们应该有块大小的额外输入参数,但这可能是离题。

示例

让我们使用这个特定案例来演示如何应用这些策略。在这里,输入是4D,而输出是2D。因此,我们很可能不需要重塑来拆分。因此,我们需要从排列轴开始。由于最终输出不是4D,而是2D,因此我们需要在最后进行重塑。

现在,这里的输入是:

In [270]: a
Out[270]: 
array([[[[ 0,  0],
         [ 0,  0]],

        [[ 5, 10],
         [15, 20]]],


       [[[ 6, 12],
         [18, 24]],

        [[ 7, 14],
         [21, 28]]]])

预期输出为:
In [271]: out
    Out[271]: 
    array([[ 0,  5,  0, 10],
           [ 6,  7, 12, 14],
           [ 0, 15,  0, 20],
           [18, 21, 24, 28]])

此外,这是从较大的nd到较小的nd的转换,因此回溯方法涉及将输出分割并研究其strides,并与输入中相应的值进行匹配。
                    axis = 3
                   ---      -->          
                                        
                    axis = 1                    
                   ------>           
axis=2|  axis=0|   [ 0,  5,  0, 10],        

               |   [ 6,  7, 12, 14],
               v  
      |            [ 0, 15,  0, 20],
      v
                   [18, 21, 24, 28]])

因此,所需的排列顺序为(2,0,3,1)
In [275]: a.transpose((2, 0, 3, 1))
Out[275]: 
array([[[[ 0,  5],
         [ 0, 10]],

        [[ 6,  7],
         [12, 14]]],


       [[[ 0, 15],
         [ 0, 20]],

        [[18, 21],
         [24, 28]]]])

然后,只需将其重新塑造为预期的形状:

In [276]: a.transpose((2, 0, 3, 1)).reshape(4,4)
Out[276]: 
array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

更多例子

我挖掘了我的历史记录,找到了一些基于ndnd转换的Q&A。虽然这些例子的解释较少,但它们可以作为其他案例。正如之前提到的那样,在每个地方最多只需要两个reshapes和至多一个swapaxes/transpose即可完成任务。它们列在下面:


11

看起来您需要进行 transpose 操作,然后再进行 reshape 操作。

x.transpose((2, 0, 3, 1)).reshape(np.prod(x.shape[:2]), -1)

array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])
为了帮助您理解为什么需要转置,让我们分析一下您错误的输出结果(通过单个reshape调用获得),以理解其错误之处。
这个结果的简单二维重塑版本(没有任何转置)看起来像这样 -
x.reshape(4, 4)

array([[ 0,  0,  0,  0],
       [ 5, 10, 15, 20],
       [ 6, 12, 18, 24],
       [ 7, 14, 21, 28]])

现在考虑这个输出与您预期输出之间的关系 -

array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

你会注意到,要获取实际结果需要对错误形状的输出进行类似 Z 字形的遍历 -

start
    | /|     /| /|
    |/ |    / |/ |
      /    /    / 
     /    /    /
    | /| /    | /|
    |/ |/     |/ |
                 end

这意味着您必须以不同的步长在数组上移动,以获得您的实际结果。简单的重新整形是不够的,您必须转置原始数组,使这些Z状元素连续在一起,以便后续的重塑调用可以给您所需的输出。

要正确理解如何进行转置,您应该沿输入跟踪元素,并确定需要跳转到哪个轴才能到达输出中的每个元素。然后相应地进行转置。 Divakar's answer 对此进行了很好的解释。


@juanpa.arrivillaga 你为什么删除了?它看起来是正确的。 - cs95
2
因为直接使用.transpose(2, 0, 3, 1)比使用.transpose(0,2,1,3)并采用Fortran顺序进行重塑更加优雅。 - juanpa.arrivillaga
@kmario23 嗯,我添加了一些解释,希望这就是你要找的。 - cs95
1
@cᴏʟᴅsᴘᴇᴇᴅ 需要使用您的解决方案来解释一个通用案例。希望这没问题。 - Divakar
1
@kmario23 没问题。我的解释与Divakar的不同,因为我想纠正你的误解,即简单的reshape就足够了。为此,我选择分析错误重塑的输出而不是原始输入。我对接受没有任何抱怨,他的答案是金标准。 - cs95
显示剩余5条评论

0

Divarkar的回答很好, 但有时候对我来说,检查transposereshape所涵盖的所有可能情况更容易。

例如,以下代码:

n, m = 4, 2
arr = np.arange(n*n*m*m).reshape(n,n,m,m)
for permut in itertools.permutations(range(4)):
    arr2 = (arr.transpose(permut)).reshape(n*m, n*m)
    print(permut, arr2[0])

使用transpose+reshape可以让我从四维数组中获取所有需要的内容。由于我知道输出应该是什么样子,所以我只需选择正确答案所显示的排列方式。如果我没有得到想要的结果,那么transpose+reshape就不够通用,无法满足我的需求,我需要做一些更复杂的操作。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接