加速numpy整数数组的深度索引

3
假设我有一个数组:
 [[0 2 1]
  [1 0 1]
  [2 1 1]]

我希望将其转换为以下形式的张量

[[[1 0 0]
  [0 1 0]
  [0 0 0]]
 [[0 0 1]
  [1 0 1]
  [0 1 1]]
 [[0 1 0]
  [0 0 0]
  [1 0 0]]]

每个深度层(索引i)都是一个二进制掩码,显示i在输入中出现的位置。

我已经编写了这个功能的代码,虽然能够正常工作,但速度太慢而无法使用。我可以用另一个向量化操作替换此函数中的循环吗?

def im2segmap(im, depth):
    tensor = np.zeros((im.shape[0], im.shape[1], num_classes))

    for c in range(depth):
        rows, cols = np.argwhere(im==c).T
        tensor[c, rows, cols] = 1

    return tensor
1个回答

5

使用broadcasting

(a==np.arange(num_classes)[:,None,None]).astype(int)

或者使用内置的比较函数 builtin 进行外部比较 -

(np.equal.outer(range(num_classes),a)).astype(int)

如果必须使用int数据类型,请使用uint8,或者完全跳过int转换并保持为boolean以进一步提高性能。

示例运行 -

In [42]: a = np.array([[0,2,1],[1,0,1],[2,1,1]])

In [43]: num_classes = 3 # or depth

In [44]: (a==np.arange(num_classes)[:,None,None]).astype(int)
Out[44]: 
array([[[1, 0, 0],
        [0, 1, 0],
        [0, 0, 0]],

       [[0, 0, 1],
        [1, 0, 1],
        [0, 1, 1]],

       [[0, 1, 0],
        [0, 0, 0],
        [1, 0, 0]]])

要将depth/num_classes作为第三个维度,请扩展输入数组,然后与范围数组进行比较 -

(a[...,None]==np.arange(num_classes)).astype(int)
(np.equal.outer(im, range(num_classes))).astype(int)
(np.equal.outer(im, range(num_classes))).astype(np.uint8) # lower prec

这真的很酷。如何编辑以将深度作为第三维度输出?(x,y,depth)而不是(depth,x,y)。我尝试了几件事情,但没有广播。 - Rabeez Riaz
1
@RabeezRiaz 如果是针对编程的话,应该是这样写:a[...,None]==np.arange(num_classes) - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接