我需要将一些单元格存储在布尔值的数组中。起初我使用了numpy,但当数组开始占用大量内存时,我想到了将非零元素存储在以元组为键(因为它是可哈希类型)的字典中。例如:
{(0, 0, 0): True, (1, 2, 3): True}
(这是“3D数组”中索引为0,0,0和1,2,3的两个单元格,但维数在运行我的算法时未知且定义)。
这项技术有很大的帮助,因为非零单元格仅占数组的一小部分。
要从此字典中写入和获取值,我需要使用循环:
def fill_cells(indices, area_dict):
for i in indices:
area_dict[tuple(i)] = 1
def get_cells(indices, area_dict):
n = len(indices)
out = np.zeros(n, dtype=np.bool)
for i in range(n):
out[i] = tuple(indices[i]) in area_dict.keys()
return out
现在我需要用Numba来加速它。Numba不支持Python的原生dict(),所以我使用了numba.typed.Dict。问题是,在定义函数时,Numba要知道键的大小,因此我甚至无法创建字典(键的长度在事先未知,并且在调用函数时定义):
@njit
def make_dict(n):
out = {(0,)*n:True}
return out
Numba无法正确推断字典键的类型并返回错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)
如果我在函数中将参数n更改为具体的数字,它就可以正常工作。我用以下小技巧解决了这个问题:
n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)
但我认为这是一种低效的方法。我仍然需要使用我的fill_cells和get_cells函数,并使用@njit修饰符,但Numba返回相同的错误,因为我在这些函数中尝试从numpy数组创建元组。
我理解Numba(以及编译)的基本限制,但也许有一些方法可以加速函数,或者你有另一种解决我的单元格存储问题的方法?