使用mpi4py并行化函数调用

5

我想使用mpi4py来并行化一个优化问题。为了最小化我的函数,我使用scipy中的minimize例程。

from scipy.optimize import minimize

def f(x, data) :
    #returns f(x)
x = minimize(f, x0, args=(data))

现在,如果我想使用mpi4py并行化我的函数。最小化算法的实现是顺序的,只能在一个进程上运行,因此只有我的函数是并行化的,这不是问题,因为函数调用是最耗时的步骤。但我无法想出如何实现这个问题,其中包含并行和顺序部分。

这是我的尝试:

from scipy.optimize import minimize
from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

N = 100 # for testing
step = N//size # say that N is divisible by size
def mpi_f(x, data) :
    x0 = x[rank*step:(rank+1)*step]
    res = f(x0, data)
    res = comm.gather(res, root=0)
    if rank == 0 :
        return res

if rank == 0 :
   x = np.zeros(N)
   xs = minimize(mpi_f, x, args=(data))

显然这种方法行不通,因为mpi_f只在进程0上运行。所以我想知道该怎么办?

谢谢。

1个回答

6
在您的代码中,根进程是唯一调用comm.gather()的进程,因为根进程是唯一调用并行化成本函数的进程。因此,程序面临死锁问题。您非常清楚这个问题。
为了克服这个死锁,其他进程必须像minimize需要的那样多次调用成本函数。由于不知道需要调用的次数,因此这些进程似乎适合使用while循环。
while循环的停止条件需要定义。该标志将从根进程广播到所有进程,因为只有根进程知道minimize()函数已结束。广播必须在成本函数中执行,因为所有进程都必须在每次迭代中测试最小化函数的结束。由于minimize使用函数的返回值,因此该标志通过可变类型的引用传递传递

最后,这是一个可能解决你问题的方案。它通过 mpirun -np 4 python main.py 运行。我使用了 fmin() 而不是 minimize(),因为我的 scipy 版本过旧。

#from scipy.optimize import minimize
from scipy.optimize import fmin
from mpi4py import MPI
import numpy as np

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

N = 100 # for testing
step = N//size # say that N is divisible by size

def parallel_function_caller(x,stopp):
    stopp[0]=comm.bcast(stopp[0], root=0)
    summ=0
    if stopp[0]==0:
        #your function here in parallel
        x=comm.bcast(x, root=0)
        array= np.arange(x[0]-N/2.+rank*step-42,x[0]-N/2.+(rank+1)*step-42,1.)
        summl=np.sum(np.square(array))
        summ=comm.reduce(summl,op=MPI.SUM, root=0)
        if rank==0:
            print "value is "+str(summ)
    return summ

if rank == 0 :
   stop=[0]
   x = np.zeros(1)
   x[0]=20
   #xs = minimize(parallel_function_caller, x, args=(stop))
   xs = fmin(parallel_function_caller,x0= x, args=(stop,))
   print "the argmin is "+str(xs)
   stop=[1]
   parallel_function_caller(x,stop)

else :
   stop=[0]
   x=np.zeros(1)
   while stop[0]==0:
      parallel_function_caller(x,stop)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接