使用字典在numba njit并行化加速代码时出现问题

3

我已经编写了一段代码,并尝试使用numba来加速代码。该代码的主要目标是根据条件对一些值进行分组。为此,iter_用于使代码收敛以满足条件。我准备了一个小案例来复制样本代码:

import numpy as np
import numba as nb

rng = np.random.default_rng(85)

# --------------------------------------- small data volume ---------------------------------------
# values_ = {'R0': np.array([0.01090976, 0.01069902, 0.00724112, 0.0068463 , 0.01135723, 0.00990762,
#                                        0.01090976, 0.01069902, 0.00724112, 0.0068463 , 0.01135723]),
#            'R1': np.array([0.01836379, 0.01900166, 0.01864162, 0.0182823 , 0.01840322, 0.01653088,
#                                        0.01900166, 0.01864162, 0.0182823 , 0.01840322, 0.01653088]),
#            'R2': np.array([0.02430913, 0.02239156, 0.02225379, 0.02093393, 0.02408692, 0.02110411,
#                                        0.02239156, 0.02225379, 0.02093393, 0.02408692, 0.02110411])}
#
# params = {'R0': [3, 0.9490579204466154, 1825, 7.070272000000002e-05],
#           'R1': [0, 0.9729203826820172, 167 , 7.070272000000002e-05],
#           'R2': [1, 0.6031363088057902, 1316, 8.007296000000003e-05]}
#
# Sno, dec_, upd_ = 2, 100, 200
# -------------------------------------------------------------------------------------------------

# ----------------------------- UPDATED (medium and large data volumes) ---------------------------
# values_ = np.load("values_med.npy", allow_pickle=True)[()]
# params = np.load("params_med.npy", allow_pickle=True)[()]
values_ = np.load("values_large.npy", allow_pickle=True)[()]
params = np.load("params_large.npy", allow_pickle=True)[()]

Sno, dec_, upd_ = 2000, 1000, 200
# -------------------------------------------------------------------------------------------------

# values_ = [*values_.values()]
# params = [*params.values()]


# @nb.jit(forceobj=True)
# def test(values_, params, Sno, dec_, upd_):

final_dict = {}
for i, j in enumerate(values_.keys()):
    Rand_vals = []
    goal_sum = params[j][1] * params[j][3]
    tel = goal_sum / dec_ * 10
    if params[j][0] != 0:
        for k in range(Sno):
            final_sum = 0.0
            iter_ = 0
            t = 1
            while not np.allclose(goal_sum, final_sum, atol=tel):
                iter_ += 1
                vals_group = rng.choice(values_[j], size=params[j][0], replace=False)
                # final_sum = 0.0016 * np.sum(vals_group)  # -----> For small data volume
                final_sum = np.sum(vals_group ** 3)        # -----> UPDATED For med or large data volume
                if iter_ == upd_:
                    t += 1
                    tel = t * tel
            values_[j] = np.delete(values_[j], np.where(np.in1d(values_[j], vals_group)))
            Rand_vals.append(vals_group)
    else:
        Rand_vals = [np.array([])] * Sno
    final_dict["R" + str(i)] = Rand_vals

#    return final_dict


# test(values_, params, Sno, dec_, upd_)
起初,对于这段代码,使用了@nb.jit(使用forceobj=True来避免警告等...),但这会对性能产生不利影响。也检查了nopython,使用@nb.njit,但由于输入的字典类型,导致出现以下错误:不支持(如12):

cannot determine Numba type of <class 'dict'>

我不知道是否可以通过numba.typed中的Dict(将创建的Python字典转换为numba Dict)来处理它,或者将字典转换为数组列表是否有任何优势。我认为,如果某些代码行(例如Rand_vals.append(vals_group)else部分等)被拿出函数并进行修改,则可能可以进行并行化,以获得与之前相同的结果,但是我不知道该如何做。 如果能帮助利用numba对此代码进行处理,我将不胜感激。如果可以实现numba并行化,那么这将是最理想的(可能是性能最好的)解决方案。

数据:


2
请注意,如果 while 条件第一次为 false,则 vals_group 未定义。 - Jérôme Richard
@JérômeRichard,说得好。而且,我觉得这种情况在我的数据中永远不会发生。我会考虑并检查的,谢谢。 - Ali_Sh
1个回答

阿里云服务器只需要99元/年,新老用户同享,点击查看详情
4

这段代码可以被转换为Numba,但不是直接的。

首先,字典和列表类型必须被定义,因为 Numba的 njit 函数不能直接操作反映出来的列表 (也就是纯Python列表)。在Numba中这有点繁琐,并且导致的代码有些啰嗦:

