不使用for循环,删除numpy数组中的前导零

8

如何在不使用循环的情况下仅从numpy数组中删除前导零?

import numpy as np

x = np.array([0,0,1,1,1,1,0,1,0,0])

# Desired output
array([1, 1, 1, 1, 0, 1, 0, 0])

我已经写下了以下代码。
x[min(min(np.where(x>=1))):] 

我想知道是否有更高效的解决方案。


3
np.trim_zeros(x, 'f') - 在内部,这是一个 for 循环(详见链接 https://github.com/numpy/numpy/blob/v1.14.0/numpy/lib/function_base.py#L2230-L2278),但在许多情况下,这可能是最有效的方法。该函数可以用来去除一维数组 x 中开头和结尾的零值。 - Alex Riley
x 的长度和前导零的数量通常是多少? - Warren Weckesser
3个回答

7
你可以使用 np.trim_zeros(x, 'f')
这里的 'f' 意味着从前面修剪零。 选项 'b' 将从后面修剪零。 默认选项 'fb' 会从两侧修剪。
x = np.array([0,0,1,1,1,1,0,1,0,0])
# [0 0 1 1 1 1 0 1 0 0]
np.trim_zeros(x, 'f')
# [1 1 1 1 0 1 0 0]

正如其他答案或评论所说,trim_zero()的实现似乎在内部使用了for循环,那么我们应该寻找替代方案吗?或者,如果OP只是在寻找一个库函数来避免手动编写循环,那么这个答案已经足够好了,但问题应该更加清晰明了。 - Kubuntuer82

4

由于np.trim_zeros使用了for循环,这里提供一个真正向量化的解决方案:

x = x[np.where(x != 0)[0][0]:]

然而,我不确定它何时开始比np.trim_zeros更高效。在最坏的情况下(即具有大量前导零的数组),它将更加高效。

无论如何,这可以作为一个有用的学习示例。

双边修剪:

>>> idx = np.where(x != 0)[0]
>>> x = x[idx[0]:1+idx[-1]]

你的回答引起了我的兴趣,所以我做了一些研究。切片在内部也被视为一个循环吗?参考:这篇文章 - Aechlys
那篇文章是关于“列表”的。你确定numpy也是这样吗?切片只是调用__getitem__方法并传入一个slice对象,例如x[2:4]x.__getitem__(slice(2,4))相同。这取决于numpy如何实现该方法。 - fferri
我不确定,所以才问。我对没有使用for的实现方式很好奇,想看看它是怎么做到的。但是我无法在numpy.ndarray对象的帖子中找到相关步骤。我也尝试过查看NumPy源代码,但没有发现。无论如何,我想我们不应该在这里讨论这个问题。 - Aechlys

3
这里是一个利用numpy的方法进行短路计算。它利用了任何 (?) dtype 的 0 表示都是零字节的事实。
import numpy as np
import itertools

# check assumption that for example 0.0f is represented as 00 00 00 00
allowed_dtypes = set()
for dt in map(np.dtype, itertools.chain.from_iterable(np.sctypes.values())):
    try:
        if not np.any(np.zeros((1,), dtype=dt).view(bool)):
            allowed_dtypes.add(dt)
    except:
        pass

def trim_fast(a):
    assert a.dtype in allowed_dtypes
    cut = a.view(bool).argmax() // a.dtype.itemsize
    if a[cut] == 0:
        return a[:0]
    else:
        return a[cut:]

与其他方法的比较:

在此输入图片描述

生成图表的代码:

def np_where(a):
    return a[np.where(a != 0)[0][0]:]

def np_trim_zeros(a):
    return np.trim_zeros(a, 'f')

import perfplot

tf, nt, nw = trim_fast, np_trim_zeros, np_where
def trim_fast(A): return [tf(a) for a in A]
def np_trim_zeros(A): return [nt(a) for a in A]
def np_where(A): return [nw(a) for a in A]

perfplot.save('tz.png',
    setup=lambda n: np.clip(np.random.uniform(-n, 1, (100, 20*n)), 0, None),
    n_range=[2**k for k in range(2, 11)],
    kernels=[
        trim_fast,
        np_where,
        np_trim_zeros
        ],
    logx=True,
    logy=True,
    xlabel='zeros per nonzero',
    equality_check=None
    )

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接