高效的NumPy行旋转,距离可变。

3

给定一个2D的M x N的NumPy数组和一个旋转距离列表,我想要将所有M行在列表中的距离上进行旋转。这是我目前拥有的代码:

import numpy as np

M = 6
N = 8
dists = [2,0,2,1,4,2] # for example
matrix = np.random.randint(0,2,(M,N))

for i in range(M):
    matrix[i] = np.roll(matrix[i], -dists[i])

最后两行实际上是内部循环的一部分,会被执行数十万次,并且根据cProfile测量,这成为了我的性能瓶颈。例如,是否有可能避免for循环并更加高效地执行此操作?

1个回答

1
我们可以通过将distsrange(0...N)数组相加来模拟滚动行为,从而为我们提供每一行的列索引,从中选择并在同一行中洗牌元素时使用模运算。我们可以通过broadcasting帮助将此过程向量化到所有行中。因此,我们将有如下实现-
M,N = matrix.shape # Store matrix shape

# Get column indices for all elems for a rolled version with modulus operation
col_idx = np.mod(np.arange(N) + dists[:,None],N)

# Index into matrix with ranged row indices and col indices to get final o/p
out = matrix[np.arange(M)[:,None],col_idx]

虽然使用 timeit 模块测试这些代码行的速度快了近65%,但当我将其放入完整代码中时,它实际上变慢了20%(多次运行平均值,标准偏差不高)。有什么猜测是为什么吗? - Gekke Boer
@GekkeBoer 你实际情况下的矩阵形状是什么? - Divakar
我有更多的案例,但是我现在只是在使用4x4进行基准测试。然而,通过您的方法,我可以在某些循环之外计算col_idx,这样速度就快了很多!谢谢!我会稍等一下,以便在有人能够进一步改进它的情况下再接受它作为正确答案。 - Gekke Boer
@GekkeBoer 如果您正在使用相同形状和“dists”的数组,则应该可以重复使用此“col_idx”。 - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接