如何使用Numba加速Python字典速度

4

我需要将一些单元格存储在布尔值的数组中。起初我使用了numpy,但当数组开始占用大量内存时,我想到了将非零元素存储在以元组为键(因为它是可哈希类型)的字典中。例如: {(0, 0, 0): True, (1, 2, 3): True}(这是“3D数组”中索引为0,0,0和1,2,3的两个单元格,但维数在运行我的算法时未知且定义)。 这项技术有很大的帮助,因为非零单元格仅占数组的一小部分。

要从此字典中写入和获取值,我需要使用循环:

def fill_cells(indices, area_dict):
    for i in indices:
        area_dict[tuple(i)] = 1

def get_cells(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool)
    for i in range(n):
        out[i] = tuple(indices[i]) in area_dict.keys()
    return out

现在我需要用Numba来加速它。Numba不支持Python的原生dict(),所以我使用了numba.typed.Dict。问题是,在定义函数时,Numba要知道键的大小,因此我甚至无法创建字典(键的长度在事先未知,并且在调用函数时定义):
@njit
def make_dict(n):
    out = {(0,)*n:True}
    return out

Numba无法正确推断字典键的类型并返回错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)

如果我在函数中将参数n更改为具体的数字,它就可以正常工作。我用以下小技巧解决了这个问题:

n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)

但我认为这是一种低效的方法。我仍然需要使用我的fill_cells和get_cells函数,并使用@njit修饰符,但Numba返回相同的错误,因为我在这些函数中尝试从numpy数组创建元组。

我理解Numba(以及编译)的基本限制,但也许有一些方法可以加速函数,或者你有另一种解决我的单元格存储问题的方法?


2
你考虑过稀疏矩阵吗? - Marat
@Marat 是的,我基于键字典(fill_cells和get_cells函数是该实现的一部分)制作了自己的稀疏矩阵实现。我意识到这是一种相当常见的稀疏矩阵解决方案。问题在于我需要加速这个实现。此外,我不需要对其进行矩阵运算,只需存储和获取值,也许可以扩展可能的解决方案集。 - True do day
本地数据结构,如字典,效率相当低下。scipy.sparse 提供了 C 实现,可能会比本地结构快上一个数量级。 - Marat
我的错,我完全忘记了scipy稀疏矩阵只有2D。在eager模式下,Tensorflow稀疏张量可能是任意维度的合理替代品。 - Marat
1
你看过 https://gist.github.com/sklam/830fe01343ba95828c3b24c391855c86 吗?当我想要使用数组作为矩阵的索引时,我遇到了同样的问题。只需要在顶部进行小的调整,因为字典没有 ndim。 - tobiasraabe
显示剩余2条评论
1个回答

0

最终解决方案:

主要问题在于Numba需要在定义创建元组的函数时知道元组的长度。诀窍是每次重新定义函数。我需要生成包含定义函数代码的字符串,并使用exec()运行它:

n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)

接下来,我可以调用arr_to_tuple(a)创建元组,这些元组可以在另一个使用@njit修饰的函数中使用。

例如,创建元组键的空字典,需要解决问题:

@njit
def make_empty_dict():
    tpl = arr_to_tuple(np.array([0]*5))
    out = {tpl:True}
    del out[tpl]
    return out

我在字典中写入一个元素,因为这是Numba推断类型的一种方式之一。

此外,我需要使用在问题描述中提到的fill_cellsget_cells函数。这是我用Numba重写它们的方式:

编写元素。只需将tuple()更改为arr_to_tuple():

@njit
def fill_cells_nb(indices, area_dict):
    for i in range(len(indices)):
        area_dict[arr_to_tuple(indices[i])] = True

从字典中获取元素需要一些令人毛骨悚然的代码:

@njit
def get_cells_nb(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool_)
    for i in range(n):
        new_len = len(area_dict)
        tpl = arr_to_tuple(indices[i])
        area_dict[tpl] = True
        old_len = len(area_dict)
        if new_len == old_len:
            out[i] = True
        else:
            del area_dict[tpl]
    return out

我的Numba版本(0.46)不支持.contains(in)运算符和try-except结构。如果您有支持它的版本,您可以为其编写更“常规”的解决方案。

因此,当我想要检查字典中是否存在某个索引的元素时,我会记住它的长度,然后在字典中写入具有提到的索引的内容。如果长度改变了,我就得出结论该元素不存在。否则,该元素存在。看起来非常慢,但实际上并不是。

速度测试:

这些解决方案的速度非常快。我使用%timeit进行了测试,与本地Python优化代码进行比较:

  1. arr_to_tuple() 比常规的tuple()函数快5倍
  2. get_cells with numba 对于一个元素,比native-Python written get_cells快3倍,对于大量元素的数组,快40倍
  3. fill_cells with numba 对于一个元素,比native-Python written fill_cells快4倍,对于大量元素的数组,快40倍

1
你是否与有类型的列表进行了性能比较?似乎你不需要存储“True”,因为已经通过存储索引来隐含了它。你还可以考虑编写一个unravel_indexravel_multi_index函数,类似于Numpy,使得存储的索引始终为1D。 - Rutger Kassies
@RutgerKassies 在你的评论后,我进行了比较。使用类型化列表的工作速度明显较慢,因为需要在循环中检查列表的元素才能得到检查结果。函数的执行时间取决于列表的大小,而字典由于哈希键具有恒定的时间。 - True do day
顺便说一句,我发现对于这样的问题,使用原生的Python hash()函数并将其用于循环而不是ravel_multi_index可能是合理的(更快)。这些循环可以使用@njit修饰,而不需要显著改变代码。 - True do day

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接