numpy数组和scipy稀疏矩阵的张量点积

6

针对一个当前项目,我需要计算一批与同一个相当稀疏的矩阵相关联的向量的内积。这些向量是与二维网格相关联的,所以我将这些向量存储在三维数组中:

例如:

X是一个维度为(I,J,N)的数组。矩阵A的尺寸是(N,N)。任务是为每个在I,J中计算A.dot(X[i,j])

对于numpy数组,可以轻松完成此操作:

Y = X.dot(A.T) 

现在我想将A存储为稀疏矩阵,因为它是稀疏的,并且只包含非常有限数量的非零条目,这导致了许多不必要的乘法。不幸的是,上面的解决方案行不通,因为numpy点积不能与稀疏矩阵一起使用。据我所知,scipy稀疏矩阵没有类似于tensordot的操作。
有人知道用稀疏矩阵A计算上述数组Y的一个好的有效方法吗?
2个回答

3
显而易见的方法是在向量上运行循环,并使用稀疏矩阵的.dot方法:
def naive_sps_x_dense_vecs(sps_mat, dense_vecs):
    rows, cols = sps_mat.shape
    I, J, _ = dense_vecs.shape
    out = np.empty((I, J, rows))
    for i in xrange(I):
        for j in xrange(J):
            out[i, j] = sps_mat.dot(dense_vecs[i, j])
    return out

但是您可以通过将三维数组重塑为二维数组并避免使用Python循环来加快速度:

def sps_x_dense_vecs(sps_mat, dense_vecs):
    rows, cols = sps_mat.shape
    vecs_shape = dense_vecs.shape
    dense_vecs = dense_vecs.reshape(-1, cols)
    out = sps_mat.dot(dense_vecs.T).T
    return out.reshape(vecs.shape[:-1] + (rows,))

问题在于我们需要将稀疏矩阵作为第一个参数,以便我们可以调用其.dot方法,这意味着返回的结果是转置的,这又意味着在转置后,最后的reshape操作将触发整个数组的复制。因此,对于较大的IJ值,结合不太大的N值,后一种方法将比前一种方法快几倍,但对于其他参数组合,性能甚至可能相反:
n, i, j = 100, 500, 500
a = sps.rand(n, n, density=1/n, format='csc')
vecs = np.random.rand(i, j, n)

>>> np.allclose(naive_sps_x_dense_vecs(a, vecs), sps_x_dense_vecs(a, vecs))
True

n, i, j = 100, 500, 500
%timeit naive_sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 3.85 s per loop
%timeit sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 576 ms per 

n, i, j = 1000, 200, 200
%timeit naive_sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 791 ms per loop
%timeit sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 1.3 s per loop

0
你可以使用 jax 来实现你想要的功能。假设你的稀疏矩阵是以 csr_array 格式存储的,你需要先将其转换为一个 jax BCOO array
from scipy import sparse
from jax.experimental import sparse as jaxsparse
import jax.numpy as jnp

def convert_to_BCOO(x):
    x = x.transpose()  #get the transpose
    x = x.tocoo()
    x = jaxsparse.BCOO((x.data, jnp.column_stack((x.row, x.col))),
                       shape=x.shape)
    x = L.sort_indices()

您可以使用 jax.sparsify 创建如下所示的稀疏点积。
def dot(x, y):
    return jnp.dot(x, y)
sp_dot = jaxsparse.sparsify(dot)

A_transpose = convert_to_BCOO(A)
Y = sp_dot(X,A_transpose)

函数sp_dot现在遵循与numpy.dot完全相同的规则。

希望这可以帮到你!


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接