使用NumPy进行高效的阈值滤波数组

86

我需要过滤一个数组,以删除低于某个阈值的元素。我的当前代码如下:

threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))
问题在于这会创建一个临时列表,使用带有lambda函数的过滤器(速度较慢)。
由于这是一个相当简单的操作,也许有一种numpy函数可以以有效的方式执行它,但我找不到它。
我认为实现这个的另一种方法可能是对数组进行排序,找到阈值的索引并从该索引开始返回一个切片,但即使对于小输入来说这可能更快(而且也不会显著),随着输入大小的增长,它肯定渐进地不够有效。
更新:我也进行了一些测量,在输入为100,000,000条目时,排序+切片仍然比纯Python过滤器快两倍。
r = numpy.random.uniform(0, 1, 100000000)

%timeit test1(r) # filter
# 1 loops, best of 3: 21.3 s per loop

%timeit test2(r) # sort and slice
# 1 loops, best of 3: 11.1 s per loop

%timeit test3(r) # boolean indexing
# 1 loops, best of 3: 1.26 s per loop

2
是的,这很不错 :-) 它甚至会自动计算需要执行多少次迭代来平均测量值,如果代码执行时间非常短的话。 - fortran
5
@yosukesabai - IPython的%timeit使用内置的timeit模块。也可以看一下它,文档链接:http://docs.python.org/library/timeit.html。 - Joe Kington
2个回答

114

b = a[a>threshold] 这样就可以了。

我进行了以下测试:

import numpy as np, datetime
# array of zeros and ones interleaved
lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten()

t0 = datetime.datetime.now()
flt = lrg[lrg==0]
print datetime.datetime.now() - t0

t0 = datetime.datetime.now()
flt = np.array(filter(lambda x:x==0, lrg))
print datetime.datetime.now() - t0

我得到了

$ python test.py
0:00:00.028000
0:00:02.461000

http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays


1
添加测试结果,而不仅仅是我认为它应该做什么。 :p - yosukesabai
3
这种索引方式不会保持数组的大小,如何才能保持相同数量的元素并将亚阈值的值归零? - linello
9
@linello,a[a<=threshold] = 0 将屏蔽掉未超过阈值的部分。 - yosukesabai
4
我遇到了基于两个条件进行过滤的问题。这是解决方案:https://dev59.com/bHA75IYBdhLWcg3wg5fs#3248599 - Robin Newhouse
@yosukesabai 这是否可能实现,而不实际更改原始值。如果 np.ma 就是为此而设计的,我无法弄清楚如何使用。 - embert
显示剩余7条评论

0

你也可以使用np.where来获取条件为True的索引,并使用高级索引。

import numpy as np
b = a[np.where(a >= threshold)]
< p > np.where 的一个有用的功能是,您可以使用它来替换值(例如,在未满足阈值的情况下替换值)。虽然 a[a <= 5] = 0 修改了 a,但 np.where 返回一个新数组,其形状相同,只有一些值(可能)被更改。 < / p >
a = np.array([3, 7, 2, 6, 1])
b = np.where(a >= 5, a, 0)       # array([0, 7, 0, 6, 0])

在性能方面,它也非常具有竞争力。

a, threshold = np.random.uniform(0,1,100000000), 0.5

%timeit a[a >= threshold]
# 1.22 s ± 92.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit a[np.where(a >= threshold)]
# 1.34 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接