用浮点数标量乘以整数numpy数组而不需要中间的浮点数数组

4
我正在处理非常大的图像数组,其中包含uint16数据,我希望将其降采样并转换为uint8
我的初始方法会导致MemoryError,因为中间需要使用float64数组:
img = numpy.ones((29632, 60810, 3), dtype=numpy.uint16) 

if img.dtype == numpy.uint16:
    multiplier = numpy.iinfo(numpy.uint8).max / numpy.iinfo(numpy.uint16).max
    img = (img * multiplier).astype(numpy.uint8, order="C")

然后我尝试了原地进行乘法,具体操作如下:

if img.dtype == numpy.uint16:
    multiplier = numpy.iinfo(numpy.uint8).max / numpy.iinfo(numpy.uint16).max
    img *= multiplier
    img = img.astype(numpy.uint8, order="C")

但是我遇到了以下错误:

TypeError: 无法将dtype('float64')的ufunc乘法输出转换为dtype('uint16'),使用强制规则'same_kind'

你知道有什么方法可以在最小化内存占用的情况下执行此操作吗?

我在哪里可以更改错误消息中提到的强制规则?

6个回答

2

"你知道一种在最小化内存占用的情况下执行此操作的方法吗?"

首先,让我们正确计算[空间]域的大小。基本数组是一个29k6 x 60k8 x RGB x 2B的内存对象:

>>> 29632 * 60810 * 3 * 2 / 1E9         ~ 10.81 [GB]

已经使用了11 [GB]的内存。

任何操作都需要一些空间。有一个TB级别的[SPACE]-域,用于纯内存numpy向量化技巧,我们就完成了。

考虑到输出任务是最小化内存占用,将所有数组及其操作移入numpy.memmap()对象中即可解决问题。


如果我遇到无法适应我的16Gb内存的更大的图像,我会记住这个答案的。 :) - PiRK
下面发布了一个改进的解决方案,具有约4.3〜4.8倍更快的原地处理基准。 - user3666197

2

在阅读了numpy ufunc文档后,我终于找到了一个可行的解决方案。

    multiplier = numpy.iinfo(numpy.uint8).max / numpy.iinfo(numpy.uint16).max
    numpy.multiply(img, multiplier, out=img, casting="unsafe")
    img = img.astype(numpy.uint8, order="C")

我本应早些发现这个问题,但如果您不熟悉一些技术词汇,阅读起来并不容易。


1
这是我的最爱解决方案,因为我不在意性能损失,并且它适用于所有情况(许多不同的数组形状)。谢谢! - Masterfool

1
你在这种情况下也可以使用Numba或Cython。这样,你可以明确地避免任何临时数组。代码有点长,但非常容易理解和更快。 示例
import numpy as np
import numba as nb

@nb.njit(parallel=True)
def conv_numba(img):
    multiplier = np.iinfo(np.uint8).max / np.iinfo(np.uint16).max
    img_out=np.empty(img.shape,dtype=np.uint8)
    for i in nb.prange(img.shape[0]):
        for j in range(img.shape[1]):
            for k in range(img.shape[2]):
                img_out[i,j,k]=img[i,j,k]*multiplier
    return img_out

#img_in have to be contigous, otherwise reshape will fail
@nb.njit(parallel=True)
def conv_numba_opt(img_in):
    multiplier = np.iinfo(np.uint8).max / np.iinfo(np.uint16).max
    shape=img_in.shape

    img=img_in.reshape(-1)
    img_out=np.empty(img.shape,dtype=np.uint8)

    for i in nb.prange(img.shape[0]):
        img_out[i]=img[i]*multiplier
    return img_out.reshape(shape)

def conv_numpy(img):
    np.multiply(img, multiplier, out=img, casting="unsafe")
    img = img.astype(np.uint8, order="C")
    return img

时间安排
img = np.ones((29630, 6081, 3), dtype=np.uint16)

%timeit res_1=conv_numpy(img)
#990 ms ± 2.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit res_2=conv_numba(img)
#with parallel=True
#122 ms ± 17.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#with parallel=False
#571 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

非常好的解决方案,谢谢!我只需要编写第二个numba函数,其中只有2个循环级别,以处理灰度(2D)图像数组。我想知道numba是否也可以提供任意维数的解决方案,即无需显式循环? - PiRK
1
@PiRK 是的,但还需要更多信息。例如,您可以将数组展平(reshape(-1)),然后在最后再次进行reshape。但是,只支持连续数组。如果在numba函数之外进行重塑,则展平分步数组将导致复制,这是希望避免的。因此,对于完整的解决方案,您可能需要两个代码路径(一个用于分步数组,一个用于连续数组)。 - max9111
是的,那很有道理。我现在必须保留循环,直到我找到一个低级库以正确顺序读取数据。我正在使用一个库读取CZI文件,该库返回所有带有奇怪形状(2, 3, 29632, 60810, 1)(轴“SCYX0”)的数据,并且我必须执行一些转置才能得到我的RGB图像。无论如何,谢谢。 - PiRK

