如何在Python的joblib中写入共享变量

24
以下代码并行化了一个for循环。
import networkx as nx;
import numpy as np;
from joblib import Parallel, delayed;
import multiprocessing;

def core_func(repeat_index, G, numpy_arrary_2D):
  for u in G.nodes():
    numpy_arrary_2D[repeat_index][u] = 2;
  return;

if __name__ == "__main__":
  G = nx.erdos_renyi_graph(100000,0.99);
  nRepeat = 5000;
  numpy_array = np.zeros([nRepeat,G.number_of_nodes()]);
  Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
  print(np.mean(numpy_array));

如您所见,预期要打印的值为2。但是,当我在集群(多核、共享内存)上运行我的代码时,它返回0.0。

我认为问题在于每个 worker 都创建了自己的 numpy_array 对象副本,而主函数中创建的那个并没有更新。我该如何修改代码,使得 numpy 数组 numpy_array 可以被更新?


{btsdaf} - Sergey Vasilyev
2个回答

27

joblib默认使用processes的多进程池,正如它的手册所说:

在底层,Parallel对象创建一个多进程池,将Python解释器分叉为多个进程,以执行列表中的每个项。延迟函数是一种简单的技巧,可以使用函数调用语法创建一个元组(函数,args,kwargs)。

这意味着每个进程都继承了数组的原始状态,但它写入数组中的任何内容在进程退出时都会丢失。只有函数结果会返回到调用(主)进程。但是因为您没有返回任何内容,所以返回了None

要使共享数组可修改,有两种方法:使用线程和使用共享内存。


与进程不同,线程共享内存。因此,您可以向数组中写入数据,每个任务都会看到这种更改。根据joblib的手册,可以通过以下方式完成:

  Parallel(n_jobs=4, backend="threading")(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));

当您运行它时:

$ python r1.py 
2.0

然而,当您将复杂的内容写入数组时,请确保正确处理数据或数据块周围的锁定,否则可能会出现竞争条件(请谷歌它)。

同时请仔细阅读有关GIL的信息,因为Python中的计算多线程是受限制的(不像I/O多线程)。


如果仍然需要进程(例如由于GIL),可以将该数组放入共享内存中。

这是一个比较复杂的话题,但是joblib手册中也展示了joblib + numpy共享内存示例


2
正如Sergey在他的回答中所写的那样,进程不共享状态和内存。这就是为什么您没有看到预期的答案。
线程共享状态和内存空间,因为它们在同一个进程下运行。如果您有许多I/O操作,这将非常有用。由于GIL的存在,它不会为您提供更多的处理能力(更多的CPU)。
一种在进程之间通信的技术是使用Manager创建代理对象。您可以创建一个管理器对象,该对象在进程之间同步资源。
引用:
Manager()返回的管理器对象控制一个服务器进程,该进程持有Python对象,并允许其他进程使用代理对它们进行操作。
我没有测试过这段代码(我没有您使用的所有模块),并且可能需要对代码进行更多修改,但是使用Manager对象应该是这样的。
if __name__ == "__main__":
    G = nx.erdos_renyi_graph(100000,0.99);
    nRepeat = 5000;

    manager = multiprocessing.Manager()
    numpys = manager.list(np.zeros([nRepeat, G.number_of_nodes()])

    Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpys, que) for repeat_index in range(nRepeat));
    print(np.mean(numpys));

1
{btsdaf} - Sergey Vasilyev

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接