什么是切割scipy.sparse矩阵的最快方法?

25

我通常使用

matrix[:, i:]

看起来速度没有我预期的快。


能为我们进行一下基准测试吗?甚至可以使用timeit来测试。 - IT Ninja
2个回答

22
为了得到一个稀疏矩阵作为输出,最快的行切片方法是使用csr类型,而对于列切片则是csc,详见这里。在这两种情况下,你只需要做你目前正在做的事情。
matrix[l1:l2, c1:c2]

如果你想要一个ndarray作为输出,直接在稀疏矩阵对象中进行切片可能会更快。你可以使用.A属性或.toarray()方法从稀疏矩阵中获取ndarray对象:

matrix.A[l1:l2, c1:c2] 
或者:
matrix.toarray()[l1:l2, c1:c2]

正如下面评论所提到的,如果数组足够大,将稀疏数组转换为密集数组可能会导致内存错误。


19
如果矩阵太大,使用matrix.A会导致内存错误。 - Munichong

14

我发现scipy.sparse.csr_matrix宣传的快速行索引可以通过自己编写行索引器来更快地实现。这是想法:

class SparseRowIndexer:
    def __init__(self, csr_matrix):
        data = []
        indices = []
        indptr = []

        # Iterating over the rows this way is significantly more efficient
        # than csr_matrix[row_index,:] and csr_matrix.getrow(row_index)
        for row_start, row_end in zip(csr_matrix.indptr[:-1], csr_matrix.indptr[1:]):
             data.append(csr_matrix.data[row_start:row_end])
             indices.append(csr_matrix.indices[row_start:row_end])
             indptr.append(row_end-row_start) # nnz of the row

        self.data = np.array(data)
        self.indices = np.array(indices)
        self.indptr = np.array(indptr)
        self.n_columns = csr_matrix.shape[1]

    def __getitem__(self, row_selector):
        data = np.concatenate(self.data[row_selector])
        indices = np.concatenate(self.indices[row_selector])
        indptr = np.append(0, np.cumsum(self.indptr[row_selector]))

        shape = [indptr.shape[0]-1, self.n_columns]

        return sparse.csr_matrix((data, indices, indptr), shape=shape)

换句话说,可以通过将每行的非零值分别存储在单独的数组中(每行长度不同),然后将所有这些行数组放入一个对象类型的数组中(允许每行具有不同大小)以便高效地进行索引,从而利用numpy数组的快速索引。列索引使用相同的方式进行存储。这种方法与标准的CSR数据结构略有不同,后者将所有非零值都存储在单个数组中,需要查找每行的起始位置和结束位置。这些查找可能会减慢随机访问速度,但对于连续行的检索应该是有效的。

性能分析结果

我的矩阵mat是一个1,900,000x1,250,000的csr_matrix,其中包含400,000,000个非零元素。ilocs是200,000个随机行索引的数组。

>>> %timeit mat[ilocs]
2.66 s ± 233 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

相比之下:

>>> row_indexer = SparseRowIndexer(mat)
>>> %timeit row_indexer[ilocs]
59.9 ms ± 4.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
SparseRowIndexer在使用高级索引时似乎比使用布尔掩码更快。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接