Python numba 列表的列表的元组

3

我正在尝试使用numba加速我的算法,在使用numpy和优化后没有更多的改进。

我有一个函数,在一个大的2层嵌套循环中进行一些计算:

import random
from numba import njit

@njit()
def decide_if_vaild():
    return bool(random.getrandbits(1))

@njit()
def decide_what_bin(bins):
    return random.randint(0, bins-1)

@njit()
def foo(bins, loops):
    results = [[] for _ in range(bins)]

    for i in range(loops):
        for j in range(i+1, loops):
            happy = decide_if_vaild()
            bin = decide_what_bin(bins)
            if happy:
                results[bin].append( (i,j) )
                # or
                # results[bin].append( [i,j] )
    return results

if __name__ == '__main__':
    x = foo(3,100)

如果我运行上面的最小示例,我会得到一个类型错误(如预期所示):
  File "C:\Users\xxx\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\typeinfer.py", line 104, in getone
    assert self.type is not None
numba.errors.InternalError: 
[1] During: typing of call at C:/Users/xxx/minimal_example.py (21)
--%<-----------------------------------------------------------------

File "minimal_example.py", line 21

问题在于:“results[bin].append((i,j))”这段代码,我试图将一个元组(列表也不行)添加到列表中,但是出现了问题。
箱子的数量事先已知,但是有多少个元素(2元组、列表或np.array)取决于“decide_if_vaild”评估为True的次数。由于我不知道这会发生多少次,并且计算非常昂贵,因此我不知道是否有其他解决方法。
有什么好主意可以在jitted函数中生成结果并返回它,或者传递全局容器来填充此函数吗?
这可能会退回到:
numba.errors.LoweringError: Failed at nopython (nopython mode backend)
list(list(list(int64))): unsupported nested memory-managed object

在list(list(int64))中出现了类似的问题(https://github.com/numba/numba/issues/2560),在numba 0.39.0版本中已经得到解决,具体信息请参考https://github.com/numba/numba/pull/2840


1
这段代码可以很容易地使用numpy进行重写,而不需要使用for循环。此外,通过results = [[] for _ in range(bins)],您正在创建一个空的未定义类型列表。Numba不喜欢这样。最后但并非最不重要的是,您不应该覆盖内部变量,例如bin - JE_Muc
@Scotty1- 这个怎么用numpy重写呢?据我所知,numpy不太喜欢使用动态大小的向量。由于我不知道结果bin的大小,因此无法确定。关于内部的“bin”,这只是为了快速展示可重现的最小示例而发生的。 - iR0Nic
我不太清楚变量size是什么,而且在您的示例中,foo只接收一个参数。因此,我无法运行您的代码进行测试。但您可以使用happy = np.random.randint(0, 2, size, dtype=np.bool)让自己感到开心,然后根据happy的大小创建结果。只需尝试“重新思考”您的代码。开始时需要一些时间,但是一旦您习惯了向量化的代码,就会发现它更易于阅读和理解(并且更快,更高效...)。 - JE_Muc
你是完全正确的,我很抱歉。我已经更正了这个例子。我明白你的意思,我会认真考虑并尝试一下的,谢谢! - iR0Nic
1个回答

1

我现在已经实施了以下解决方法,尽管它并没有完全回答这个问题,但对于其他遇到此问题的人来说,这可能是一个合适的方法:

@njit()
def foo(bins, loops):
    results = []
    mapping = []

    for i in range(loops):
        for j in range(loops+1, size):
            happy = decide_if_vaild()
            bin = decide_what_bin(bins)
            if happy:
                results.append( (i,j) )
                mapping.append( bin )
    return results, mapping

这将返回一个元组列表(numba 0.39.0支持),以及一个映射列表,其中mapping [i]包含results [i]的bin。现在jit编译器可以平稳工作,我可以在jit之外解压缩结果。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接