做类似以下的事情
import numpy as np
a = np.random.rand(10**4, 10**4)
b = np.dot(a, a)
使用多个核心,运行效果良好。
然而,a
中的元素是 64 位浮点数(或在 32 位平台上为 32 位浮点数?),我想要乘以 8 位整数数组。尝试以下代码:
a = np.random.randint(2, size=(n, n)).astype(np.int8)
这导致点积在我的电脑上不能使用多个内核,因此运行速度慢了约1000倍。
array: np.random.randint(2, size=shape).astype(dtype)
dtype shape %time (average)
float32 (2000, 2000) 62.5 ms
float32 (3000, 3000) 219 ms
float32 (4000, 4000) 328 ms
float32 (10000, 10000) 4.09 s
int8 (2000, 2000) 13 seconds
int8 (3000, 3000) 3min 26s
int8 (4000, 4000) 12min 20s
int8 (10000, 10000) It didn't finish in 6 hours
float16 (2000, 2000) 2min 25s
float16 (3000, 3000) Not tested
float16 (4000, 4000) Not tested
float16 (10000, 10000) Not tested
我了解NumPy使用的是BLAS库,它不支持整数,但如果我使用SciPy的BLAS包装器,即
import scipy.linalg.blas as blas
a = np.random.randint(2, size=(n, n)).astype(np.int8)
b = blas.sgemm(alpha=1.0, a=a, b=a)
计算是多线程的。现在,对于float32,blas.sgemm与np.dot完全具有相同的时间运行,但对于非floats,它将所有内容转换为float32并输出浮点数,而这是np.dot所不做的。(此外,b现在以F_CONTIGUOUS顺序排列,这是一个较小的问题)。
因此,如果我想进行整数矩阵乘法,我必须执行以下操作之一:
1.使用NumPy的极慢的np.dot,并且很高兴保留8位整数。 2.使用SciPy的sgemm并使用4倍内存。 3.使用Numpy的np.float16,并仅使用2倍内存,但要注意np.dot在float16数组上比在float32数组上慢得多,特别是int8。 4.查找优化的库以进行多线程整数矩阵乘法(实际上,Mathematica可以做到这一点,但我更喜欢Python解决方案),最好支持1位数组,尽管8位数组也可以...(我实际上旨在对Z / 2Z上的矩阵进行乘法,我知道我可以用Sage做到这一点,它非常像Python,但是,还有严格的Python吗?)
我能按照选项4吗?是否存在这样的库?
免责声明:我实际上正在运行NumPy + MKL,但我已经尝试过vanilly NumPy上的类似测试,并获得了类似的结果。
numpy.einsum
,但这可能是一个不错的选择。 - user2379410