我希望创建一个使用Numba编译的函数,该函数将接受指针或数组的内存地址作为参数,并对其进行计算,例如修改底层数据。
以下是用纯Python编写的示例:
import ctypes
import numba as nb
import numpy as np
arr = np.arange(5).astype(np.double) # create arbitrary numpy array
def modify_data(addr):
""" a function taking the memory address of an array to modify it """
ptr = ctypes.c_void_p(addr)
data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
data += 2
addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])
从示例中可以看到,数组arr
在未显式传递给函数的情况下被修改。在我的用例中,数组的形状和dtype已知且始终不变,这应该简化接口。
1. 尝试:天真的即时编译
我现在尝试编译modify_data
函数,但失败了。 我的第一次尝试是使用
shape = arr.shape
dtype = arr.dtype
@nb.njit
def modify_data_nb(ptr):
data = nb.carray(ptr, shape, dtype=dtype)
data += 2
ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr) # <<< error
这个错误信息为“无法确定的Numba类型”,也就是说,它不知道如何解释指针。
尝试2:显式指定类型。
arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape
@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
""" a function taking the memory address of an array to modify it """
data = nb.carray(ptr, shape)
data += 2
但这并没有帮助。它没有报错,但我不知道如何调用函数modify_data_nb
。 我尝试了以下选项
modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject
ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject
有没有办法从arr
中获取正确的指针格式,以便我可以将其传递给Numba编译的modify_data_nb
函数?或者,有没有其他方法将内存位置传递给函数。
尝试3:使用scipy.LowLevelCallable
通过使用scipy.LowLevelCallable
及其魔力,我取得了一些进展:
arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])
# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype
@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
data = nb.carray(ptr, shape, dtype=dtype)
data += 2
modify_data_llc = LowLevelCallable(modify_data.ctypes).function
# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
# call the function only with the pointer
modify_data_llc(ptr)
# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])
我现在可以调用一个函数来访问这个数组,但是这个函数不再是Numba函数。特别地,它不能在其他Numba函数中使用。
@nb.njit() def call(arr): modify_data_nb(arr.ctypes)
- max9111call
外部arr
的内容。 - David Zwicker