CUDA如何获取网格、块、线程大小并并行化非正方形矩阵计算

23

我是CUDA的新手,需要帮助理解一些内容。我需要帮助将这两个for循环并行化。具体来说,如何设置dimBlock和dimGrid以使其运行更快。我知道这看起来像sdk中的向量加法示例,但该示例仅适用于方阵,并且当我尝试修改该代码以适应我的128 x 1024矩阵时,它无法正常工作。

__global__ void mAdd(float* A, float* B, float* C)
{
    for(int i = 0; i < 128; i++)
    {
        for(int j = 0; j < 1024; j++)
        {
            C[i * 1024 + j] = A[i * 1024 + j] + B[i * 1024 + j];
        }
    }
}

这段代码是较大循环中的一部分,它是最简单的部分,所以我决定尝试并行化,并同时学习CUDA。我已经阅读了指南,但仍不明白如何正确设置网格/块/线程的数量并有效地使用它们。


5
pycuda中,只需使用C[i] = A[i] + B[i]语句即可完成操作。具体见demo.py - jfs
1个回答

44
你所写的内核是完全串行的。每个启动执行的线程都将执行相同的工作。
CUDA(以及OpenCL和其他类似的“单程序,多数据”类型编程模型)背后的主要思想是,你会进行“数据并行”操作 - 所以需要执行相同的、大部分独立的操作多次 - 并编写一个可以执行该操作的内核。然后启动大量(半)自主线程在输入数据集上执行该操作。
在数组添加示例中,数据并行操作是:
C[k] = A[k] + B[k];

对于所有k在0到128 * 1024之间的情况。每个加法操作都是完全独立的,没有排序要求,因此可以由不同的线程执行。为了在CUDA中表示这个,可能会像这样编写内核:

__global__ void mAdd(float* A, float* B, float* C, int n)
{
    int k = threadIdx.x + blockIdx.x * blockDim.x;

    if (k < n)
        C[k] = A[k] + B[k];
}

[免责声明:代码是在浏览器中编写的,未经测试,使用风险自负]

这里,串行代码中的内部和外部循环被一个CUDA线程代替进行每个操作,并且我在代码中添加了一个限制检查,以便在启动的线程数超过所需操作数的情况下,不会发生缓冲区溢出。如果像这样启动内核:

const int n = 128 * 1024;
int blocksize = 512; // value usually chosen by tuning and hardware constraints
int nblocks = n / blocksize; // value determine by block size and total work

madd<<<nblocks,blocksize>>>mAdd(A,B,C,n);

然后,将启动256个块,每个块包含512个线程,以并行执行数组加法操作。请注意,如果输入数据大小不能表示为块大小的整数倍,则块数需要上取整以覆盖完整的输入数据集。

上述所有内容都是对CUDA范式进行了极度简化的概述,只是一个非常微不足道的操作,但也许足够让您继续学习。如今,CUDA已经相当成熟,并且有很多好的、免费的教育材料在网上流传,您可以使用它们来进一步阐明我在此答案中忽略的编程模型的许多方面。


1
int k = threadIdx.x + gridDim.x * blockDim.x; 这肯定是不正确的吧?在你的例子中,gridDim.x * blockDim.x 总是等于 256*512。应该是 int k = threadIdx.x + blockIdx.x * blockDim.x; 我试图编辑它,但被拒绝了。 - Ozone
1
警告给浏览速读者:nblocks = ceil(n / nthreads); // 如果你的数据不能完美地被分割。 - ofer.sheffer
@ofer.sheffer:我确实写了“请注意,如果输入数据大小不能表示为块大小的整数倍,则块数需要向上舍入以覆盖完整的输入数据集。”这不够清楚吗? - talonmies
1
@talonmies,您的答案非常好,我为它投了赞成票。另一方面,当我阅读它时,我在想“他忘了+1”,以防数据不能均匀划分...然后我继续阅读其他一些东西,回到这里完成阅读时,我注意到您已经写上了。作为一个通常只看代码并考虑稍后阅读每个字的粗略阅读者 - 我觉得我的警告会帮助我的未来自己。 - ofer.sheffer
我怎么知道 nthreads?不是 blocksize 就是线程数吗? - smcs

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接