String = nb.types.unicode_type
ValueArray = nb.float64[::1]
ValueDict = nb.types.DictType(String, ValueArray)
ParamDictValue = nb.types.Tuple([nb.int_, nb.float64, nb.int_, nb.float64])
ParamDict = nb.types.DictType(String, ParamDictValue)
FinalDictValue = nb.types.ListType(ValueArray)
FinalDict = nb.types.DictType(String, FinalDictValue)

然后您需要将输入字典进行转换:

nbValues = nb.typed.typeddict.Dict.empty(String, ValueArray)
for key,value in values_.items():
    nbValues[key] = value.copy()

nbParams = nb.typed.typeddict.Dict.empty(String, ParamDictValue)
for key,value in params.items():
    nbParams[key] = (nb.int_(value[0]), nb.float64(value[1]), nb.int_(value[2]), nb.float64(value[3]))
首先,您需要编写核心函数。由于Numba中未实现np.allclosenp.isin函数,因此它们需要手动重新实现。但主要问题在于,Numba不支持rng Numpy对象。我认为它短期内肯定不会支持它。请注意,Numba具有随机数实现,尝试模仿Numpy的行为,但是种子的管理略有不同。还需注意,如果将种子设置为相同的值,则结果应与np.random.xxx Numpy函数相同(Numpy和Numba具有不同的种子变量,这些变量不同步)。
@nb.njit(FinalDict(ValueDict, ParamDict, nb.int_, nb.int_, nb.int_))
def nbTest(values_, params, Sno, dec_, upd_):
    final_dict = nb.typed.Dict.empty(String, FinalDictValue)
    for i, j in enumerate(values_.keys()):
        Rand_vals = nb.typed.List.empty_list(ValueArray)
        goal_sum = params[j][1] * params[j][3]
        tel = goal_sum / dec_ * 10
        if params[j][0] != 0:
            for k in range(Sno):
                final_sum = 0.0
                iter_ = 0
                t = 1

                vals_group = np.empty(0, dtype=nb.float64)

                while np.abs(goal_sum - final_sum) > (1e-05 * np.abs(final_sum) + tel):
                    iter_ += 1
                    vals_group = np.random.choice(values_[j], size=params[j][0], replace=False)
                    final_sum = 0.0016 * np.sum(vals_group)
                    # final_sum = 0.0016 * np.sum(vals_group)  # (for small data volume)
                    final_sum = np.sum(vals_group ** 3)        # (for med or large data volume)
                    if iter_ == upd_:
                        t += 1
                        tel = t * tel

                # Perform an in-place deletion
                vals, gr = values_[j], vals_group
                cur = 0
                for l in range(vals.size):
                    found = False
                    for m in range(gr.size):
                        found |= vals[l] == gr[m]
                    if not found:
                        # Keep the value (delete it otherwise)
                        vals[cur] = vals[l]
                        cur += 1
                values_[j] = vals[:cur]

                Rand_vals.append(vals_group)
        else:
            for k in range(Sno):
                Rand_vals.append(np.empty(0, dtype=nb.float64))
        final_dict["R" + str(i)] = Rand_vals
    return final_dict
请注意,np.isin 的替代实现非常简单但在实际输入示例中表现良好。 以下是该函数的调用方式:
nbFinalDict = nbTest(nbValues, nbParams, Sno, dec_, upd_)
最后,字典应该转换回基本的Python对象:
finalDict = dict()
for key,value in nbFinalDict.items():
    finalDict[key] = list(value)

这个实现对于小输入而言速度很快,但是对于大输入则不然,因为np.random.choice占据了几乎所有的时间(>96%)。问题在于当所请求的项数较少(即你的情况)时,这个函数明显并非最优的选择。事实上,它以线性时间运行输入数组的长度,而不是请求项数的线性时间。


进一步优化

可以完全重写算法,以更加有效的方式从主要当前数组中提取仅有的12个随机项并将其丢弃。思路是将n个项(小目标样本)与数组末尾的其他随机位置交换,然后检查总和,重复此过程直到达成条件,最后提取最后n项的视图,再调整视图大小以丢弃最后的项。所有这些操作可以在O(n)时间内完成,而不是在O(m)时间内完成,其中m是主要当前数组的大小,n << m(例如12与20_000)。它也可以在不进行任何昂贵的分配的情况下计算。以下是结果代码:

@nb.njit(nb.void(ValueArray, nb.int_, nb.int_))
def swap(arr, i, j):
    arr[i], arr[j] = arr[j], arr[i]

