高效地查找数组中所有值的索引

4

我有一个非常大的数组,包含介于0和N之间的整数,每个值至少出现一次。

我想知道对于每个值 k,我的数组中所有等于 k 的值的索引。

例如:

arr = np.array([0,1,2,3,2,1,0])
desired_output = {
    0: np.array([0,6]),
    1: np.array([1,5]),
    2: np.array([2,4]),
    3: np.array([3]),
    }

目前我是通过循环range(N+1),并且调用np.where N次来实现这个目标的。

indices = {}
for value in range(max(arr)+1):
    indices[value] = np.where(arr == value)[0]

这个循环是我代码中最慢的部分。(无论是arr==value评估还是np.where调用都占用了大量时间。)有没有更有效率的方法来完成这个任务?
我也尝试过使用np.unique(arr, return_index=True),但它只告诉我第一个索引,而不是所有的索引。

你可以使用range迭代数组中的唯一项,而不是循环。将数组转换为set,并在set上进行迭代。这可以减少开销。 - nishparadox
@nishparadox2 这样做没有任何作用,因为我已经知道唯一的值是什么:从0到N,你可以通过range比调用unique()更有效地获得它们。 - acdr
4个回答

7

方法一

以下是一个向量化的方法,可以将这些索引作为数组列表获取 -

sidx = arr.argsort()
unq, cut_idx = np.unique(arr[sidx],return_index=True)
indices = np.split(sidx,cut_idx)[1:]

如果您想要一个最终的字典,将每个唯一元素对应到它们的索引,我们可以使用循环推导式 -
dict_out = {unq[i]:iterID for i,iterID in enumerate(indices)}

方法二

如果你只对数组列表感兴趣,这里有一个专为性能而设计的替代方案 -

sidx = arr.argsort()
indices = np.split(sidx,np.flatnonzero(np.diff(arr[sidx])>0)+1)

这真是相当聪明的做法。我会尝试一下的。 - acdr
1
所以这比我实际接受的答案快了大约6.5倍。考虑到聪明才智会混淆代码的目的,我会坚持使用更简单的解决方案,但它绝对值得点赞和赞扬。 :) - acdr
@acdr,我很欣赏有关运行时间的任何反馈,因为我非常担心这个问题。感谢您带来的点赞;) - Divakar

3
一种使用Pythonic的方式是使用collections.defaultdict():
>>> from collections import defaultdict
>>> 
>>> d = defaultdict(list)
>>> 
>>> for i, j in enumerate(arr):
...     d[j].append(i)
... 
>>> d
defaultdict(<type 'list'>, {0: [0, 6], 1: [1, 5], 2: [2, 4], 3: [3]})

以下是使用字典推导式和 numpy.where() 的 Numpythonic 方式:

>>> {i: np.where(arr == i)[0] for i in np.unique(arr)}
{0: array([0, 6]), 1: array([1, 5]), 2: array([2, 4]), 3: array([3])}

如果您不想涉及字典,这里是一个纯粹的Numpythonic方法:

>>> uniq = np.unique(arr)
>>> args, indices = np.where((np.tile(arr, len(uniq)).reshape(len(uniq), len(arr)) == np.vstack(uniq)))
>>> np.split(indices, np.where(np.diff(args))[0] + 1)
[array([0, 6]), array([1, 5]), array([2, 4]), array([3])]

1
我的输入数组非常大,所以我本来以为这个解决方案会比我的更慢(因为它循环遍历所有元素),但事实证明它非常快。考虑到它比其他答案更简单,我会接受这个答案。 - acdr
1
这也值得点赞,因为正如OP所说,它更简单! :) - Divakar
1
@acdr,我根据你期望的输出(字典)给出了这个答案。请查看仅对NumPy中索引进行分类的答案。 - Mazdak
@Divakar 遵循KISS原则! - Mazdak

1

我不了解numpy,但你绝对可以用defaultdict在一次迭代中完成这个操作:

indices = defaultdict(list)
for i, val in enumerate(arr):
    indices[val].append(i)

0

使用numpy_indexed包的完全矢量化解决方案:

import numpy_indexed as npi
k, idx = npi.groupy_by(arr, np.arange(len(arr)))

从更高的层面来看,你为什么需要这些索引?通常可以使用group_by功能[例如,npi.group_by(arr).mean(someotherarray)]更有效地计算后续分组操作,而无需显式计算键的索引。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接