R中矩阵对象的距离函数

3

我有一个非常简单的问题。

给定一个N维点(例如,一个向量,其中每个元素代表一个维度)表示为x,以及一个MxN维矩阵(或具有N维的M个点的组合!)表示为y

set.seed(999)
data <- matrix(runif(1100), nrow = 11, ncol = 10)

x <- data[1, ]
y <- data[2:nrow(data), ]

我想要计算 x 和每个点的距离度量 y。我知道一种简单的方法是这样做:

distances <- dist(rbind(x, y))

然而,针对这种特殊情况,我认为这并不是非常有效,原因如下:
  1. 我需要使用rbind函数,它会占用大量内存。
  2. dist计算每个点之间的距离,但我只对其中10个距离感兴趣,或者说只对y中每个点与x之间的距离感兴趣,我不关心y中各点之间的距离。
  3. 由于上条原因,我需要手动选择距离矩阵的最后一行来获取我实际需要的距离。
我想到的一个可能的解决方案是通过循环y手动应用距离测量。
distances <- apply(y, MARGIN = 1, function(a, b = x) {
   sqrt(sum((a - b)^2))
})

然而,当我对这两种方法进行计时时,得到的结果是:
func1 <- function(x, y) {
  apply(y, MARGIN = 1, function(a, b = x) {
    sqrt(sum((a - b)^2))
  })
}

func2 <- function(x, y) {
  dist(rbind(x, y))
}

microbenchmark::microbenchmark(
  func1(x, y),
  func2(x, y)
)

Unit: microseconds
        expr    min     lq     mean median      uq      max neval
 func1(x, y) 29.602 30.450 61.21791 31.301 32.3510 2916.101   100
 func2(x, y) 15.101 15.801 28.55304 17.201 17.7015 1143.001   100

所以我的问题是:有没有一种比使用 dist 更快的方法来解决这个问题?
3个回答

3

更新2: 如果我们假设数据完整且准确性达到 <code>epsilon</code>,那么我们可以使用rcpp实现更快速的距离计算版本。我已经添加了下面的内容,最快的版本使用了字节码编译器。对于那些有RcppParallel经验的人来说,这可能还可以进一步改进。

更新: 从fields包中的rdist函数是迄今为止发现的最快方法(请参见在R中高效地计算一个点和一组点之间所有距离)。当不使用字节码编译器时,它似乎是最快的。

对先前结果进行简要测试后,我发现在使用字节码编译器时,vapply比所有其他方法都要快(第一次运行时它会首次编译函数,这就是为什么字节码运行期间maxtime较大的原因)。

我在这里也尝试了@akrun和@ThomasIsCoding的方法。

library(microbenchmark)
library(compiler)
library(collapse)
library(fields)
library(Rcpp)

set.seed(999)
data <- matrix(runif(1100), nrow = 11, ncol = 10)

x <- data[1, ]
y <- data[2:nrow(data), ]

distances <- dist(rbind(x, y))

func1 <- function(x, y) {
  apply(y, MARGIN = 1, function(a, b = x) {
    sqrt(sum((a - b)^2))
  })
}

func2 <- function(x, y) {
  dist(rbind(x, y))
}

func3 <- function(x, y) {
  dapply(y, function(a, b = x) {
    sqrt(sum((a-b)^2))
  }, MARGIN = 1)
}

func4 <- function(x, y) {
  vapply(seq_len(nrow(y)), function(i, b = x) sqrt(sum((y[i,]-b)^2)), numeric(1))
}

func5 <- function(x, y) {
  rdist(rbind(x, y))
}

cppFunction('NumericVector func6(NumericVector x, NumericVector y) {
  int n = x.size();
  int n2 = y.size();
  
  int maxiters = n2/n;
  
  NumericVector results(maxiters);
  
  for(int i = 0; i < maxiters; i++) {
    results[i] = 0;
    for(int j = 0; j < n; j++) {
      double val = x[j] - y[j * maxiters + i];
      results[i] += val * val;
    }
    results[i] = sqrt(results[i]);
  }
  
  return results;
  
}')

func7 <- function(x, y) sqrt(rowSums((y-x[col(y)])^2))

func8 <- function(x, y) sqrt(colSums((t(y) - x)^2))

compiler::enableJIT(0)
#> [1] 3

microbenchmark::microbenchmark(
  func1(x, y),
  func2(x, y),
  func3(x, y),
  func4(x, y),
  func5(x, y),
  func6(x, y),
  func7(x, y),
  func8(x, y)
)
#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 37.001 42.8010 50.53103 45.4520 53.6515  138.302   100
#> func2(x, y) 20.201 25.3510 30.23096 27.8515 31.4010   70.401   100
#> func3(x, y) 23.901 27.6510 55.45699 30.0010 35.7505 2248.902   100
#> func4(x, y) 20.501 23.2010 28.20101 24.6020 31.4010  119.501   100
#> func5(x, y)  6.100  8.6020 19.27804  9.6515 11.4510  891.001   100
#> func6(x, y)  1.501  2.4010 11.60706  2.9010  3.4510  848.102   100
#> func7(x, y) 14.401 17.2505 27.73793 19.7510 23.2510  596.002   100
#> func8(x, y) 18.901 22.5510 27.91699 24.9015 29.3010   73.301   100


