在最后一个轴上每隔n个元素进行采样构成数组。

3
a 成为一个(不一定是一维的)NumPy 数组,沿着其最后一个轴有 n * m 个元素。我希望能够沿着最后一个轴将该数组“拆分”,以便从 0 开始每隔 n 个元素取一个,直到 n
明确一下,假设 a 的形状为 (k, n * m),那么我希望构造一个形状为 (n, k, m) 的数组。
np.array([a[:, i::n] for i in range(n)])

我的问题是,尽管这确实返回了我寻找的数组,但我仍然觉得可能有更有效和更整洁的NumPy例程可用。

干杯!


你不能这样做:Numpy不支持不规则数组,也就是包含不同大小的数组的数组(你可以通过使用包含Numpy数组的Numpy数组来欺骗,但这并不是真正的唯一的Numpy数组,而且速度非常慢)。提供的代码在n=2,m=5和k=3时失败。 - Jérôme Richard
抱歉,我的错!范围应该停在 n 而不是 m。结果数组的形状应为 (n,k,m)。我已编辑原始问题以反映这一点。 - William Crawford
3个回答

1

我认为这可以满足你的需求,而且不需要循环。我已经测试了2D输入,但是对于更多维度可能需要进行一些调整。

indexes = np.arange(0, a.size*n, n) + np.repeat(np.arange(n), a.size/n)
np.take(a, indexes, mode='wrap').reshape(n, a.shape[0], -1)

在我的测试中,它比您的原始列表解决方案慢一些。

1

如果我没记错的话,这个可以做到你期望的,并且速度很快:

a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)

例子:

import numpy as np
k=5; n=3; m=4
a = np.arange(k*n*m).reshape(k, n*m)
a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)
"""
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
       [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

is transformed into:

array([[[ 0,  3,  6,  9],
        [12, 15, 18, 21],
        [24, 27, 30, 33],
        [36, 39, 42, 45],
        [48, 51, 54, 57]],

       [[ 1,  4,  7, 10],
        [13, 16, 19, 22],
        [25, 28, 31, 34],
        [37, 40, 43, 46],
        [49, 52, 55, 58]],

       [[ 2,  5,  8, 11],
        [14, 17, 20, 23],
        [26, 29, 32, 35],
        [38, 41, 44, 47],
        [50, 53, 56, 59]]])
"""

时间安排:

from time import time
k=37; n=42; m=53
a = np.arange(k*n*m).reshape(k, n*m)

start = time()
for _ in range(1_000_000):
    res = a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0,1)
time() - start

# 0.95 s per 1 mil repetitions

1

要写出更快的Numpy实现是很困难的。一个高效的解决方案是使用Numba来加速。但是,内存访问模式可能是代码在相对较大的矩阵上运行缓慢的主要原因。因此,需要关注迭代顺序,以使访问相对缓存友好。此外,对于大型数组,使用多个线程可以更好地减轻由于相对较高的内存延迟(由于内存访问模式)而产生的开销。以下是一种实现方式:

import numba as nb

# The first call is slower due to the build.
# Please consider specifying the signature of the function (ie. input types)
# to precompile the function ahead of time.
@nb.njit # Use nb.njit(parallel=True) for the parallel version
def compute(arr, n):
    k, m = arr.shape[0], arr.shape[1] // n
    assert arr.shape[1] == n * m

    out = np.empty((n, k, m), dtype=arr.dtype)

    # Use nb.prange for the parallel version
    for i2 in range(k):
        for i1 in range(n):
            outView = out[i1, i2]
            inView = a[i2]
            cur = i1
            for i3 in range(m):
                outView[i3] = inView[cur]
                cur += n

    return out

以下是在我的机器上使用i5-9600KF处理器(6个核心)对k=37n=42m=53a.dtype=np.int32进行测试的结果:

John Zwinck's solution:    986.1 µs
Initial implementation:     91.7 µs
Sequential Numba:           62.9 µs
Parallel Numba:             14.7 µs
Optimal lower-bound:        ~7.0 µs

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接