Numpy: 获取值大于某数且满足条件的索引数组

3

我有以下数组:

a = np.array([6,5,4,3,4,5,6])

现在我想获取所有大于4且索引值大于2的元素。 我找到了以下方法:

a[2:][a[2:]>4]

有更好或更易读的方法来完成这个吗?

更新:这是一个简化版。实际上,索引是通过多个变量进行算术运算来完成的,就像这样:

a[len(trainPredict)+(look_back*2)+1:][a[len(trainPredict)+(look_back*2)+1:]>4]

trainPredict 是一个numpy数组,look_back 是一个整数。
我想知道是否有已经确定的方法或其他人如何做到这一点。


你是在寻找元素、元素的索引(在原始数组中)、还是元素的掩码? - Mad Physicist
@MadPhysicist,这与我在问题中写的内容相同:a[2:][a[2:]>4],只是分成了三行而不是一行。如果没有其他更好的方法,那么我将采纳此答案并结束问题。 - Code Pope
我能想到的其他方法都要低得多。我会写一个答案来证明它。现有的答案比一行代码更清晰,因为它避免了冗余的临时数组。 - Mad Physicist
@CodePope,您在帖子中的代码似乎是标准/惯用的做法。然而,我有一个问题:trainPredict是numpy数组、纯Python列表还是完全不同的东西?如果我们能够获得更多上下文,比如涉及到的3个对象的示例,那么这可能也会有所帮助。 - AMC
@CodePope,你是用 len() 函数来找到沿着第 0 维的数组数量 (len(arr) == arr.shape[0]),而不是总元素数量,对吗? - AMC
显示剩余3条评论
2个回答

2
如果你担心切片的复杂性和/或条件的数量,你可以将它们分开处理:
a = np.array([6,5,4,3,4,5,6])

a_slice = a[2:]

cond_1 = a_slice > 4

res = a_slice[cond_1]

你的例子很简化吗?对于更复杂的操作,可能会有更好的解决方案。


1

@AlexanderCécile的回答不仅比你发布的一行代码更易读,而且还消除了临时数组的冗余计算。尽管如此,它似乎并没有比你原来的方法更快。

下面的时间都是在预备设置下运行的。

import numpy as np
np.random.seed(0xDEADBEEF)
a = np.random.randint(8, size=N)

N的值从1e3到1e8,每次增加10倍。我尝试了四种代码变体:

  1. CodePope:result = a[2:][a[2:] > 4]
  2. AlexanderCécile:s = a[2:]; result = s[s > 4]
  3. MadPhysicist1:result = a[np.flatnonzero(a[2:]) + 2]
  4. MadPhysicist2:result = a[(a > 4) & (np.arange(a.size) >= 2)]

在所有情况下,通过在命令行上运行以下命令来获取时间:

python -m timeit -s 'import numpy as np; np.random.seed(0xDEADBEEF); a = np.random.randint(8, size=N)' '<X>'

在这里,N 是介于 3 和 8 之间的10的幂次方,<X> 是上述表达式之一。时间如下:

enter image description here

方法1和方法2几乎无法区分。令人惊讶的是,在约5e3到1e6个元素之间,方法3似乎略微但明显地更快。我通常不会期望从花式索引中获得这种结果。当然,方法4将是最慢的。
为了完整起见,以下是数据:
           CodePope  AlexanderCécile  MadPhysicist1  MadPhysicist2
1000       3.77e-06         3.69e-06       5.48e-06       6.52e-06
10000       4.6e-05         4.59e-05       3.97e-05       5.93e-05
100000     0.000484         0.000483         0.0004       0.000592
1000000     0.00513          0.00515        0.00503        0.00675
10000000     0.0529           0.0525         0.0617          0.102
100000000     0.657            0.658          0.782           1.09

实际上,我的回答只是为了提高可读性。由于numpy数组切片是视图,创建一个新变量的开销可能超过了不切片两次所带来的小的性能提升。 - AMC
然而,在他更新的代码中,按照我的答案分离部分可能会提高性能,因为数值操作会创建新的数组。 - AMC

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接