在一个二维 NumPy 数组中,创建一个作为分组查找的字典的最快方法是什么?

3
假设我有一个二维 numpy 数组,其值对应于标签或类别。例如,如果 A = [[0, 0, 1, 1], [1, 1, 1, 0]],则位置 (0, 0), (0, 1), (1, 3) 对应于类别 '0',而位置 (0, 2), (0, 3), (1, 0) 等对应于类别 '1'。这是一个非常简单的例子,但是一般情况下,我将处理更多项的矩阵。
我想做的是基本上构建一个字典,其中一个键对应于每个类别,其相应的值是一个元组列表,其中每个元组对应于输入矩阵的一个具有该键值的位置。换句话说,按其值对输入矩阵进行分组,并获得每个唯一值出现的位置列表。
目前为止,我有以下代码:
S = {i: [] for i in range(A.max() + 1)}
for i in range(A.shape[0]):
    index = np.arange(A[i].shape[0])
    sort_idx = np.argsort(A[i])
    cnt = np.bincount(A[i])
    result = np.split(index[sort_idx], np.cumsum(cnt[:-1]))
    for j, k in enumerate(result):
        S[j] += [(i, z) for z in k]

A是我的输入矩阵。在一个500x500的矩阵上运行平均需要大约0.4毫秒。尽管如此,我觉得可以通过更好地利用向量化来进一步提高性能。

请问有人可以指导我如何使它变得更简单和/或更快吗?感谢任何帮助!


你确切想要用这个字典做什么?这听起来有点像 XY 问题。 - anon01
这是分割问题的一部分,所以我更感兴趣的是知道哪些(X,Y)属于哪个类别。 - Camilo Martinez M.
如果你想进一步扩展,可能有更适合你所做的工具。 - anon01
是的,我相信scikit有一些内置库,比如"labels",或多或少允许这样做。但是,我没有去寻找在我的代码中实现它的方法,我想直接这样做。 - Camilo Martinez M.
1个回答

5
你可以使用np.argwherenp.unique更简单地完成这个任务:
S = {}
for key in np.unique(A):
    S[key] = np.argwhere(A==key)

请注意,这将返回一个二维的numpy数组。

谢谢!这正是我在寻找的。直接将numpy数组作为键的值得到是更好的选择。 - Camilo Martinez M.

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接