cuBLAS中的矩阵乘法转置教程

4
问题很简单:我有两个矩阵A和B,都是M乘N的,其中M远大于N。我想先对A进行转置,然后将其乘以B(A^T * B)得到C,C是N乘N的。我已经为A和B设置好了一切,但如何正确调用cublasSgemm而不返回错误答案呢?
我知道cuBlas有一个cublasOperation_t枚举类型,用于事先转置,但不知怎么使用它。我的矩阵A和B按行主序存储在设备内存中,即[ row1 ][ row2 ][ row3 ]...。这意味着要将A解释为A-transposed,BLAS需要知道我的A是按列主序排列的。我的当前代码如下:
float *A, *B, *C;
// initialize A, B, C as device arrays, fill them with values
// initialize m = num_row_A, n = num_row_B, and k = num_col_A;
// set lda = m, ldb = k, ldc = m;
// alpha = 1, beta = 0;
// set up cuBlas handle ...

cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc);

我的问题:

我是否正确设置了m、k、n?

lda、ldb、ldc怎么样?

谢谢!


你是在询问如何计算 (A^TB)(A^TB) 吗? - talonmies
1个回答

13

由于cuBLAS始终假设矩阵以列主序存储,因此您可以通过使用cublas_geam()将矩阵先转置为列主序,或者将存储在行主序中的矩阵A视为存储在列主序中的新矩阵AT。矩阵AT实际上是A的转置。对于B,也要进行相同的操作。然后,您可以通过C=AT * BT^T计算以列主序存储的矩阵C。

float* AT = A;
float* BT = B;

领先维度是与存储相关的一个参数,不管您是否使用转置标志CUBLAS_OP_T,它都不会改变。

lda = num_col_A = num_row_AT = N;
ldb = num_col_B = num_row_BT = N;
ldc = num_row_C = N;

cuBLAS中的mn是GEMM例程中结果矩阵C的行数和列数。

m = num_row_C = num_row_AT = num_col_A = N;
n = num_col_C = num_row_BT = num_col_B = N;

k是矩阵A^T和B的公共维度,

k = num_col_AT = num_row_B = M;

然后,您可以通过以下方式调用GEMM例程:

cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, k, &alpha, AT, lda, BT, ldb, &beta, C, ldc);

如果您想要将矩阵 C 存储在行主序中,可以使用公式 CT = BT * AT^T 计算以列主序存储的 CT。
cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, BT, ldb, AT, lda, &beta, CT, ldc);

请注意,在这种情况下,由于C是一个方阵,因此您不必交换mn

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接