使用Numba计算内积的正确方法

3

我正在尝试计算两个大矩阵的内积。似乎当尝试计算点积时,numpy会创建矩阵的副本,这导致了一些内存问题。在Google搜索后,我发现numba包很有前途。然而,我无法使其正常工作。以下是我的代码:

import numpy as np
from numba import jit
import time, contextlib



@contextlib.contextmanager
def timeit():
    t=time.time()
    yield
    print(time.time()-t,"sec")


def dot1(a,b):
    return np.dot(a,b)

@jit(nopython=True)
def dot2(a,b):
    n = a.shape[0]
    m = b.shape[1]
    K = b.shape[0]
    c = np.zeros((n,m))
    for i in xrange(n):
        for j in xrange(m):
            for k in range(K):
                c[i,j] += a[i,k]*b[k,j]

    return c



def main():
    a = np.random.random((200,1000))
    b = np.random.random((1000,400))

    with timeit():
        c1 = dot1(a,b)
    with timeit():
        c2 = dot2(a,b)

具有以下运行时间:

dot1:
(0.034691810607910156, 'sec')

dot2:
(0.9215810298919678, 'sec')

有人能告诉我这里缺少什么吗?


很可能矩阵乘法已经在底层(可能使用CPU缓存/向量优化)中得到了很好的优化(在C甚至FORTRAN中),因此即时编译无法超越它。 - user707650
我查看了内存需求,但在执行np.dot时,我没有看到任何矩阵重复的证据。可能会有一些开销(我还没有深入研究),但肯定不需要两倍的内存。 - user707650
@Evert 请查看 http://wiki.scipy.org/PerformanceTips 的“大数组上的线性代数”部分。其中提到:虽然 C 只有40x40,但在执行点积操作期间检查内存使用情况将指出正在进行复制。原因是点积运算使用基础的BLAS操作,这些操作依赖于以连续的C顺序存储矩阵。 - Moj
我不了解numba。至于内存问题,请查看[https://github.com/numpy/numpy/issues/4062],看看是否相关。如果您正在对矩阵的转置进行“点”乘运算(我可以确认这个问题),或者使用numpy < 1.8,那么有可能会遇到内存问题。 - user707650
@Evert 我不是在对转置矩阵进行点乘,而是在对一个形状为转置的矩阵进行点乘。我的意思是 n by mm by n 进行点乘,其中 m 大约为 80000。帖子中使用 np.rollaxis 的解决方案似乎很有趣,但我不知道它是如何工作的,也不知道该如何将其应用到我的情况中。对此有什么建议吗? - Moj
创建一个关于内存问题的新问题,并提供一个大家可以验证的示例(即,创建一个自包含的示例,使用约0.5到2 GB的内存,这是大多数人都可以处理的)。; 这个问题是关于numba以及(缺乏)加速的。 - user707650
1个回答

1
你的算法是朴素算法。BLAS实现了更快的算法。
引用维基百科的矩阵乘法页面:
尽管如此,它出现在几个库中,例如BLAS,在矩阵维度大于100时效率显着提高。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接