如何在Numba jitted函数中将float numpy数组的值转换为int,且要在nopython模式下实现。

5
在一个使用numba编译的nopython函数中,我需要使用另一个数组中的值作为索引来索引一个数组。这两个数组都是numpy浮点数数组。
例如:
@numba.jit("void(f8[:], f8[:], f8[:])", nopython=True)
def need_a_cast(sources, indices, destinations):
    for i in range(indices.size):
        destinations[i] = sources[indices[i]]

我的代码有所不同,但让我们假设这个愚蠢的例子可以重现问题(即,我不能使用 int 类型的索引)。据我所知,在 nopython jit 函数内部,我不能使用 int(indices[i]) 或 indices[i].astype("int")。

我该怎么做呢?

2个回答

3

至少使用 numba 0.24,你可以进行简单的强制类型转换:

import numpy as np
import numba as nb

@nb.jit(nopython=True)
def need_a_cast(sources, indices, destinations):
    for i in range(indices.size):
        destinations[i] = sources[int(indices[i])]

sources = np.arange(10, dtype=np.float64)
indices = np.arange(10, dtype=np.float64)
np.random.shuffle(indices)
destinations = np.empty_like(sources)

print indices
need_a_cast(sources, indices, destinations)
print destinations

# Result
# [ 3.  2.  8.  1.  5.  6.  9.  4.  0.  7.]
# [ 3.  2.  8.  1.  5.  6.  9.  4.  0.  7.]

3
如果你真的不能使用int(indices[i])(这对于JoshAdel和我都有效),你应该能够用math.truncmath.floor来解决它:
import math

...

destinations[i] = sources[math.trunc(indices[i])] # truncate (py2 and py3)
destinations[i] = sources[math.floor(indices[i])] # round down (only py3)

据我所知,math.floor 只适用于 Python3,因为在 Python2 中它返回一个 float。但是另一方面,math.trunc 对于负值向上舍入。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接