如何在Spark中计算距离矩阵?

10

我尝试过对样本进行配对,但由于100个样本会导致9900个样本,这样会消耗大量的内存。在spark分布式环境中,有哪些更有效的计算距离矩阵的方法?

下面是我正在尝试的伪代码片段

val input = (sc.textFile("AirPassengers.csv",(numPartitions/2)))
val i = input.map(s => (Vectors.dense(s.split(',').map(_.toDouble))))
val indexed = i.zipWithIndex()                                                                       //Including the index of each sample
val indexedData = indexed.map{case (k,v) => (v,k)}

val pairedSamples = indexedData.cartesian(indexedData)

val filteredSamples = pairedSamples.filter{ case (x,y) =>
(x._1.toInt > y._1.toInt)  //to consider only the upper or lower trainagle
 }
filteredSamples.cache
filteredSamples.count

上述代码创建了数据对,但即使我的数据集包含100个样本,通过对过滤后的样本(filteredSamples)进行配对操作,结果会产生4950个样本,这对于大数据来说可能非常昂贵。


2
请发布您已尝试过的任何代码示例,与您的问题相关的数据或样本数据以及您尝试过的任何库或资源。 - Mike Zavarello
我已经添加了代码片段。希望它能帮助你理解我的问题。 - Manoj Kondapaka
3个回答

5

我最近回答了一个类似的问题

基本上,需要计算n(n-1)/2个对,这在你的例子中是4950个计算。然而,这种方法不同的地方在于我使用连接操作替代了cartesian。使用你的代码,解决方案看起来像这样:

val input = (sc.textFile("AirPassengers.csv",(numPartitions/2)))
val i = input.map(s => (Vectors.dense(s.split(',').map(_.toDouble))))
val indexed = i.zipWithIndex()

// including the index of each sample
val indexedData = indexed.map { case (k,v) => (v,k) } 

// prepare indices
val count = i.count
val indices = sc.parallelize(for(i <- 0L until count; j <- 0L until count; if i > j) yield (i, j))

val joined1 = indices.join(indexedData).map { case (i, (j, v)) => (j, (i,v)) }
val joined2 = joined1.join(indexedData).map { case (j, ((i,v1),v2)) => ((i,j),(v1,v2)) }

// after that, you can then compute the distance using your distFunc
val distRDD = joined2.mapValues{ case (v1, v2) => distFunc(v1, v2) }

请尝试使用这种方法,并与您已发布的方法进行比较。希望这可以加快您的代码速度。


1
据我所知,通过查阅各种来源和Spark mllib clustering site,Spark目前不支持距离或pdist矩阵。
在我看来,100个样本将始终输出至少4950个值;因此,手动创建一个使用转换(如.map)的分布式矩阵求解器将是最佳解决方案。

0

这可以作为jtitusjanswer的Java版本。

public JavaPairRDD<Tuple2<Long, Long>, Double> getDistanceMatrix(Dataset<Row> ds, String vectorCol) {

    JavaRDD<Vector> rdd = ds.toJavaRDD().map(new Function<Row, Vector>() {

        private static final long serialVersionUID = 1L;

        public Vector call(Row row) throws Exception {
            return row.getAs(vectorCol);
        }

    });

    List<Vector> vectors = rdd.collect();

    long count = ds.count();

    List<Tuple2<Tuple2<Long, Long>, Double>> distanceList = new ArrayList<Tuple2<Tuple2<Long, Long>, Double>>();

    for(long i=0; i < count; i++) {
        for(long j=0; j < count && i > j; j++) {
            Tuple2<Long, Long> indexPair = new Tuple2<Long, Long>(i, j);
            double d = DistanceMeasure.getDistance(vectors.get((int)i), vectors.get((int)j));
            distanceList.add(new Tuple2<Tuple2<Long, Long>, Double>(indexPair, d));
        }
    }

    return distanceList;
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接