library(Rcpp)
library(microbenchmark)
cppFunction('NumericMatrix crossdist(NumericMatrix x,NumericMatrix y){
int n1=x.nrow(),n2=y.nrow(),ncol=x.ncol(),i,j,k;
if(ncol!=y.ncol())throw std::runtime_error("Different column number");
NumericMatrix out(n1,n2);
for(i=0;i<n1;i++)
for(j=0;j<n2;j++){
double sum=0;
for(k=0;k<ncol;k++)sum+=pow(x(i,k)-y(j,k),2);
out(i,j)=sqrt(sum);
}
return out;
}')
cppFunction('NumericMatrix crossdist2(NumericMatrix x,NumericMatrix y){
int n1=x.nrow(),n2=y.nrow(),ncol=x.ncol(),i,j,k;
if(ncol!=y.ncol())throw std::runtime_error("Different column number");
NumericMatrix out(n1,n2);
double rs1[n1],rs2[n2],sum;
for(i=0;i<n1;i++){sum=0;for(j=0;j<ncol;j++)sum+=pow(x(i,j),2);rs1[i]=sum;}
for(i=0;i<n2;i++){sum=0;for(j=0;j<ncol;j++)sum+=pow(y(i,j),2);rs2[i]=sum;}
for(i=0;i<n1;i++)for(j=0;j<n2;j++){
sum=0;
for(k=0;k<ncol;k++)sum+=x(i,k)*y(j,k);
out(i,j)=sqrt(rs1[i]+rs2[j]-2*sum);
}
return out;
}')
x=matrix(rnorm(2e4),,10)
y=matrix(rnorm(1e4),,10)
b=microbenchmark(times=100,
crossdist(x,y),
crossdist2(x,y),
Rfast::dista(x,y),
proxy::dist(x,y),
pracma::distmat(x,y),
as.matrix(pdist::pdist(x,y)),
sqrt(outer(rowSums(x^2),rowSums(y^2),"+")-2*tcrossprod(x,y)),
sqrt(outer(rowSums(x^2),rowSums(y^2),"+")-2*x%*%t(y)),
sqrt(Rfast::Outer(Rfast::rowsums(y^2),Rfast::rowsums(x^2),"+")-2*x%*%t(y)),
sqrt(Rfast::Outer(Rfast::rowsums(y^2),Rfast::rowsums(x^2),"+")-2*Rfast::Tcrossprod(x,y)),
outer(split(x,row(x)),split(y,row(y)),Vectorize(function(x,y)sqrt(sum((x-y)^2))))
)
a=aggregate(b$time,list(b$expr),median)
a=a[order(a[,2]),]
writeLines(paste(sprintf("%.3f",a[,2]/min(a[,2])),gsub(" ","",a[,1])))
结果:
1.000 crossdist(x,y)
1.054 crossdist2(x,y)
1.217 sqrt(Rfast::Outer(Rfast::rowsums(y^2),Rfast::rowsums(x^2),"+")-2*Rfast::Tcrossprod(x,y))
1.227 sqrt(Rfast::Outer(Rfast::rowsums(y^2),Rfast::rowsums(x^2),"+")-2*x%*%t(y))
1.897 Rfast::dista(x,y)
1.946 sqrt(outer(rowSums(x^2),rowSums(y^2),"+")-2*tcrossprod(x,y))
1.950 sqrt(outer(rowSums(x^2),rowSums(y^2),"+")-2*x%*%t(y))
2.004 proxy::dist(x,y)
2.402 as.matrix(pdist::pdist(x,y))
3.674 pracma::distmat(x,y)
177.474 outer(split(x,row(x)),split(y,row(y)),Vectorize(function(x,y)sqrt(sum((x-y)^2))))
tcrossprod(m1,m2)
是一个比 m1%*%t(m2)
稍微更快一点的替代方案,尽管在这个基准测试中两者速度差不多:
> m1=matrix(rnorm(2e4),,10);m2=matrix(rnorm(1e4),,10)
> microbenchmark(times=1000,tcrossprod(m1,m2),m1%*%t(m2),Rfast::Tcrossprod(m1,m2))
expr min lq mean median uq
tcrossprod(m1, m2) 12.28305 13.06046 17.58402 17.60379 17.74104
m1 %*% t(m2) 12.79996 17.30764 17.52570 17.59473 17.70758
Rfast::Tcrossprod(m1, m2) 11.48939 13.81658 17.68059 17.23675 17.37447
这是计算
m1
中第1行到
m2
中第1行,
m1
中第2行到
m2
中第2行等距离的快速方法:
sqrt(rowSums((m1-m2)^2))
这是一种快速计算向量v
到矩阵m
每行距离的方法:
sqrt(rowSums(m^2)+sum(v^2)-2*(m%*%as.matrix(v))[,1])