这里提供一种使用向量化和掩码的方法来解决问题,具体实现可以参考
掩码
。
valid_mask = a!=0
flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
a[flipped_mask] = a[valid_mask]
a[~flipped_mask] = 0
示例运行 -
In [90]: a
Out[90]:
array([[ 2, 1, 23, 32],
[34, 0, 3, 0], # <== Added a zero in between for variety
[ 3, 33, 0, 0],
[32, 0, 0, 0]])
# After code run -
In [92]: a
Out[92]:
array([[ 2, 1, 23, 32],
[ 0, 0, 34, 3],
[ 0, 0, 3, 33],
[ 0, 0, 0, 32]])
再来举一个通用的示例:
In [94]: a
Out[94]:
array([[1, 1, 2, 3, 1, 0, 3, 0, 2, 1],
[2, 1, 0, 1, 2, 0, 1, 3, 1, 1],
[1, 2, 0, 3, 0, 3, 2, 0, 2, 2]])
# After code run -
In [96]: a
Out[96]:
array([[0, 0, 1, 1, 2, 3, 1, 3, 2, 1],
[0, 0, 2, 1, 1, 2, 1, 3, 1, 1],
[0, 0, 0, 1, 2, 3, 3, 2, 2, 2]])
运行时测试
适用于通用情况的方法 -
def masking_based(a):
valid_mask = a!=0
flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
a[flipped_mask] = a[valid_mask]
a[~flipped_mask] = 0
return a
def sort_based(a):
return a[np.arange(a.shape[0])[:, None], (a != 0).argsort(1, kind="mergesort")]
时间 -
In [205]: a = np.random.randint(0,4,(1000,1000))
In [206]: %timeit sort_based(a)
10 loops, best of 3: 30.8 ms per loop
In [207]: %timeit masking_based(a)
100 loops, best of 3: 6.46 ms per loop
In [208]: a = np.random.randint(0,4,(5000,5000))
In [209]: %timeit sort_based(a)
1 loops, best of 3: 961 ms per loop
In [210]: %timeit masking_based(a)
1 loops, best of 3: 151 ms per loop