Numpy:如何快速替换矩阵中相等的值?

4
假设我们有一个秩为2的数组 a,其中包含n个整数值,这些值在 {0,1,2,...,m} 中。现在对于每个整数,我想找到具有该值的 a 条目的索引(在以下示例中称为 index_i、index_j)。 (因此,我要寻找类似于 np.unique(..., return_index=True) 的东西,但是针对 2d 数组,并且可以返回每个唯一值的所有索引。)
一个天真的方法涉及布尔索引,这将导致 O(m*n) 的操作(见下文),但我只想进行 O(n) 操作。虽然我找到了一种解决方案,但我觉得应该有一个内置方法或至少一个简化这个过程的东西 - 或者至少可以消除这些丑陋的循环:
import numpy as np
a = np.array([[0,0,1],[0,2,1],[2,2,1]])
m = a.max()


#"naive" in O(n*m)
i,j = np.mgrid[range(a.shape[0]), range(a.shape[1])]
index_i = [[] for _ in range(m+1)]
index_j = [[] for _ in range(m+1)]
for k in range(m+1):
  index_i[k] = i[a==k]
  index_j[k] = j[a==k]

#all the zeros:
print(a[index_i[0], index_j[0]])
#all the ones:
print(a[index_i[1], index_j[1]])
#all the twos:
print(a[index_i[2], index_j[2]])


#"sophisticated" in O(n)

index_i = [[] for _ in range(m+1)]
index_j = [[] for _ in range(m+1)]
for i in range(a.shape[0]):
  for j in range(a.shape[1]):
    index_i[a[i,j]].append(i)
    index_j[a[i,j]].append(j)

#all the zeros:
print(a[index_i[0], index_j[0]])
#all the ones:
print(a[index_i[1], index_j[1]])
#all the twos:
print(a[index_i[2], index_j[2]])

(Note: 我随后需要这些索引以进行写入访问,也就是替换数组中存储的值。但在这些操作之间,我确实需要这个二维结构。)

在线尝试!


1
也许可以使用 numpy.argwhere?类似于 [np.argwhere(a == x) for x in np.unique(a)] 这样的代码? - Chris Adams
最终输出会是什么样子?考虑到元素数量可能不同,这可能会使在寻找删除循环时变得更加复杂。 - Divakar
@ChrisA 谢谢,这很接近我想要的,至少简化了代码。但问题在于你仍然需要进行 O(n*m) 次操作,因为你将每个值 x 与整个输入数组进行比较。 - flawr
理想情况下是一组索引列表的集合,其中每个列表对应一个值。(就像我示例中的 index_i, index_j 一样。) - flawr
温馨提醒 - 对已发布的解决方案有任何反馈吗? - Divakar
抱歉,我只能快速测试它,没有时间理解发生了什么。我需要一段时间才能再次处理这个问题。我很感激你的回答,但现在我只是缺乏时间去仔细研究它。 - flawr
1个回答

2

这是一个基于排序的方法,旨在在迭代时最小化工作量,以便将其保存为一个字典,其中键是唯一元素,值是索引 -

最初的回答:

shp = a.shape
idx = a.ravel().argsort()
idx_sorted = np.c_[np.unravel_index(idx,shp)]
count = np.bincount(a.ravel())
valid_idx = np.flatnonzero(count!=0)
cs = np.r_[0,count[valid_idx].cumsum()]
out = {e:idx_sorted[i:j] for (e,i,j) in zip(valid_idx,cs[:-1],cs[1:])}

样例输入,输出 -

In [155]: a
Out[155]: 
array([[0, 2, 6],
       [0, 2, 6],
       [2, 2, 1]])

In [156]: out
Out[156]: 
{0: array([[0, 0],
        [1, 0]]), 1: array([[2, 2]]), 2: array([[0, 1],
        [1, 1],
        [2, 0],
        [2, 1]]), 6: array([[0, 2],
        [1, 2]])}

如果序列中的所有整数都在数组中出现,我们可以简化一下 -
shp = a.shape
idx = a.ravel().argsort()
idx_sorted = np.c_[np.unravel_index(idx,shp)]
cs = np.r_[0,np.bincount(a.ravel()).cumsum()]
out = {iterID:idx_sorted[i:j] for iterID,(i,j) in enumerate(zip(cs[:-1],cs[1:]))}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接