如何过滤numpy数组的行

40
我想对numpy数组的每一行应用一个函数。如果这个函数的结果为True,我将保留这一行,否则我将舍弃它。例如,我的函数可能是:
def f(row):
    if sum(row)>10: return True
    else: return False

我想知道是否有类似于以下内容的东西:
np.apply_over_axes()

这个函数可以应用于numpy数组的每一行,并返回结果。我希望得到类似这样的东西:

np.filter_over_axes()

将函数应用于numpy数组的每一行,并仅返回函数返回True的行。是否有类似的功能?还是我应该使用for循环?

2个回答

41

理想情况下,您应该能够实现一个向量化版本的函数,并使用它来进行布尔索引。对于绝大多数问题而言,这是正确的解决方案。Numpy提供了许多可以作用于不同轴的函数,以及所有基本操作和比较运算,因此大多数有用的条件都可以进行向量化处理。

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

如果您完全确定无法执行上述操作,我建议使用列表推导式(或np.apply_along_axis)创建一个布尔数组以用作索引。

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

这种方法可以相对清晰地完成工作,但比向量化版本慢得多。 一个例子:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop

谢谢罗杰,我想要使用的函数比简单的求和要复杂一些,所以我可能会使用列表推导解决方案。 - killajoule

0
正如@Roger Fan所提到的,逐行应用函数实际上应该在整个数组上以矢量化方式完成。过滤的规范方法是构造一个布尔掩码并将其应用于数组。话虽如此,如果函数非常复杂而无法进行向量化,则最好/更快地将数组转换为Python列表(特别是如果它使用Python函数,例如sum()),然后在其上应用该函数。
msk = arr.sum(axis=1)>10                # best way to create a boolean mask

msk = [f(row) for row in arr.tolist()]  # second best way
#                            ^^^^^^^^   <---- convert to list

filtered_arr = arr[msk]                 # filtered via boolean indexing
一个可工作的示例和性能测试

从下面的timeit测试中可以看出,循环遍历列表(arr.tolist())比循环遍历numpy数组(arr)要快得多,部分原因是在函数f()中调用了Python的sum()而不是np.sum()。尽管如此,向量化方法比两者都要快得多。

def f(row):
    if sum(row)>10: return True
    else: return False
    
arr = np.random.rand(10000, 200)

%timeit arr[[f(row) for row in arr]]
# 260 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit arr[[f(row) for row in arr.tolist()]]
# 114 ms ± 4.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit arr[arr.sum(axis=1)>10]
# 10.8 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接