NumPy中更高级的花式索引?

3

我正在使用NumPy实现颜色插值,使用查找表(LUT)。在某个时刻,我使用RGB值的4个最高有效位来从17x17x17x4 LUT中选择相应的CMYK值。目前看起来像这样:

import numpy as np
rgb = np.random.randint(16, size=(3, 1000, 1000))
lut = np.random.randint(256, size=(17, 17, 17, 4))
cmyk = lut[rgb[0], rgb[1], rgb[2]]

这里是第一个问题...有没有更好的方法?看起来很自然,你可以告诉NumPy,lut的索引存储在rgb的轴0上,而不必实际编写代码。那么,NumPy中是否有类似于cmyk = lut.fancier_take(rgb, axis=0)的东西呢?

此外,我得到了一个形状为(1000, 1000,4)的数组,因此为了与输入保持一致,我需要使用一些swapaxes进行全部旋转:

cmyk = cmyk.swapaxes(2, 1).swapaxes(1, 0).copy()

我还需要添加复制语句,因为如果不这样做,生成的数组在内存中不是连续的,这会在后面带来麻烦。

现在我倾向于在进行高级索引之前旋转LUT,然后进行以下操作:

swapped_lut = lut.swapaxes(2, 1).swapaxes(1, 0)
cmyk = swapped_lut[np.arange(4), rgb[0], rgb[1], rgb[2]]

但是,仍然感觉不太对劲……肯定有更加优雅的方法来做到这一点,对吧?比如像这样:cmyk = lut.even_fancier_take(rgb, in_axis=0, out_axis=0)……


我按照Bi Rico的建议进行了操作,使用了重新塑形后的arange。虽然效果很好,但我发现将lut(17, 17, 17, 4)的形状转换为(4, 4913),然后使用take和我自己版本的ravel_multi_index来提取所需值,速度可以快10倍左右... - Jaime
2个回答

3

我建议使用元组来强制按行进行索引,并使用np.rollaxistranspose代替swapaxes

lut[tuple(rgb)].transpose(2, 0, 1).copy()

或者

np.rollaxis(lut[tuple(rgb)], 2).copy()

首先滚动坐标轴,使用以下命令:

np.rollaxis(lut, -1)[(Ellipsis,) + tuple(rgb)]

1
可能会有点丑,但是在最后一个例子中,np.arange(4).reshape(4,1,1)将已经强制执行C连续的复制(因此避免了显式复制),而不是使用省略号。 - seberg
@Sebastian,如果作为完整答案,那会是什么样子呢? - ecatmur
np.rollaxis(lut, -1)[(np.arange(4).reshape(4,1,1),) + tuple(rgb)]np.rollaxis(lut, -1)[(np.arange(4).reshape(4,1,1),) + tuple(rgb)] - seberg
1
如果你要使用 np.arange(4).reshape(4, 1, 1),为什么还要滚动 LUT?直接使用 lut[tuple(rgb) + (np.arange(4).reshape(4, 1, 1),)] 即可。 - Bi Rico
@Sebastian,这是我测试时的情况。一般来说不是这样吗? - Bi Rico
显示剩余2条评论

1
如果您更换了lut,则需要执行以下操作:np.arange(4)将无法正常工作:
swapped_lut = np.rollaxis(lut, -1)
cmyk = swapped_lut[:, rgb[0], rgb[1], rgb[2]].copy()

或者您可以替换

cmyk = lut[rgb[0], rgb[1], rgb[2]]
cmyk = cmyk.swapaxes(2, 1).swapaxes(1, 0).copy()

使用:

cmyk = lut[tuple(rgb)]
cmyk = np.rollaxis(cmyk, -1).copy()

但是试图一步完成所有的工作可能会很困难,也许:

rng = np.arange(4).reshape(4, 1, 1)
cmyk = lut[rgb[0], rgb[1], rgb[2], rng]

这看起来一点也不易读,是吧?

看看这个问题的答案Numpy multi-dimensional array indexing swaps axis order。它很好地解释了如何使用numpy广播多个数组以获得输出大小。在这里,您想要创建索引到lut,以广播到(4,1000,1000)。希望这有些意义。


因此,在进行花式索引时也会发生广播...一旦你知道了这一点,这既聪明又显而易见!这看起来确实像我正在寻找的答案,但在接受它之前,请给我周末时间自己弄清楚细节。 - Jaime

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接