使用numpy.take进行更快的花式索引

15

编辑 下面我将保留我遇到的更加复杂的问题,但是我使用 np.take 函数遇到的问题可以简单总结如下:假设你有一个形状为 (planes, rows) 的数组 img,以及一个形状为 (planes, 256) 的数组 lut,并且你想用它们创建一个形状为 (planes, rows) 的新数组 out,其中 out[p, j] = lut[p, img[p, j]]。你可以使用高级索引来实现:

In [4]: %timeit lut[np.arange(planes).reshape(-1, 1), img]
1000 loops, best of 3: 471 us per loop

但是,如果您使用take和Python循环而不是使用繁琐的索引,可以大大加快速度:

In [6]: %timeit for _ in (lut[j].take(img[j]) for j in xrange(planes)) : pass
10000 loops, best of 3: 59 us per loop

能否以某种方式重新排列lutimg,使整个操作不需要使用Python循环,而是使用numpy.take(或其他替代方法)而不是传统的花式索引来保持速度优势?

planes, rows, cols, n = 3, 4000, 4000, 4
lut = np.random.randint(-2**31, 2**31 - 1,
                        size=(planes * 256 * n // 4,)).view('uint8')
lut = lut.reshape(planes, 256, n)
img = np.random.randint(-2**31, 2**31 - 1,
                    size=(planes * rows * cols // 4,)).view('uint8')
img = img.reshape(planes, rows, cols)

我可以使用类似这样的高级索引来实现我的目标

out = lut[np.arange(planes).reshape(-1, 1, 1), img]

它给我一个形状为(planes, rows, cols, n)的数组,其中out[i, :, :, j]保存了img的第i层平面经过i层LUT的第j个 LUT的结果。

一切都很好,只有这个东西不太行:

In [2]: %timeit lut[np.arange(planes).reshape(-1, 1, 1), img]
1 loops, best of 3: 5.65 s per loop

这完全是不可接受的,因为我可以使用 np.take 的以下不太美观的替代方案中任何一个都会运行得更快:

  1. 单个平面上的单个查找表运行速度约为x70倍:

In [2]: %timeit np.take(lut[0, :, 0], img[0])
10 loops, best of 3: 78.5 ms per loop
一个运行所有所需组合的 Python 循环完成速度几乎是原来的6倍:
In [2]: %timeit for _ in (np.take(lut[j, :, k], img[j]) for j in xrange(planes) for k in xrange(n)) : pass
1 loops, best of 3: 947 ms per loop
  • 即使是在查找表(LUT)和图像中运行所有平面的组合,然后丢弃planes**2 - planes 个不需要的平面,也比使用高级索引要快:

  • In [2]: %timeit np.take(lut, img, axis=1)[np.arange(planes), np.arange(planes)]
    1 loops, best of 3: 3.79 s per loop
    
  • 目前我想到的最快组合是使用Python循环迭代平面,速度比原先快13倍:

  • In [2]: %timeit for _ in (np.take(lut[j], img[j], axis=0) for j in xrange(planes)) : pass
    1 loops, best of 3: 434 ms per loop
    
    当然,问题是如果没有使用任何Python循环,是否无法使用np.take完成此操作?理想情况下,任何所需的重塑或调整大小都应该发生在LUT上,而不是图像上,但我可以接受你们能够想出来的任何方法...

    这整行不应该写成 lut = lut.reshape(planes, 256, 4) 吗?最后一个维度是 4 - tzelleke
    @TheodrosZelleke 感谢您发现了这些问题!我的 lut 实际上是一个 断点表,所以在我的代码中它被称为 bkpt,而我在翻译问题时漏掉了这个。 - Jaime
    2
    真的,不知道。看起来都很丑陋。有一件事,np.take 目前只有在两个输入都是 C 连续时才快(否则会复制它们)。你可能可以手动将 2D 数组转换为 1D 数组,但如果 img 很大,这可能并不重要,而且如果值得折腾的话... - seberg
    2
    嗨,你应该提供一个完全可工作的示例,如果太长了,可以在Github上创建一个Gist。否则,人们很难重现你的问题并尝试帮助你。 - Andrea Zonca
    (针对您的简化编辑):您提供的两个示例并不相等。第二个示例的输出不会像您所需的那样是一个numpy数组“out”。您能否计时一个等效的示例?它可能包括调用“np.concatenate”,我想速度优势可能会变得更小。 - Juan
    显示剩余2条评论
    1个回答

    6
    首先我必须说我真的很喜欢你的问题。不重新排列 LUTIMG,以下解决方案有效:
    %timeit a=np.take(lut, img, axis=1)
    # 1 loops, best of 3: 1.93s per loop
    

    但是从结果中你需要查询对角线:a[0,0],a[1,1],a[2,2];来获得你想要的内容。我试图找到一种方法只针对对角线元素进行索引,但仍未成功。
    以下是一些重新排列LUT和IMG的方法: 如果IMG中的索引为0-255(第1平面),256-511(第2平面)和512-767(第3平面),则以下方法有效,但这将阻止您使用'uint8',这可能是一个大问题...:
    lut2 = lut.reshape(-1,4)
    %timeit np.take(lut2,img,axis=0)
    # 1 loops, best of 3: 716 ms per loop
    # or
    %timeit np.take(lut2, img.flatten(), axis=0).reshape(3,4000,4000,4)
    # 1 loops, best of 3: 709 ms per loop
    

    在我的机器上,你的解决方案仍然是最好的选择,并且非常适当,因为你只需要对角线评估,也就是plane1-plane1,plane2-plane2和plane3-plane3:
    %timeit for _ in (np.take(lut[j], img[j], axis=0) for j in xrange(planes)) : pass
    # 1 loops, best of 3: 677 ms per loop
    

    我希望这能让您对更好的解决方案有所了解。可以尝试使用flatten()np.apply_over_axes()np.apply_along_axis()等方法,它们似乎是很有前途的选择。

    我使用下面的代码生成数据:

    import numpy as np
    num = 4000
    planes, rows, cols, n = 3, num, num, 4
    lut = np.random.randint(-2**31, 2**31-1,size=(planes*256*n//4,)).view('uint8')
    lut = lut.reshape(planes, 256, n)
    img = np.random.randint(-2**31, 2**31-1,size=(planes*rows*cols//4,)).view('uint8')
    img = img.reshape(planes, rows, cols)
    

    网页内容由stack overflow 提供, 点击上面的
    可以查看英文原文,
    原文链接