从数组中删除所有零元素

3

我有一个形状为[120000, 3]的数组,其中只有前1500个元素有用,其余都是0。

这里是一个例子

[15.0, 14.0, 13.0]
[11.0, 7.0, 8.0]
[4.0, 1.0, 3.0]
[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]

我需要找到一种方法来删除所有元素为 [0.0, 0.0, 0.0]的内容。我尝试了写下以下代码,但是它并没有起作用。

for point in points:
        if point[0] == 0.0 and point[1] == 0.0 and point[2] == 0.0:
            np.delete(points, point)

编辑

评论区中的所有解决方案都可行,但我会选择使用一个来标记为已解决。感谢大家。

6个回答

8
不要使用for循环,它们速度较慢。在for循环中反复调用np.delete会导致性能下降。
相反,创建一个掩码(mask):
zero_rows = (points == 0).all(1)

这是一个长度为120000的数组,其中所有元素都为0的行会被标记为True。

然后找到第一行满足条件:

first_invalid = np.where(zero_rows)[0][0]

最后,对数组进行切片:

points[:first_invalid]

6
有几种相关的方法,分为两个阵营。你可以通过计算单个布尔数组并使用 np.ndarray.all 的向量化方法。或者你可以通过使用 for 循环或带有生成器表达式的 next 计算仅包含 0 元素的第一行的索引。
为了性能,我建议您使用手动 for 循环和 numba。这是一个例子,但请参见下面的基准测试以获得更有效的变体:
from numba import jit

@jit(nopython=True)
def trim_enum_nb(A):
    for idx in range(A.shape[0]):
        if (A[idx]==0).all():
            break
    return A[:idx]

性能基准测试

# python 3.6.5, numpy 1.14.3

%timeit trim_enum_loop(A)     # 9.09 ms
%timeit trim_enum_nb(A)       # 193 µs
%timeit trim_enum_nb2(A)      # 2.2 µs
%timeit trim_enum_gen(A)      # 8.89 ms
%timeit trim_vect(A)          # 3.09 ms
%timeit trim_searchsorted(A)  # 7.67 µs

测试代码

设置

import numpy as np
from numba import jit

np.random.seed(0)

n = 120000
k = 1500

A = np.random.randint(1, 10, (n, 3))
A[k:, :] = 0

函数

def trim_enum_loop(A):
    for idx, row in enumerate(A):
        if (row==0).all():
            break
    return A[:idx]

@jit(nopython=True)
def trim_enum_nb(A):
    for idx in range(A.shape[0]):
        if (A[idx]==0).all():
            break
    return A[:idx]

@jit(nopython=True)
def trim_enum_nb2(A):
    for idx in range(A.shape[0]):
        res = False
        for col in range(A.shape[1]):
            res |= A[idx, col]
            if res:
                break
            return A[:idx]

def trim_enum_gen(A):
    idx = next(idx for idx, row in enumerate(A) if (row==0).all())
    return A[:idx]

def trim_vect(A):
    idx = np.where((A == 0).all(1))[0][0]
    return A[:idx]

def trim_searchsorted(A):
    B = np.frombuffer(A, 'S12')
    idx = A.shape[0] - np.searchsorted(B[::-1], B[-1:], 'right')[0]
    return A[:idx]

检查

# check all results are the same
assert (trim_vect(A) == trim_enum_loop(A)).all()
assert (trim_vect(A) == trim_enum_nb(A)).all()
assert (trim_vect(A) == trim_enum_nb2(A)).all()
assert (trim_vect(A) == trim_enum_gen(A)).all()
assert (trim_vect(A) == trim_searchsorted(A)).all()

你能解释一下trim_enum_gen中的if吗?.all()是什么意思? - User
1
请参阅np.ndarray.allnext生成器表达式。如果您对它们在此处的使用有具体问题,我可以尝试进一步解释。 - jpp
1
numba对于numba的优化:将if (A[idx]==0).all():替换为 for j in range(3):\ if v[j]!=0:\ break\ if v[j]==0:\ break,速度提升了四倍 ;) - B. M.
@B.M.,说得好,已更新。尝试编写嵌套循环需要一些时间来适应。我现在的解决方案似乎比以前的numba快了约100倍! - jpp

2
x = [[15.0, 14.0, 13.0],
[11.0, 7.0, 8.0],
[4.0, 1.0, 3.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]]

简单的迭代解决方案:
y = [i for i in x if i != [0.0, 0.0, 0.0]]

更好的解决方案(Python 3.x):
y = list(filter(lambda a: a != [0.0, 0.0, 0.0], x))

输出:

[[15.0, 14.0, 13.0], [11.0, 7.0, 8.0], [4.0, 1.0, 3.0]]

很有趣,想知道这与John Zwinck关于性能的回答相比如何。 - sobek
为什么您的(Python 3.x)解决方案更好(为什么在Python 2中它不能同样有效地工作)? - thebjorn

2

知道这个问题已经结束了,只是想给出我的答案 :)

x = [[15.0, 14.0, 13.0],
[11.0, 7.0, 8.0],
[4.0, 1.0, 3.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]]

那么可以使用简单的列表推导式

[i for i in x if all(i)]

并输出:

[[15.0, 14.0, 13.0],[11.0, 7.0, 8.0],[4.0, 1.0, 3.0]]

需要

0.0000010866 # seconds or 1.0866 microseconds

不要太过相信时间,因为它非常不确定。给我两秒钟来得到更好的估计。

何时:

x = [[15.0, 14.0, 13.0],
[11.0, 7.0, 8.0],
[4.0, 1.0, 3.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]]*(120000//7)

我得到时间了

0.01199 # seconds

这次的速度很大程度上取决于它们是否为0,因为0被忽略掉了,所以速度更快。


2
对于对数复杂度,您可以在按行转换数据后使用numpy.searchsorted
B=np.frombuffer(A,'S12')
index=B.size-np.searchsorted(B[::-1],B[-1:],'right')[0]

"index"将是非空项的数量,如果前面的项都不为空。
测试:
>>>> %timeit B.size-searchsorted(B[::-1],B[-1:],'right')[0]
2.2 µs 

1
我认为你需要 A.shape[0]-np.searchsorted(B[::-1],B[-1:],'right')[0] - jpp
不管怎样,很好的解决方案 +1,我也将这个解决方案添加到我的帖子时间中,希望你没关系。 - jpp

1

使用vstack的简单迭代解决方案

import numpy as np
b = np.empty((0,3), float)
for elem in a:
    toRemove = np.array([0.0, 0.0, 0.0])
    if(not np.array_equal(elem,toRemove)):
        b=np.vstack((b, elem))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接