如何在Apache Spark中计算RowMatrix的逆矩阵?

7

我有一个以RowMatrix形式分布的矩阵X。我正在使用Spark 1.3.0。我需要能够计算X的逆。


一个算法在https://arxiv.org/pdf/1801.04723.pdf中有描述。 - Andrew
3个回答

8
import org.apache.spark.mllib.linalg.{Vectors,Vector,Matrix,SingularValueDecomposition,DenseMatrix,DenseVector}
import org.apache.spark.mllib.linalg.distributed.RowMatrix

def computeInverse(X: RowMatrix): DenseMatrix = {
  val nCoef = X.numCols.toInt
  val svd = X.computeSVD(nCoef, computeU = true)
  if (svd.s.size < nCoef) {
    sys.error(s"RowMatrix.computeInverse called on singular matrix.")
  }

  // Create the inv diagonal matrix from S 
  val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x,-1))))

  // U cannot be a RowMatrix
  val U = new DenseMatrix(svd.U.numRows().toInt,svd.U.numCols().toInt,svd.U.rows.collect.flatMap(x => x.toArray))

  // If you could make V distributed, then this may be better. However its alreadly local...so maybe this is fine.
  val V = svd.V
  // inv(X) = V*inv(S)*transpose(U)  --- the U is already transposed.
  (V.multiply(invS)).multiply(U)
  }

4

我在使用这个选项时遇到了问题。

conf.set("spark.sql.shuffle.partitions", "12")

RowMatrix中的行已经被打乱。

以下是我成功使用的更新。

import org.apache.spark.mllib.linalg.{DenseMatrix,DenseVector}
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix

def computeInverse(X: IndexedRowMatrix)
: DenseMatrix = 
{
  val nCoef = X.numCols.toInt
  val svd = X.computeSVD(nCoef, computeU = true)
  if (svd.s.size < nCoef) {
    sys.error(s"IndexedRowMatrix.computeInverse called on singular matrix.")
  }

  // Create the inv diagonal matrix from S 
  val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x, -1))))

  // U cannot be a RowMatrix
  val U = svd.U.toBlockMatrix().toLocalMatrix().multiply(DenseMatrix.eye(svd.U.numRows().toInt)).transpose

  val V = svd.V
  (V.multiply(invS)).multiply(U)
}

0
X.computeSVD返回的矩阵U的维度为m x k,其中m是原始(分布式)RowMatrix X的行数。预计m很大(可能大于k),因此如果我们希望代码能够扩展到非常大的m值,就不建议在驱动程序中收集它。
我认为下面两种解决方案都有这个缺陷。由@Alexander Kharlamov给出的答案调用了val U = svd.U.toBlockMatrix().toLocalMatrix()来在驱动程序中收集矩阵。同样的情况也发生在由@Climbs_lika_Spyder(顺便说一句,你的昵称很棒!!)给出的答案中,该答案调用了svd.U.rows.collect.flatMap(x => x.toArray)。我宁愿建议依靠分布式矩阵乘法,例如Scala代码发布在这里

我在您添加的链接中没有看到任何逆运算。 - Climbs_lika_Spyder
@Climbs_lika_Spyder 这个链接是关于分布式矩阵乘法的,可以替换你解决方案中最后一行的本地矩阵乘法(V.multiply(invS)).multiply(U),这样你就不需要在驱动程序中收集U了。我认为VinvS不够大,不会引起问题。 - Pablo

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接