0

这些是在将 uint16 转换为 uint8 的特定情况下可行的替代解决方案,建议参考 https://github.com/silx-kit/silx/pull/2889,利用可以只读取第一个字节并忽略第二个字节的事实:

>>> import numpy
>>> img = numpy.ones((29632, 30000, 3), dtype=numpy.uint16)

>>> %timeit img2 = numpy.ascontiguousarray(img.astype(dtype=('u2', [('lo','u1'), ('hi', 'u1')])))["hi"]
2.43 s ± 34.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

>>> %timeit img2 = numpy.ascontiguousarray(img.view(numpy.uint8)[..., 1::2])
2.58 s ± 165 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

结合numba和并行计算,它比之前的使用算术乘法的numba解决方案稍微快一点(在相同的数组上耗时750毫秒):
In [11]: @numba.njit(parallel=True)
...: def uint16_to_uint8_shift(img):
...:     img_out = numpy.empty(img.shape, dtype=numpy.uint8)
...:     for i in numba.prange(img.shape[0]):
...:         for j in range(img.shape[1]):
...:             for k in range(img.shape[2]):
...:                 img_out[i, j, k] = img[i, j, k] >> 8
...:     return img_out
...:
In [12]: %timeit img2 = uint16_to_uint8_shift(img)
650 ms ± 43.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 

0
使用整数除法(而不是乘以倒数)应该可以避免中间的浮点数组(如果我错了请纠正我),并允许您在原地执行操作。
divisor = numpy.iinfo(numpy.uint16).max // numpy.iinfo(numpy.uint8).max
img //= divisor
img = img.astype(numpy.uint8, order="C")

谢谢您的建议。但我遇到了类似的错误。 “TypeError: ufunc 'floor_divide' 输出(类型代码为'd')无法根据强制转换规则''same_kind''强制转换为提供的输出参数(类型代码为'H')。” - PiRK

0

想要获得一个4.3 ~ 4.7 x更快的解决方案怎么样?

PiRKnumpy.ufunc基于原地处理的解决方案进行了一些改进,如此处所示,
这里有一个大约4.3 ~ 4.7 x更快的修改版本:

>>> from zmq import Stopwatch; aClk = Stopwatch() # a trivial [us]-resolution clock
>>> ###
>>> ############################################### ORIGINAL ufunc()-code:
>>> ###
>>> ###   np.ones( ( 29632, 608, 3 ), dtype = np.uint16 ) ## SIZE >> CPU CACHE SIZE
>>> I   = np.ones( ( 29632, 608, 3 ), dtype = np.uint16 )
>>> #mg = np.ones( ( 29632, 608, 3 ), dtype = np.uint16 ); aClk.start(); _ = np.multiply( img, fMUL, out = img, casting = 'unsafe' ); img = img.astype( np.uint8, order = 'C' );aClk.stop() ########## a one-liner for fast re-testing on CLI console
>>> img = I.copy();aClk.start();_= np.multiply( img,
...                                             fMUL,
...                                             out     =  img,
...                                             casting = 'unsafe'
...                                             ); img  =  img.astype( np.uint8,
...                                                                    order = 'C'
...                                                                    );aClk.stop()

312802 [us]
320087 [us]
329401 [us]
317346 [us]

在第一次尝试中使用更多文档记录的 ufunc-smart kwargs,性能提升~ 4.3 ~ 4.7 x

>>> ### = I.copy(); aClk.start(); _ = np.multiply( img, fMUL, out = img, casting = 'unsafe', dtype = np.uint8, order = 'C'  ); aClk.stop() ########## a one-liner for fast re-testing on CLI console
>>> img = I.copy(); aClk.start(); _ = np.multiply( img,
...                                                fMUL,
...                                                out     =  img,
...                                                casting = 'unsafe',
...                                                dtype   =  np.uint8,
...                                                order   = 'C'
...                                                ); aClk.stop()
69812 [us]
71335 [us]
73112 [us]
70171 [us]

我在哪里可以更改错误消息中提到的转换规则?

隐式(默认)模式用于casting参数,在numpy版本1.10〜1.11之间的某个地方已更改,但在已发布的numpy.ufunc API文档中有很好的记录。


这个不起作用。数组类型没有改变。看起来当提供了out参数时,dtype参数被忽略了。在原地更改输出数组的类型是不可行的,这是有道理的。对于这一步骤,复制是不可避免的。 - PiRK
@PiRK 我的错 - 我只测试了.flags中的“F”到“C”的排序,以便在原地转换,而不是dtype。 - user3666197
我确认这个不起作用。请尝试将输入改为img=np.uint16((0, 500, 65535))。在调用原始解决方案后,我们得到了array([ 0, 1, 255], dtype=uint8)。在调用您的新解决方案后,我们得到了array([0, 0, 0], dtype=uint8) - Masterfool

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接