Numba矩阵向量乘法

6

我正在尝试使用numbapro编写一个简单的矩阵向量乘法:

from numbapro import cuda
from numba import *
import numpy as np
import math
from timeit import default_timer as time

n = 100

@cuda.jit('void(float32[:,:], float32[:], float32[:])')
def cu_matrix_vector(A, b, c):
    y, x = cuda.grid(2)
    if y < n:
        c[y] = 0.0

    if x < n and y < n:
        for i in range(n):
            c[y] += A[y, i] * b[i]


A = np.array(np.random.random((n, n)), dtype=np.float32)
B = np.array(np.random.random((n, 1)), dtype=np.float32)
C = np.empty_like(B)

s = time()
dA = cuda.to_device(A)
dB = cuda.to_device(B)
dC = cuda.to_device(C)
cu_matrix_vector(dA, dB, dC)
dC.to_host()

e = time()
tcuda = e - s

但是我遇到了以下错误:
numbapro.cudadrv.error.CudaDriverError: CUDA_ERROR_LAUNCH_FAILED 复制内存D->H失败
我不明白为什么设备到主机的复制会失败。请帮忙解决。
1个回答

6

您的代码存在多个问题。

  1. B和C向量是Nx1的2D矩阵,而不是1D向量,但是您的内核类型签名将它们列为“float32[:]”-- 1D向量。它还使用单个索引对它们进行索引,这会导致GPU运行时错误,因为访问不对齐(在此处使用cuda-memcheck!)
  2. 您的内核假设有一个2D网格,但只使用其中的1列--这意味着许多线程执行相同的计算并互相覆盖。
  3. 没有给出执行配置,因此NumbaPro正在启动1个线程块的内核。(在此处使用nvprof!)

这是一段可行的代码。请注意,它使用1D块的1D网格,并循环遍历矩阵的列。因此,它针对向量/矩阵中的行数较大的情况进行了优化。对于短而宽的矩阵进行优化的内核需要使用另一种方法(并行归约)。但是我建议使用CUBLAS sgemv(也在NumbaPro中公开)。

from numbapro import cuda
from numba import *
import numpy as np
import math
from timeit import default_timer as time

m = 100000 
n = 100

@cuda.jit('void(f4[:,:], f4[:], f4[:])')
def cu_matrix_vector(A, b, c):
    row = cuda.grid(1)
    if (row < m):
        sum = 0

        for i in range(n):
            sum += A[row, i] * b[i]

        c[row] = sum

A = np.array(np.random.random((m, n)), dtype=np.float32)
B = np.array(np.random.random(m), dtype=np.float32)
C = np.empty_like(B)

s = time()
dA = cuda.to_device(A)
dB = cuda.to_device(B)
dC = cuda.to_device(C)

cu_matrix_vector[(m+511)/512, 512](dA, dB, dC)

dC.to_host()

print C

e = time()
tcuda = e - s

即使进行了编辑,我仍然得到相同的错误。具体错误如下:Traceback (most recent call last): File "<stdin>", line 1, in <module> File "matrixVector.py", line 34, in <module> dC.to_host() 复制内存D->H失败。 - kirikoumath
你说得对。实际上还有更多的问题 - 我重写了我的答案。 - harrism

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接