这不是一个改进的工作答案,而是对于稀疏索引和 "triu" 的探索。它可能会给你一些更直接计算的想法。你从 tri 开始,并期望得到 tri,这意味着这不是一个简单的任务,甚至使用密集数组(其索引速度要快得多)也不是。
"sparse.csr" 索引使用矩阵乘法。我将用密集数组来说明这一点:
In [304]: X = np.array([
...: [1, 2, 3, 3, 1],
...: [0, 1, 3, 3, 2],
...: [0, 0, 1, 1, 3],
...: [0, 0, 0, 1, 3],
...: [0, 0, 0, 0, 1],
...: ])
In [305]: idx = np.array([1,2,4,2])
In [306]: X[idx[:,None],idx]
Out[306]:
array([[1, 3, 2, 3],
[0, 1, 3, 1],
[0, 0, 1, 0],
[0, 1, 3, 1]])
In [307]: m = np.array([[0,1,0,0,0],[0,0,1,0,0],[0,0,0,0,1],[0,0,1,0,0]])
In [308]: m@X@m.T
Out[308]:
array([[1, 3, 2, 3],
[0, 1, 3, 1],
[0, 0, 1, 0],
[0, 1, 3, 1]])
并且使用完整的距离数组:
In [309]: X2 = X+X.T-np.diag(np.diag(X))
In [311]: X2[idx[:,None],idx]
Out[311]:
array([[1, 3, 2, 3],
[3, 1, 3, 1],
[2, 3, 1, 3],
[3, 1, 3, 1]])
In [312]: m@X2@m.T
Out[312]:
array([[1, 3, 2, 3],
[3, 1, 3, 1],
[2, 3, 1, 3],
[3, 1, 3, 1]])
我不知道是否可能从
X
(或
X2
)直接构造出提供所需结果的
m
,无论是上三角还是其他形式。
In [316]: sparse.triu(Out[312])
Out[316]:
<4x4 sparse matrix of type '<class 'numpy.int64'>'
with 10 stored elements in COOrdinate format>
In [317]: _.A
Out[317]:
array([[1, 3, 2, 3],
[0, 1, 3, 1],
[0, 0, 1, 3],
[0, 0, 0, 1]])
的作用是:
In [331]: A = sparse.coo_matrix(_312)
...: mask = A.row <= A.col
In [332]: A
Out[332]:
<4x4 sparse matrix of type '<class 'numpy.int64'>'
with 16 stored elements in COOrdinate format>
In [333]: mask
Out[333]:
array([ True, True, True, True, False, True, True, True, False,
False, True, True, False, False, False, True])
这个
mask
数组有16项,
A.nnz
。
然后它从
A
的属性中选择数据/行/列数组,并生成一个新的
coo
矩阵:
In [334]: d=A.data[mask]
In [335]: r=A.row[mask]
In [336]: c=A.col[mask]
In [337]: d
Out[337]: array([1, 3, 2, 3, 1, 3, 1, 1, 3, 1])
In [338]: sparse.coo_matrix((d, (r,c)))
Out[338]:
<4x4 sparse matrix of type '<class 'numpy.int64'>'
with 10 stored elements in COOrdinate format>
In [339]: _.A
Out[339]:
array([[1, 3, 2, 3],
[0, 1, 3, 1],
[0, 0, 1, 3],
[0, 0, 0, 1]])
"
np.triu
使用类似于mask
的方式:
"
In [349]: np.tri(4,4,-1)
Out[349]:
array([[0., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 1., 0.]])
triu
来节省内存是个好主意。但是我开始有这样的印象,即这并不值得。创建triu
似乎会消耗大量内存。它是否实现了布尔掩码? - Gregor SturmX.nnz / X.shape[0]**2
= 8.28e-05,翻译为中文是:远小于1%。 - Gregor Sturm