无法理解MPI中的MPI_Reduce_scatter函数

14
我正在尝试理解MPI_Reduce_scatter函数,但似乎我的推断总是错误的:(
文档说明如下(link):

MPI_Reduce_scatter首先对由sendbuf、count和datatype定义的发送缓冲区中的元素数量为S(i)recvcounts[i]的向量进行逐元素约简。接下来,结果向量被分成n个不相交的段,其中n是组中进程的数量。第i个段包含recvcounts[i]个元素。第i个段被发送到进程i并存储在由recvbuf、recvcounts[i]和datatype定义的接收缓冲区中。

我有以下(非常简单的)C程序,我希望得到前recvcounts[i]个元素的最大值,但似乎我做错了什么...

#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"

#define NUM_PE 5
#define NUM_ELEM 3

char *print(int arr[], int n);

int main(int argc, char *argv[]) {
    int rank, size, i, n;
    int sendbuf[5][3] = {
        {  1,  2,  3 },
        {  4,  5,  6 },
        {  7,  8,  9 },
        { 10, 11, 12 },
        { 13, 14, 15 }
    };
    int recvbuf[15] = {0};
    int recvcounts[5] = {
        3, 3, 3, 3, 3
    };

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    n = sizeof(sendbuf[rank]) / sizeof(int);
    printf("sendbuf (thread %d): %s\n", rank, print(sendbuf[rank], n));

    MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

    n = sizeof(recvbuf) / sizeof(int);
    printf("recvbuf (thread %d): %s\n", rank, print(recvbuf, n)); // <--- I receive the same output as with sendbuf :(

    MPI_Finalize();

    return 0;
}

char *print(int arr[], int n) { } // it returns a string formatted as the following output

我的程序的输出对于recvbuf和sendbuf是相同的。我预期recvbuf应该包含最大值:
$ mpicc 03_reduce_scatter.c
$ mpirun -n 5 ./a.out
sendbuf (thread 4): [ 13, 14, 15 ]
sendbuf (thread 3): [ 10, 11, 12 ]
sendbuf (thread 2): [  7,  8,  9 ]
sendbuf (thread 0): [  1,  2,  3 ]
sendbuf (thread 1): [  4,  5,  6 ]
recvbuf (thread 1): [  4,  5,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 2): [  7,  8,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 0): [  1,  2,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 3): [ 10, 11, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 4): [ 13, 14, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
1个回答

24

是的,Reduce_scatter的文档很简洁,并且它并不是被广泛使用的,因此没有太多的示例。来自OCW MIT讲座的前几张幻灯片有一个好的图表,并提出了一个用例。

关键是像往常一样阅读MPI文档并特别注意给实现者的建议:

"MPI_REDUCE_SCATTER例程在功能上等同于:具有接收计数之和recvcounts[i]的MPI_REDUCE集合操作,后跟具有sendcounts等于recvcounts的MPI_SCATTERV。"

因此,让我们通过您的示例进行演示:

MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

这将相当于这个:

int totcounts = 15;  // = sum of {3, 3, 3, 3, 3}
MPI_Reduce({1,2,3...15}, tmpbuffer, totcounts, MPI_INT, MPI_MAX, 0, MPI_COMM_WORLD);
MPI_Scatterv(tmpbuffer, recvcounts, [displacements corresponding to recvcounts], 
              MPI_INT, rcvbuffer, 3, MPI_INT, 0, MPI_COMM_WORLD);

于是每个人都会提交相同的数字{1...15},然后这些数字的每一列将相互取最大值,得到{max(1,1...1), max(2,2...2) ... max(15,15...15)} = {1,2,...15}。

然后这些数字将被分配给处理器,每次3个,得到{1,2,3},{4,5,6},{7,8,9}...

这就是发生的事情,那么我们如何让你想要的事情发生呢?我明白你想要对每行进行最大值运算,并且每个处理器获取其对应的行最大值。例如,假设数据看起来像这样:

Proc 0: 1 5 9 13
Proc 1: 2 6 10 14
Proc 2: 3 7 11 15
Proc 3: 4 8 12 16

我们希望最终处理器0(假设)拥有所有第0个数据块中最大的值,处理器1拥有所有第1个数据块中最大的值,以此类推,这样我们最终会得到:

Proc 0: 4
Proc 1: 8
Proc 2: 12
Proc 3: 16

那么让我们看看如何实现。首先,每个人都将拥有一个值,因此所有的接收计数都为1。其次,每个进程将必须发送不同的数据。因此,我们将拥有类似以下的内容:

#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"

int main(int argc, char *argv[]) {
    int rank, size, i, n;

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    int sendbuf[size];
    int recvbuf;

    for (int i=0; i<size; i++)
        sendbuf[i] = 1 + rank + size*i;

    printf("Proc %d: ", rank);
    for (int i=0; i<size; i++) printf("%d ", sendbuf[i]);
    printf("\n");

    int recvcounts[size];
    for (int i=0; i<size; i++)
        recvcounts[i] = 1;

    MPI_Reduce_scatter(sendbuf, &recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

    printf("Proc %d: %d\n", rank, recvbuf);

    MPI_Finalize();

    return 0;
}

运行程序的输出结果为(已重新排列以提高清晰度):

Proc 0: 1 5 9 13 
Proc 1: 2 6 10 14 
Proc 2: 3 7 11 15
Proc 3: 4 8 12 16

Proc 0: 4
Proc 1: 8
Proc 2: 12
Proc 3: 16

4
哇!很少见到对如此具体的问题给出如此详尽和信息丰富的答案。非常感谢,也感谢您花时间写下这个详细的解释! - StockBreak

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接