BLAS sgemm/dgemm是如何工作的?

3

我正在尝试使用Python中的ctypes调用BLAS库中的sgemm函数。下面这段代码可以成功地解决C = A x B

no_trans = c_char("n")
m = c_int(number_of_rows_of_A)
n = c_int(number_of_columns_of_B)
k = c_int(number_of_columns_of_A)
one = c_float(1.0)
zero = c_float(0.0)

blaslib.sgemm_(byref(no_trans), byref(no_trans), byref(m), byref(n), byref(k),
               byref(one), A, byref(m), B, byref(k), byref(zero), C, byref(m))

现在我想要解决这个方程:C = A' x A,其中A'A的转置。以下代码能够正常运行但返回的结果是错误的:

trans = c_char("t")
no_trans = c_char("n")
m = c_int(number_of_rows_of_A)
n = c_int(number_of_columns_of_A)
one = c_float(1.0)
zero = c_float(0.0)

blaslib.sgemm_(byref(trans), byref(no_trans), byref(n), byref(n), byref(m),
               byref(one), A, byref(m), A, byref(m), byref(zero), C, byref(n))

为了进行一项测试,我插入了一个矩阵 A = [1 2; 3 4]。正确的结果是 C = [10 14; 14 20],但是 sgemm 程序输出的是 C = [5 11; 11 25]
据我所理解,由于算法会处理,所以我不需要对矩阵 A 进行转置。在第二种情况下,我的参数传递有什么问题?
非常感谢任何帮助、链接、文章或建议!
2个回答

7

Blas通常使用列主序矩阵(类似于Fortran),因此A = [1 2; 3 4]表示

    |1 3|   
A = |   |
    |2 4|

假设你的Python库也是这样做的,那么结果是正确的。请参阅此自述文件


1
你得到的结果表明sgemm计算的是A*A'而不是你想要的A'*A。简单的解决方法是交换函数的两个输入。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接