从数组中选择大于模板的行

3
现在我有一个浮点值的 2D Numpy 数组,即a,它的形状为(10^6, 3)。 我想知道哪些行大于np.array([25.0, 25.0, 25.0])。然后输出满足此条件的行。 我的代码如下。
# Create an empty array
a_cut = np.empty(shape=(0, 3), dtype=float)

minimum = np.array([25.0, 25.0, 25.0])

for i in range(len(a)):
    if a[i,:].all() > minimum.all():
        a_cut = np.append(a_cut, a[i,:], axis=0)

然而,这段代码效率低下。几个小时后,结果仍未出现。 那么有没有办法提高循环速度?


@Chris,你在函数all()中设置了一个参数,即all(1)。这是什么意思? - Stephen Wong
@StephenWong 这意味着 axis=1,即所有的行都是 True,也就是说所有的行都大于 25 ;) - Chris
@Chris 因为我想把满足条件的所有行保存到另一个数组中,所以我猜我必须循环数组。但是,数量级达到了$10^6$,所以使用for循环是不太有效的。也许有一些方法可以实现我的目标。 - Stephen Wong
3
@StephenWong a[(a > minimum).all(1)] - Divakar
@Divakar 如果我采用你的代码,那么如何输出满足此条件的行? - Stephen Wong
显示剩余5条评论
1个回答

2

np.append 每次调用时都会重新分配整个数组。它基本上与 np.concatenate 相同:请非常节制地使用它。目标是批量执行整个操作。

您可以构造一个掩码:

mask = (a > minimum).all(axis=1)

然后选择:
a_cut = a[mask, :]

使用索引而不是布尔掩码可能会稍微改善性能:

a_cut = a[np.flatnonzero(mask), :]

若要使用少于维度数量的索引进行索引,则将这些索引应用于前导维度,因此您可以执行以下操作:

a_cut = a[mask]

因此,这个简短的表述是:
a_cut = a[(a > minimium).all(1)]

假设我有另一个名为maximum = np.array([50.0, 50.0, 50.0])的数组,我想让a满足maximum > a[i, :] > minimum。应该如何修改代码? - Stephen Wong
1
max > a & a > min - Mad Physicist

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接