如果我没记错的话,这个可以做到你期望的,并且速度很快:
a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)
例子:
import numpy as np
k=5; n=3; m=4
a = np.arange(k*n*m).reshape(k, n*m)
a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)
"""
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])
is transformed into:
array([[[ 0, 3, 6, 9],
[12, 15, 18, 21],
[24, 27, 30, 33],
[36, 39, 42, 45],
[48, 51, 54, 57]],
[[ 1, 4, 7, 10],
[13, 16, 19, 22],
[25, 28, 31, 34],
[37, 40, 43, 46],
[49, 52, 55, 58]],
[[ 2, 5, 8, 11],
[14, 17, 20, 23],
[26, 29, 32, 35],
[38, 41, 44, 47],
[50, 53, 56, 59]]])
"""
时间安排:
from time import time
k=37; n=42; m=53
a = np.arange(k*n*m).reshape(k, n*m)
start = time()
for _ in range(1_000_000):
res = a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0,1)
time() - start