用Rcpp和RcppArmadillo计算多元正态/高斯函数的一阶导数

3

我正在尝试在R中实现多元正态分布的一阶导数,基于这里这里发布的多元正态分布的rcpp实现。

下面是一个快速的R实现:

mvnormDeriv = function(..., mu=rep(0,length(list(...))), sigma=diag(length(list(...)))) {
    if(sd(laply(list(...),length))!=0)
        stop("The vectors not same length.")
    fn = function(x) -1 * c((1/sqrt(det(2*pi*sigma))) * exp(-0.5*t(x-mu)%*%solve(sigma)%*%(x-mu))) * solve(sigma,(x-mu))
    out = t(apply(cbind(...),1,fn))
    colnames(out) = c('x', 'y')
    return(out[,1])
}

以及一些基准测试数据:

set.seed(123456789)
sigma = rWishart(1, 2, diag(2))
means = rnorm(2)
X     = rmvnorm(10000, means, sigma[,,1])
x1    = X[,1]
x2    = X[,2]
benchmark(mvnormDeriv(x1,x2,mu=means,sigma=sigma),
    order="relative", replications=5)[,1:4]

公式可以在《矩阵手册》(2012)的第346个公式中找到。
我无法修改来自此处的多元正态分布的rcpp实现。以下是我尝试使用的一些代码。
// [[Rcpp::export]]
arma::vec dmvnormDeriv_arma(arma::mat x,  SEXP mu_sexp, arma::mat sigma, bool log = false) {

    // create Rcpp vector and matrix from SEXP arguments
    Rcpp::NumericVector mu_rcpp(mu_sexp);
    // create views for arma objects(reuses memory and avoids extra copy)
    arma::vec mu_vec(mu_rcpp.begin(), mu_rcpp.size(), false);
    arma::rowvec mu(mu_rcpp.begin(), mu_rcpp.size(), false);

    // return(mu_vec);
    arma::vec distval = Mahalanobis(x,  mu, sigma);
    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = std::log(2.0 * M_PI);
    arma::vec val = exp(-( (x.n_cols * log2pi + logdet + distval)/2));

    // x.each_row() -= mu;
    // arma::vec val2 = solve(sigma, x.row(1));
    // arma::vec retval = -1 * val(1) * solve(sigma, x.row(1)-mu_vec);

    return(val);
}

当然,这还不完整。你有什么想法可以在rcpp或者Armadillo中实现“* solve(sigma,(x-mu))”部分吗?我遇到了处理不同变量类型和运行每一行x的解决方案的问题。

“solve()”不就是“inv()”的简写吗?如果是这样,你可以直接从Armadillo文档中获取。我们即将更新您所提到的画廊帖子,请留意。然后使用此功能提交一个新的帖子 :) - Dirk Eddelbuettel
谢谢,我已经添加了一个解决方案。如果有兴趣的话,我会提交一些内容到画廊。如果您有改进建议,请告诉我。 - user2503795
略微偏题,但你提到了使用2012年的Matrix Cookbook,能否在问题中提供一个链接?这很难找到。就话题而言,它可能有助于使问题更加自包含。谢谢! - Jotaf
1个回答

5

下面是基于RcppArmadillo的解决方案。它比R实现快100倍以上。首先,这里是C++实现,它依赖于rcpp gallery example

// [[Rcpp::export]]
arma::mat dmvnormderiv_arma(arma::mat x, arma::rowvec mean, arma::mat sigma, bool log = false) {
    // get result for mv normal
    arma::vec distval = Mahalanobis(x,  mean, sigma);
    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = std::log(2.0 * M_PI);
    arma::vec mvnorm = exp(-( (x.n_cols * log2pi + logdet + distval)/2));

    // create output matrix with one column for each derivative
    int n = x.n_rows;
    arma::mat deriv;
    deriv.copy_size(x);
    for (int i=0; i < n; i++) {
        deriv.row(i) = -1 * mvnorm(i) * trans(solve(sigma, trans(x.row(i) - mean)));
    }

    return(deriv);
}

还有两个R实现,一个是纯R,另一个基于mvtnorm包中的dmvnorm

library('RcppArmadillo')
library('mvtnorm')
library('rbenchmark')
sourceCpp('mvnorm.cpp')

mvnormDeriv = function(X, mu=rep(0,ncol(X)), sigma=diag(ncol(X))) {
    fn = function(x) -1 * c((1/sqrt(det(2*pi*sigma))) * exp(-0.5*t(x-mu)%*%solve(sigma)%*%(x-mu))) * solve(sigma,(x-mu))
    out = t(apply(X,1,fn))
    return(out)
}
dmvnormDeriv = function(X, mean, sigma) {
    if (is.vector(X)) X <- matrix(X, ncol = length(X))
    if (missing(mean)) mean <- rep(0, length = ncol(X))
    if (missing(sigma)) sigma <- diag(ncol(X))
    n = nrow(X)
    mvnorm = dmvnorm(X, mean = mean, sigma = sigma)
    deriv = array(NA,c(n,ncol(X)))
    for (i in 1:n)
        deriv[i,] = -mvnorm[i] * solve(sigma,(X[i,]-mean))
    return(deriv)
}

最后附上一些基准测试结果:
set.seed(123456789)
sigma = rWishart(1, 2, diag(2))[,,1]
means = rnorm(2)
X     = rmvnorm(10000, means, sigma)

benchmark(dmvnormderiv_arma(X,means,sigma),
        mvnormDeriv(X,mu=means,sigma=sigma),
        dmvnormDeriv(X,mean=means,sigma=sigma),
        order="relative", replications=5)[,1:4]

                                          test replications elapsed
1           dmvnormderiv_arma(X, means, sigma)            5   0.016
3 dmvnormDeriv(X, mean = means, sigma = sigma)            5   2.118
2    mvnormDeriv(X, mu = means, sigma = sigma)            5   5.939
  relative
1    1.000
3  132.375
2  371.187

干得好。看起来相当不错,几乎就绪了。请查看现有的.Rmd(或者,如果您喜欢,.cpp)文件,并通过邮件发送给我,或者在Github上进行分叉并执行常见的拉取请求操作。 - Dirk Eddelbuettel
而不是执行sum(arma::log(arma::eig_sym(sigma))),log_det()可能更有效。 - mtall
Dirk,这里有一个Rmd文件的要点。我会在画廊GitHub页面上添加一个问题,以便讨论可以在那里继续。 - user2503795

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接