compiler::enableJIT(3)
#> [1] 0

microbenchmark::microbenchmark(
  func1(x, y),
  func2(x, y),
  func3(x, y),
  func4(x, y),
  func5(x, y),
  func6(x, y),
  func7(x, y),
  func8(x, y)
  
)
#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 32.100 35.9510 85.49213 39.4015 44.2510 4298.002   100
#> func2(x, y) 19.701 23.6010 45.11697 26.0505 29.6005 1732.702   100
#> func3(x, y) 19.801 22.2515 76.96108 24.8010 27.6510 5023.201   100
#> func4(x, y) 16.302 19.2510 77.46094 20.3010 21.8005 5564.701   100
#> func5(x, y)  6.201  8.5010 41.53397  9.4510 11.0510 3032.301   100
#> func6(x, y)  1.401  2.3010 13.95802  2.7005  3.0020 1101.801   100
#> func7(x, y) 14.201 16.7010 64.09999 18.6510 21.0015 4307.901   100
#> func8(x, y) 19.201 22.4500 64.33288 24.8510 27.5010 3776.101   100

本例使用的是 reprex包(v2.0.0),创建于2021/04/04。

仅输出结果。

#ordinary compiler

#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 37.001 42.8010 50.53103 45.4520 53.6515  138.302   100
#> func2(x, y) 20.201 25.3510 30.23096 27.8515 31.4010   70.401   100
#> func3(x, y) 23.901 27.6510 55.45699 30.0010 35.7505 2248.902   100
#> func4(x, y) 20.501 23.2010 28.20101 24.6020 31.4010  119.501   100
#> func5(x, y)  6.100  8.6020 19.27804  9.6515 11.4510  891.001   100
#> func6(x, y)  1.501  2.4010 11.60706  2.9010  3.4510  848.102   100
#> func7(x, y) 14.401 17.2505 27.73793 19.7510 23.2510  596.002   100
#> func8(x, y) 18.901 22.5510 27.91699 24.9015 29.3010   73.301   100

#bytecode compiler

#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 32.100 35.9510 85.49213 39.4015 44.2510 4298.002   100
#> func2(x, y) 19.701 23.6010 45.11697 26.0505 29.6005 1732.702   100
#> func3(x, y) 19.801 22.2515 76.96108 24.8010 27.6510 5023.201   100
#> func4(x, y) 16.302 19.2510 77.46094 20.3010 21.8005 5564.701   100
#> func5(x, y)  6.201  8.5010 41.53397  9.4510 11.0510 3032.301   100
#> func6(x, y)  1.401  2.3010 13.95802  2.7005  3.0020 1101.801   100
#> func7(x, y) 14.201 16.7010 64.09999 18.6510 21.0015 4307.901   100
#> func8(x, y) 19.201 22.4500 64.33288 24.8510 27.5010 3776.101   100

伟大的贡献。 - eduardokapp
我已经使用更快的Rcpp方法进行了更新,希望这能帮到你。 - Joel Kandiah

3

一个选项是来自collapsedapply

 library(collapse)
 func3 <- function(x, y) {
     dapply(y, function(a, b = x) {
             sqrt(sum((a-b)^2))
          }, MARGIN = 1)
  }

或者可以使用vapply

func4 <- function(x, y) {
  vapply(seq_len(nrow(y)), function(i, b = x) sqrt(sum((y[i,]-b)^2)), numeric(1))
 }

或者可以复制向量并在减去后使用rowSums

func7 <- function(x, y) sqrt(rowSums((y-x[col(y)])^2))
microbenchmark::microbenchmark(func1(x, y), func3(x, y), func4(x, y), func7(x, y))
#Unit: microseconds
#        expr    min      lq     mean  median      uq      max neval cld
# func1(x, y) 37.605 39.7475 61.17471 40.7595 42.1865 1955.888   100   a
# func3(x, y) 22.212 23.5945 68.63660 24.8320 25.8670 4333.933   100   a
# func4(x, y) 21.089 22.7930 24.11542 23.5945 24.2315   58.050   100   a
# func7(x, y)  7.731  8.9135 44.45935 10.0615 10.9500 3415.959   100   a

它会更快吗?在R基础中是否有替代方法? - eduardokapp
@eduardokapp 它比 apply 更快。 - akrun
哇!它非常快。非常有趣。我尝试使用vapply,但我不知道你必须用i进行“索引”。 - eduardokapp
默认情况下,它是按列进行的,就像sapply/lapply一样。因此,我们正在循环遍历行序列。 - akrun

1

这里是另一个基于R语言的示例:

sqrt(colSums((t(y) - x)^2))

这与使用“dist”相同。正如我在问题中所说的那样,您只是对结果进行了子集处理。 - eduardokapp
@eduardokapp 抱歉,是我不好。请看我的更新。 - ThomasIsCoding
1
@Joel Kandiah,您是否愿意将这个替代方案包含在您的基准测试中?谢谢。 - eduardokapp
1
@eduardokapp 我已经更新了我的帖子,包含到目前为止找到的所有解决方案。 - Joel Kandiah

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接