我正在尝试计算两个大矩阵的内积。似乎当尝试计算点积时,numpy
会创建矩阵的副本,这导致了一些内存问题。在Google搜索后,我发现numba
包很有前途。然而,我无法使其正常工作。以下是我的代码:
import numpy as np
from numba import jit
import time, contextlib
@contextlib.contextmanager
def timeit():
t=time.time()
yield
print(time.time()-t,"sec")
def dot1(a,b):
return np.dot(a,b)
@jit(nopython=True)
def dot2(a,b):
n = a.shape[0]
m = b.shape[1]
K = b.shape[0]
c = np.zeros((n,m))
for i in xrange(n):
for j in xrange(m):
for k in range(K):
c[i,j] += a[i,k]*b[k,j]
return c
def main():
a = np.random.random((200,1000))
b = np.random.random((1000,400))
with timeit():
c1 = dot1(a,b)
with timeit():
c2 = dot2(a,b)
具有以下运行时间:
dot1:
(0.034691810607910156, 'sec')
dot2:
(0.9215810298919678, 'sec')
有人能告诉我这里缺少什么吗?
np.dot
时,我没有看到任何矩阵重复的证据。可能会有一些开销(我还没有深入研究),但肯定不需要两倍的内存。 - user707650n by m
与m by n
进行点乘,其中m
大约为80000
。帖子中使用np.rollaxis
的解决方案似乎很有趣,但我不知道它是如何工作的,也不知道该如何将其应用到我的情况中。对此有什么建议吗? - Moj