从Scipy稀疏矩阵中获取唯一行

5

我正在使用Python处理稀疏矩阵,想知道有没有一种有效的方法来删除稀疏矩阵中重复的行,只保留唯一的行。

我没有找到相应的函数,也不确定如何在不将稀疏矩阵转换为密集矩阵并使用numpy.unique的情况下完成该操作。


scipy中没有相关的内容。使用带有新的axis参数的np.unique可能是最好的选择。如果你必须使用sparse,我建议看一下lil格式以及它的'raw'行和数据属性。 - hpaulj
2个回答

7

没有快速的方法可以做到这一点,所以我不得不编写一个函数。它返回一个稀疏矩阵,其中包含输入稀疏矩阵的唯一行(axis=0)或列(axis=1)。 请注意,返回矩阵的唯一行或列未按字典顺序排序(与np.unique不同)。

import numpy as np
import scipy.sparse as sp

def sp_unique(sp_matrix, axis=0):
    ''' Returns a sparse matrix with the unique rows (axis=0)
    or columns (axis=1) of an input sparse matrix sp_matrix'''
    if axis == 1:
        sp_matrix = sp_matrix.T

    old_format = sp_matrix.getformat()
    dt = np.dtype(sp_matrix)
    ncols = sp_matrix.shape[1]

    if old_format != 'lil':
        sp_matrix = sp_matrix.tolil()

    _, ind = np.unique(sp_matrix.data + sp_matrix.rows, return_index=True)
    rows = sp_matrix.rows[ind]
    data = sp_matrix.data[ind]
    nrows_uniq = data.shape[0]

    sp_matrix = sp.lil_matrix((nrows_uniq, ncols), dtype=dt)  #  or sp_matrix.resize(nrows_uniq, ncols)
    sp_matrix.data = data
    sp_matrix.rows = rows

    ret = sp_matrix.asformat(old_format)
    if axis == 1:
        ret = ret.T        
    return ret


def lexsort_row(A):
    ''' numpy lexsort of the rows, not used in sp_unique'''
    return A[np.lexsort(A.T[::-1])]

if __name__ == '__main__':    
    # Test
    # Create a large sparse matrix with elements in [0, 10]
    A = 10*sp.random(10000, 3, 0.5, format='csr')
    A = np.ceil(A).astype(int)

    # unique rows
    A_uniq = sp_unique(A, axis=0).toarray()
    A_uniq = lexsort_row(A_uniq)
    A_uniq_numpy = np.unique(A.toarray(), axis=0)
    assert (A_uniq == A_uniq_numpy).all()

    # unique columns
    A_uniq = sp_unique(A, axis=1).toarray()
    A_uniq = lexsort_row(A_uniq.T).T
    A_uniq_numpy = np.unique(A.toarray(), axis=1)
    assert (A_uniq == A_uniq_numpy).all()  

优秀的回答,代码质量很高。谢谢! - Yohan Obadia

1

一个也可以使用切片。

def remove_duplicate_rows(data):
    unique_row_indices, unique_columns = [], []
    for row_idx, row in enumerate(data):
        indices = row.indices.tolist()
        if indices not in unique_columns:
            unique_columns.append(indices)
            unique_row_indices.append(row_idx)
    return data[unique_row_indices]

我发现这对于我在监督式机器学习环境中的应用特别有帮助。在那里,我的函数输入是数据和标签。通过这种方法,我可以轻松返回结果。
labels[unique_row_indices]

此外,确保在此清理后,数据和标签是相符的。


请注意,此解决方案不适用于lil_matrix格式;对我而言,csr_matrix可行。 - dafinguzman
此外,它不仅会删除重复的行,还会删除所有具有重复列索引的行。例如,它将删除 scipy.sparse.csr_matrix(np.array([[1, 1], [2, 2]])) 的第二行。 - dafinguzman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接