有没有一种方法可以加快循环numpy.where的速度?

3
假设您有一个分割地图,每个对象都由唯一的索引标识,看起来类似于这样:enter image description here。对于每个对象,我想保存它覆盖的像素,但是到目前为止我只能想到标准的 for 循环。不幸的是,对于具有数千个单独对象的大型图像,这种方法非常缓慢--至少对于我的真实数据是这样。我该如何加速处理呢?
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from skimage.draw import random_shapes


# please ignore that this does not always produce 20 objects each with a
# unique color. it is simply a quick way to produce data that is similar to
# my problem that can also be visualized.
segmap, labels = random_shapes(
    (100, 100), 20, min_size=6, max_size=20, multichannel=False,
    intensity_range=(0, 20), num_trials=100,
)
segmap = np.ma.masked_where(segmap == 255, segmap)

object_idxs = np.unique(segmap)[:-1]
objects = np.empty(object_idxs.size, dtype=[('idx', 'i4'), ('pixels', 'O')])

# important bit here:

# this I can vectorize
objects['idx'] = object_idxs
# but this I cannot. and it takes forever.
for i in range(object_idxs.size):
    objects[i]['pixels'] = np.where(segmap == i)

# just plotting here
fig, ax = plt.subplots(constrained_layout=True)
image = ax.imshow(
    segmap, cmap='tab20', norm=mpl.colors.Normalize(vmin=0, vmax=20)
)
fig.colorbar(image)
fig.show()
2个回答

1
在循环中使用np.where并不高效,因为时间复杂度为O(s n m),其中s = object_idxs.sizen, m = segmap.shape。这个操作可以在O(n m)内完成。
使用Numpy的一个解决方案是首先选择所有对象像素位置,然后根据它们在segmap中关联的对象进行排序,并最终根据对象数量进行拆分。以下是代码:
background = np.max(segmap)
mask = segmap != background
objects = segmap[mask]
uniqueObjects, counts = np.unique(objects, return_counts=True)
ordering = np.argsort(objects)
i, j = np.where(mask)
indices = np.vstack([i[ordering], j[ordering]])
indicesPerObject = np.split(indices, counts.cumsum()[:-1], axis=1)

objects = np.empty(uniqueObjects.size, dtype=[('idx', 'i4'), ('pixels', 'O')])
objects['idx'] = uniqueObjects
for i in range(uniqueObjects.size):
    # Use `tuple(...)` to get the exact same type as the initial code here
    objects[i]['pixels'] = tuple(indicesPerObject[i])
# In case the conversion to tuple is not required, the loop can also be accelerated:
# objects['pixels'] = indicesPerObject

哇,太棒了!也许是个愚蠢的问题,但是我现在难道不应该能够省略这个 for 循环,直接使用 objects ['pixels'] = indicesPerObject 吗? - mapf
确实,它应该可以工作。顺便提一下,注意这肯定不是在“并行”中完成的。它被称为“矢量化”。感谢您的编辑。由于使用了SIMD指令(以及像所有程序一样的IPC),并行性可能会发生,“并行”通常指使用多个线程,而据我所知,Numpy在内部不使用多个线程(至少不是在默认实现中)。一些函数,如矩阵乘法,使用像BLAS这样的外部库,可以使用多个线程。 - Jérôme Richard
抱歉,我是指向量化。感谢您的澄清!我对这些术语的确切定义不是很确定。 - mapf

0

听起来你想知道任何物体的位置。所以如果我们从一个矩阵开始(也就是说,所有形状都在一个数组中,其中空白处为零,第一个对象由1组成,第二个对象由2组成等等),那么你可以创建一个掩码,显示哪些像素(或矩阵中的值)是非零的,如下所示:

my_array != 0


嗨,谢谢。我不确定你的意思。你能详细说明一下吗? - mapf
因此,假设对象4由零矩阵中的4组成,则可以使用my_array == 4来查找它 - 这将为对象4生成一个True / False掩码。 如果您想查找任何对象,只需使用代码!= 0即可查看任何矩阵值(或在您的情况下为像素),以查看对象的位置。 - w_sz
我明白了。这就是 segmap == i 的作用。但这只给了我掩码,而我想知道像素索引,这就是 np.where 的作用。 - mapf
好的,那么只需在整个矩阵上使用 np.where,您就可以得到结果。 - w_sz
这就是我正在做的,但速度非常慢。 - mapf
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接