我需要过滤一个数组,以删除低于某个阈值的元素。我的当前代码如下:
threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))
问题在于这会创建一个临时列表,使用带有lambda函数的过滤器(速度较慢)。由于这是一个相当简单的操作,也许有一种numpy函数可以以有效的方式执行它,但我找不到它。
我认为实现这个的另一种方法可能是对数组进行排序,找到阈值的索引并从该索引开始返回一个切片,但即使对于小输入来说这可能更快(而且也不会显著),随着输入大小的增长,它肯定渐进地不够有效。
更新:我也进行了一些测量,在输入为100,000,000条目时,排序+切片仍然比纯Python过滤器快两倍。
r = numpy.random.uniform(0, 1, 100000000)
%timeit test1(r) # filter
# 1 loops, best of 3: 21.3 s per loop
%timeit test2(r) # sort and slice
# 1 loops, best of 3: 11.1 s per loop
%timeit test3(r) # boolean indexing
# 1 loops, best of 3: 1.26 s per loop
%timeit
使用内置的timeit
模块。也可以看一下它,文档链接:http://docs.python.org/library/timeit.html。 - Joe Kington