在numpy矩阵中交换零值

3

我有一个像这样的numpy矩阵:

array([[2,  1, 23, 32],
       [34, 3, 3, 0],
       [3, 33, 0, 0],
       [32, 0, 0, 0]], dtype=int32)

现在我想把所有数字向右移,并将零交换到左边,就像这样:
array([[2, 1,  23, 32],
       [0, 34, 3,  3],
       [0, 0,  3,  33],
       [0, 0,  0,  32]], dtype=int32)

有没有一种简洁的Pythonic方式来完成这个任务,也许可以使用numpy、pandas或scikit-learn的API方法?


你尝试过已发布的解决方案吗?是否有效? - Divakar
是的!我正在努力决定接受什么答案,这很困难。 - user1506145
6个回答

3
这里提供一种使用向量化和掩码的方法来解决问题,具体实现可以参考掩码
valid_mask = a!=0
flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
a[flipped_mask] = a[valid_mask]
a[~flipped_mask] = 0

示例运行 -

In [90]: a
Out[90]: 
array([[ 2,  1, 23, 32],
       [34,  0,  3,  0],  # <== Added a zero in between for variety
       [ 3, 33,  0,  0],
       [32,  0,  0,  0]])

# After code run -

In [92]: a
Out[92]: 
array([[ 2,  1, 23, 32],
       [ 0,  0, 34,  3],
       [ 0,  0,  3, 33],
       [ 0,  0,  0, 32]])

再来举一个通用的示例:

In [94]: a
Out[94]: 
array([[1, 1, 2, 3, 1, 0, 3, 0, 2, 1],
       [2, 1, 0, 1, 2, 0, 1, 3, 1, 1],
       [1, 2, 0, 3, 0, 3, 2, 0, 2, 2]])

# After code run -

In [96]: a
Out[96]: 
array([[0, 0, 1, 1, 2, 3, 1, 3, 2, 1],
       [0, 0, 2, 1, 1, 2, 1, 3, 1, 1],
       [0, 0, 0, 1, 2, 3, 3, 2, 2, 2]])

运行时测试

适用于通用情况的方法 -

# Proposed in this post
def masking_based(a):
    valid_mask = a!=0
    flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
    a[flipped_mask] = a[valid_mask]
    a[~flipped_mask] = 0
    return a

# @Psidom's soln            
def sort_based(a):
    return a[np.arange(a.shape[0])[:, None], (a != 0).argsort(1, kind="mergesort")]

时间 -

In [205]: a = np.random.randint(0,4,(1000,1000))

In [206]: %timeit sort_based(a)
10 loops, best of 3: 30.8 ms per loop

In [207]: %timeit masking_based(a)
100 loops, best of 3: 6.46 ms per loop

In [208]: a = np.random.randint(0,4,(5000,5000))

In [209]: %timeit sort_based(a)
1 loops, best of 3: 961 ms per loop

In [210]: %timeit masking_based(a)
1 loops, best of 3: 151 ms per loop

时间方面怎么样?你能比较一下解决方案吗?谢谢。 - jezrael
@jezrael 添加了通用情况的解决方案。 - Divakar

2

pandas方法:

In [181]:
# construct df from array
df = pd.DataFrame(a)
# call apply and call np.roll rowise and roll by the number of zeroes
df.apply(lambda x: np.roll(x, (x == 0).sum()), axis=1).values

Out[181]:
array([[ 2,  1, 23, 32],
       [ 0, 34,  3,  3],
       [ 0,  0,  3, 33],
       [ 0,  0,  0, 32]])

使用apply函数,可以对每一行调用np.roll函数,移动的数量为该行中零的数量。


df.apply不只是遍历行吗?如果不需要进行pd转换,如何将相同的lambda应用于数组的每一行? - hpaulj
@hpaulj 是的,它确实如此,np.roll 只接受标量,这就是为什么我用这种方式做的原因。 - EdChum

1

您也可以使用 numpy.argsort 进行高级索引

arr[np.arange(arr.shape[0])[:, None], (arr != 0).argsort(1, kind="mergesort")]

#array([[ 2,  1, 23, 32],
#       [ 0, 34,  3,  3],
#       [ 0,  0,  3, 33],
#       [ 0,  0,  0, 32]], dtype=int32)

那样做无法保持顺序。你需要使用“mergesort”来实现。顺便说一下,这是一个很好的想法,也符合问题中要求的简短。 - Divakar
@Divakar 你说得对。我没有注意到那个。 - Psidom

0

使用非NumPy基础的Python进行的轻松尝试 -

>>> arr = [[2,  1, 23, 32],
...        [34, 3, 3, 0],
...        [3, 33, 0, 0],
...        [32, 0, 0, 0]]
... 
>>> t_arr = [[0 for _ in range(cur_list.count(0))]\
            + [i for i in cur_list if i!=0]\
            for cur_list in arr]
>>> t_arr
[[2, 1, 23, 32], [0, 34, 3, 3], [0, 0, 3, 33], [0, 0, 0, 32]]

0

你还可以使用numpy.ma.sort()对掩码数组进行排序,该函数沿着最后一个轴就地排序,axis=-1,示例如下:

np.ma.array(a, mask=a!=0).sort()

现在,a 变成了:
array([[ 2,  1, 23, 32],
       [ 0, 34,  3,  3],
       [ 0,  0,  3, 33],
       [ 0,  0,  0, 32]])

唯一的缺点是它不像前面提到的一些方法那么快,但无论如何都是一个简短的单行代码。


0
一种基于行滚动的解决方案,灵感来自于@EDChum的pandas版本:
def rowroll(arr):
    for row in arr:
        row[:] = np.roll(row,-np.count_nonzero(row))
    return arr
In [221]: rowroll(arr.copy())
Out[221]: 
array([[ 2,  1, 23, 32],
       [ 0, 34,  3,  3],
       [ 0,  0,  3, 33],
       [ 0,  0,  0, 32]])

np.count_nonzero 是一种快速编译的查找非零数的方法。它被 np.where 用来查找其返回大小。

但是看着 np.roll 的代码,我认为这对于任务来说过于复杂了,因为它可以处理多个轴。

尽管这看起来比较混乱,但我怀疑它与 roll 相比是一样快甚至更快的:

def rowroll(arr):
    for row in arr:
        n = np.count_nonzero(row)
        temp = np.zeros_like(row)
        temp[-n:] = row[:n]
        row[:] = temp
    return arr

roll 的解决方案需要在原始数据中有连续的 0,而不是分散的 0。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接