识别稀疏矩阵中有值的行(或列)

4
我需要确定在一个大的稀疏布尔矩阵中有哪些行(/列)具有定义值。我想使用这个来1.通过这些行/列切片(实际上是view)矩阵;和2.切片(/view)与矩阵相同维度的向量和矩阵的边缘。即结果可能应该是索引/布尔值的向量或(最好是)迭代器。

我已经尝试了显而易见的方法:

a = sprand(10000, 10000, 0.01)
cols = unique(a.colptr)
rows = unique(a.rowvals)

但是在我的机器上,每个这样的操作都需要大约20毫秒,可能是因为它们分配了大约1MB的内存(至少分配了colsrows)。这是一个性能关键函数,所以我希望代码能够被优化。基础代码似乎有一个用于稀疏矩阵的nzrange迭代器,但是我不容易看出如何将其应用到我的情况中。
是否有建议的方法来解决这个问题?
第二个问题:我还需要对我的稀疏矩阵视图执行此操作 - 是否类似于x = view(a,:,:); cols = unique(x.parent.colptr[x.indices[:,2]])或者是否有专门针对此类操作的功能?稀疏矩阵的视图似乎很棘手(参见https://discourse.julialang.org/t/slow-arithmetic-on-views-of-sparse-matrices/3644 - 不是一个跨帖子)。
非常感谢!

cols 实现存在一个 bug,例如 unique(speye(2).colptr) 得到的是一个 3 元素向量。更多信息请参见完整答案。 - Dan Getz
2个回答

7

关于获取稀疏矩阵的非零行和列,以下函数应该是相当高效的:

nzcols(a::SparseMatrixCSC) = collect(i 
  for i in 1:a.n if a.colptr[i]<a.colptr[i+1])

function nzrows(a::SparseMatrixCSC)
    active = falses(a.m)
    for r in a.rowval
        active[r] = true
    end
    return find(active)
end

对于一个10_000x10_000的矩阵,密度为0.1,列和行分别需要0.2毫秒和2.9毫秒。除了正确性问题外,它应该比所提出的方法更快。
关于稀疏矩阵的视图,一个快速的解决方案是将视图转换为稀疏矩阵(例如使用b = sparse(view(a,100:199,100:199))),然后使用上述函数。在代码中:
nzcols(b::SubArray{T,2,P}) where {T,P<:AbstractSparseArray} = nzcols(sparse(b))
nzrows(b::SubArray{T,2,P}) where {T,P<:AbstractSparseArray} = nzrows(sparse(b))

更好的解决方案是根据视图自定义函数。例如,当视图同时使用UnitRanges作为行和列时:
# utility predicate returning true if element of sorted v in range r
inrange(v,r) = searchsortedlast(v,last(r))>=searchsortedfirst(v,first(r))

function nzcols(b::SubArray{T,2,P,Tuple{UnitRange{Int64},UnitRange{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    return collect(i+1-start(b.indexes[2]) 
      for i in b.indexes[2]
      if b.parent.colptr[i]<b.parent.colptr[i+1] && 
        inrange(b.parent.rowval[nzrange(b.parent,i)],b.indexes[1]))
end

function nzrows(b::SubArray{T,2,P,Tuple{UnitRange{Int64},UnitRange{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    active = falses(length(b.indexes[1]))
    for c in b.indexes[2]
        for r in nzrange(b.parent,c)
            if b.parent.rowval[r] in b.indexes[1]
                active[b.parent.rowval[r]+1-start(b.indexes[1])] = true
            end
        end
    end
    return find(active)
end

这些版本比完整矩阵的版本工作速度更快(对于上述10,000x10,000矩阵的100x100子矩阵,我的机器上列和行分别需要16μs和12μs,但这些结果不稳定)。

一个适当的基准测试应该使用固定的矩阵(或至少固定随机种子)。如果我进行了这样的基准测试,我将编辑此行。


1
谢谢,这些速度快多了! 我有点惊讶需要在nzcols中“收集”生成器才能用于索引,这似乎本能地应该是可能的。 我已经发现b.indexes[2](和1)的值为Colon(),在这种情况下函数会失败,但我应该能够自己解决它 :-) (可能通过为startinrange定义包装函数)。 - Michael K. Borregaard
1
不需要使用collect,但对于nzrows来说,生成器不太明显,因此我希望保持行和列的返回类型相同。 - Dan Getz
1
此外,还有一个更快(但可读性较差)的 nzcols 函数:nzcols(a::SparseMatrixCSC) = (res=Vector{Int}(); foldl((x,y)->(if (x<a.colptr[y]) push!(res,y-1) end; a.colptr[y]),a.colptr[1],2:a.n+1) ; res) - Dan Getz
1
并且 nzrows 的一行代码为:nzrows(a::SparseMatrixCSC) = find(setindex!(falses(a.m),true,a.rowval)) - Dan Getz
1
非常感谢 :-) 我一直在尝试解密代码,看看是否可以自己修复 - 我肯定会学到很多关于稀疏矩阵的工作原理! - Michael K. Borregaard
显示剩余11条评论

2

如果索引不是范围,那么将其转换为稀疏矩阵的回退方法可行,但这里还有针对向量索引的版本。如果索引是混合的,则需要创建另一组版本。虽然很重复,但这正是 Julia 的优势,当版本完成后,代码将根据调用者中的类型正确选择优化方法,而不需要太多的努力。

function sortedintersecting(v1, v2)
    i,j = start(v1), start(v2)
    while i <= length(v1) && j <= length(v2)
        if v1[i] == v2[j] return true
        elseif v1[i] > v2[j] j += 1
        else i += 1
        end
    end
    return false
end

function nzcols(b::SubArray{T,2,P,Tuple{Vector{Int64},Vector{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    brows = sort(unique(b.indexes[1]))
    return [k 
      for (k,i) in enumerate(b.indexes[2])
      if b.parent.colptr[i]<b.parent.colptr[i+1] && 
        sortedintersecting(brows,b.parent.rowval[nzrange(b.parent,i)])]
end

function nzrows(b::SubArray{T,2,P,Tuple{Vector{Int64},Vector{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    active = falses(length(b.indexes[1]))
    for c in b.indexes[2]
      active[findin(b.indexes[1],b.parent.rowval[nzrange(b.parent,c)])] = true
    end
    return find(active)
end

-- ADDENDUM --

由于发现使用 Vector{Int} 索引的 nzrows 速度有点慢,因此尝试通过使用利用排序性能的版本来替换 findin 来提高其速度:

function findin2(inds,v,w)
    i,j = start(v),start(w)
    res = Vector{Int}()
    while i<=length(v) && j<=length(w)
        if v[i]==w[j]
            push!(res,inds[i])
            i += 1
        elseif (v[i]<w[j]) i += 1
        else j += 1
        end
    end
    return res
end

function nzrows(b::SubArray{T,2,P,Tuple{Vector{Int64},Vector{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    active = falses(length(b.indexes[1]))
    inds = sortperm(b.indexes[1])
    brows = (b.indexes[1])[inds] 
    for c in b.indexes[2]
      active[findin2(inds,brows,b.parent.rowval[nzrange(b.parent,c)])] = true
    end
    return find(active)
end

1
原来sortedintersecting很有用 :)。我从评论中复制了版本并稍微调整了一下间距。 - Dan Getz
1
nzrows 中的 active[...] 行非常神奇 - 我有点惊讶它如此简洁地工作。 - Dan Getz
1
非常感谢!如果我能将两个都标记为答案就好了! :-) - Michael K. Borregaard
好的!马上完成了。 - Michael K. Borregaard
1
实际上它确实可以 - 只比UnitRange版本慢9倍,这可能是你能得到的最快速度。有趣的是,我昨天也尝试了同样的事情(我创建了一个与你的findin2完全相同的函数),但由于nzrows调用函数中的遗漏错误,我没有注意到它的基准测试非常糟糕。很高兴看到我走在了正确的轨道上! - Michael K. Borregaard
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接