@nb.njit(FinalDict(ValueDict, ParamDict, nb.int_, nb.int_, nb.int_))
def nbTest(values_, params, Sno, dec_, upd_):
    final_dict = nb.typed.Dict.empty(String, FinalDictValue)
    for i, j in enumerate(values_.keys()):
        Rand_vals = nb.typed.List.empty_list(ValueArray)
        goal_sum = params[j][1] * params[j][3]
        tel = goal_sum / dec_ * 10
        values = values_[j]
        n = params[j][0]

        if n != 0:
            for k in range(Sno):
                final_sum = 0.0
                iter_ = 0
                t = 1

                m = values.size
                assert n <= m
                group = values[-n:]

                while np.abs(goal_sum - final_sum) > (1e-05 * np.abs(final_sum) + tel):
                    iter_ += 1

                    # Swap the group view with other random items
                    for pos in range(m - n, m):
                        swap(values, pos, np.random.randint(0, m))

                    # For small data volume:
                    # final_sum = 0.0016 * np.sum(group)

                    # For med/large data volume
                    final_sum = 0.0
                    for v in group:
                        final_sum += v ** 3

                    if iter_ == upd_:
                        t += 1
                        tel *= t

                assert iter_ > 0
                values = values[:m-n]
                Rand_vals.append(group)
        else:
            for k in range(Sno):
                Rand_vals.append(np.empty(0, dtype=nb.float64))
        final_dict["R" + str(i)] = Rand_vals
    return final_dict
除了更快之外,这种实现方法的好处也在于更简单。尽管随机性使得结果检查棘手(特别是因为此函数不使用相同的方法选择随机样本),但结果看起来与以前的实现相当相似。请注意,与上一个实现不同,此实现不会删除在group中的values项目(虽然这可能不是期望的)。

基准测试

以下是最后一次实现在我的计算机上的结果(编译和转换时间除外):

Provided small input (embedded in the question):
 - Initial code:   42.71 ms
 - Numba code:      0.11 ms

Medium input:
 - Initial code:   3481 ms
 - Numba code:       11 ms

Large input:
 - Initial code:   6728 ms
 - Numba code:       20 ms

请注意,转换时间与计算时间大致相同。

这个最新的实现在小输入上比初始代码快316~388倍


注释

请注意,由于dict和lists类型,编译时间需要几秒钟。

请注意,虽然可能可以并行实现,但只能并行执行最全面的循环。问题是要计算的项目很少,时间已经相当短了(不是多线程的最佳情况)。<-- 另外,由rng.choice创建的许多临时数组的创建肯定会导致并行循环无法扩展。-> 另外,list/dict不能安全地从多个线程写入,因此需要在整个函数中使用NumPy数组才能这样做(或添加已经昂贵的额外转换)。此外,Numba并行化倾向于显着增加编译时间,而编译时间已经很长了。最后,结果将不那么确定,因为每个Numba线程都有自己的随机数生成器种子,并且线程计算的项目无法预测。prange(取决于目标平台上选择的并行运行时)。请注意,在Numpy中,默认情况下使用常规随机函数的一个全局种子(已弃用的方式),而RNG对象有自己的种子(新的首选方式)。


我已经添加了更大的数据量来测试并行化或其他性能,将 final_sum = 0.0016 * np.sum(vals_group) 更改为 final_sum = np.sum(vals_group ** 3);这个公式在 medlarger 上都有效。如果需要进行评估,可以通过增加 dec_ 的数量级(例如从 100 增加到 100000 或更多)来增加运行时间。 - Ali_Sh
我对Numba有自己的方法来生成随机数(可以在并行循环中安全使用)感到有些困惑。你是指numba控制np.random.choice以便在并行循环中更快、更安全地使用,还是说numba有自己的随机模块可供使用,而你在这个解决方案中没有使用np.random.choice - Ali_Sh
我试图澄清一下这部分内容。我的观点是,我认为在并行的Numba代码和Numpy代码之间尝试复制相同的随机值非常棘手(即使可能)。不可重复性导致代码更难以检查,因此可能会出现更多的错误(更不用说并行代码往往也有更多的错误 - 特别是更难以察觉的错误)。此外,对于错误,似乎您传递的对象是PyObject,而它们必须是类型化字典。 - Jérôme Richard
我已经在 COLAB 和我的 Windows (10 企业版,x64,CPU i5 @2.8GHz,内存 16Gb) 上测试了代码,两者都显示错误。我还没有测试过。如何找到解决方法(也许最好在 colab 上尝试,因为它会像在我的 windows 上一样出错)?我已经在 Numba ver 0.53.10.55.1(已更新)上进行了测试。我不认为这与 Numba 的实验性新功能有关吧!是吗?这个答案对我来说充满了注释,而且也很具有挑战性,所以最好将这个讨论移到聊天中进行。 - Ali_Sh
1
让我们在聊天中继续这个讨论 - Ali_Sh
显示剩余4条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,