交换numpy数组的维度

66

我想要做以下事情:

for i in dimension1:
  for j in dimension2:
    for k in dimension3:
      for l in dimension4:
        B[k,l,i,j] = A[i,j,k,l]

不使用循环。最终,A和B都包含相同的信息,但索引方式不同。

我必须指出,维度1、2、3和4可以相同也可以不同。因此,使用numpy.reshape()似乎有些困难。


什么是“dimension”变量?你是否忘记了一些“range”调用? - user2357112
5个回答

130

在numpy中完成此操作的规范方法是使用np.transpose的可选排列参数。对于您的情况,从ijklklij,排列方式为(2, 3, 0, 1),例如:

In [16]: a = np.empty((2, 3, 4, 5))

In [17]: b = np.transpose(a, (2, 3, 0, 1))

In [18]: b.shape
Out[18]: (4, 5, 2, 3)

43
请注意:Jaime的回答更好。NumPy专门提供了np.transpose来实现这个目的。
或者使用np.einsum;虽然这可能是对其预期用途的扭曲,但语法非常好:
In [195]: A = np.random.random((2,4,3,5))

In [196]: B = np.einsum('klij->ijkl', A)

In [197]: A.shape
Out[197]: (2, 4, 3, 5)

In [198]: B.shape
Out[198]: (3, 5, 2, 4)

In [199]: import itertools as IT    
In [200]: all(B[k,l,i,j] == A[i,j,k,l] for i,j,k,l in IT.product(*map(range, A.shape)))
Out[200]: True

2
@DSM - 我想你获得了“今日最差的双关语”奖! - Joe Kington
3
对我来说,这是最好的解决方案,因为实际上我有7个维度,多次交换或滚动很困难。我曾经考虑过使用einsum,但某种方式上我使用了错误的语法,无法完成。非常感谢!!! 我非常喜欢einsum <3 - sponce

16

您可以两次使用rollaxis

>>> A = np.random.random((2,4,3,5))
>>> B = np.rollaxis(np.rollaxis(A, 2), 3, 1)
>>> A.shape
(2, 4, 3, 5)
>>> B.shape
(3, 5, 2, 4)
>>> from itertools import product
>>> all(B[k,l,i,j] == A[i,j,k,l] for i,j,k,l in product(*map(range, A.shape)))
True

或者可能两次使用swapaxes更容易理解:
>>> A = np.random.random((2,4,3,5))
>>> C = A.swapaxes(0, 2).swapaxes(1,3)
>>> C.shape
(3, 5, 2, 4)
>>> all(C[k,l,i,j] == A[i,j,k,l] for i,j,k,l in product(*map(range, A.shape)))
True

1
顺便说一句,product(*map(range, A.shape)) 可以更简洁地写成 np.ndindex(*A.shape) - Joe Kington

6

可以利用numpy.moveaxis()移动所需的轴到所需的位置。以下是一个示例,借鉴自Jaime's answer

In [160]: a = np.empty((2, 3, 4, 5))

# move the axes that are originally at positions [0, 1] to [2, 3]
In [161]: np.moveaxis(a, [0, 1], [2, 3]).shape 
Out[161]: (4, 5, 2, 3)

你能看一下我在这里的问题吗:https://datascience.stackexchange.com/questions/58123/np-shape-1-reversal-of-dims-and-np-moveaxis-splitting-array np.moveaxis() 是如何将数组分割成r、b、g的?我的导师告诉我是np.split(),但它是否隐式地与np.movieaxis一起调用? - mLstudent33

2
我会查看numpy.ndarray.shape和itertools.product:
import numpy, itertools
A = numpy.ones((10,10,10,10))
B = numpy.zeros((10,10,10,10))

for i, j, k, l in itertools.product(*map(xrange, A.shape)):
    B[k,l,i,j] = A[i,j,k,l]

如果你说“不使用循环”,我假设你的意思是“不使用嵌套循环”。当然,除非有一些内置的numpy函数可以实现这个需求,否则我认为这是最佳选择。


1
在NumPy中,避免循环的一个普遍目标是让代码的C部分来完成重活。这意味着嵌套循环或itertools.product都是不可取的。 - user2357112
谢谢metaperture! 你在建议什么,user2357112?哪个C例程可以做到这一点? - sponce

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接