使用Rcpp中的rmultinom函数

3
我想在C++代码中使用R函数rmultinom,并与Rcpp一起使用。但我遇到了一个关于参数不足的错误 - 我不熟悉这些参数应该是什么,因为它们与在R中使用的函数参数不对应。我还没有使用"::Rf_foo"语法从Rcpp代码访问R函数的成功经验。
下面是我的简化代码版本(是的,我正在编写Gibbs采样器)。
#include <Rcpp.h>                                                                                                                                     
using namespace Rcpp;                                                                                                                                 

// C++ implementation of the R which() function.                                                                                                      
int whichC(NumericVector x, double val) {                                                                                                             
  int ind = -1;                                                                                                                                       
  int n = x.size();                                                                                                                                   
  for (int i = 0; i < n; ++i) {                                                                                                                       
    if (x[i] == val) {                                                                                                                                
      if (ind == -1) {                                                                                                                                
        ind = i;                                                                                                                                      
      } else {                                                                                                                                        
        throw std::invalid_argument( "value appears multiple times." );                                                                               
      }                                                                                                                                               
    } // end if                                                                                                                                       
  } // end for                                                                                                                                        
  if (ind != -1) {                                                                                                                                    
    return ind;                                                                                                                                       
  } else {                                                                                                                                            
    throw std::invalid_argument( "value doesn't appear here!" );                                                                                      
    return -1;                                                                                                                                        
  }                                                                                                                                                   
}                                                                                                                                                     

// [[Rcpp::export]]                                                                                                                                   
int multSample(double p1, double p2, double p3) {                                                                                                     
  NumericVector params(3);                                                                                                                            
  params(0) = p1;                                                                                                                                     
  params(1) = p2;                                                                                                                                     
  params(2) = p3;                                                                                                                                     

  // HERE'S THE PROBLEM.                                                                                                                              
  RObject sampled = rmultinom(1, 1, params);                                                                                                          
  int out = whichC(as<NumericVector>(sampled), 1);                                                                                                    
  return out;                                                                                                                                         
}

我是一个新手,所以我意识到这些代码很可能是新手和低效的。我乐于听取关于如何改进我的c++代码的建议,但我的优先事项是了解rmultinom业务。谢谢!
顺便说一句,我很抱歉与this thread相似,但:
1.答案不适用于我的目的 2.差异足以证明需要提出不同的问题(你认为呢?) 3.该问题已经发布并回答了一年。
3个回答

5
以下是用户95215的答案,已经进行了修改以便编译,并且还有一种更符合Rcpp风格的版本:
#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
IntegerVector oneMultinomC(NumericVector probs) {
    int k = probs.size();
    SEXP ans;
    PROTECT(ans = Rf_allocVector(INTSXP, k));
    probs = Rf_coerceVector(probs, REALSXP);
    rmultinom(1, REAL(probs), k, &INTEGER(ans)[0]);
    UNPROTECT(1);
    return(ans);
}

// [[Rcpp::export]]
IntegerVector oneMultinomCalt(NumericVector probs) {
    int k = probs.size();
    IntegerVector ans(k);
    rmultinom(1, probs.begin(), k, ans.begin());
    return(ans);
}

0
如果我尝试编译你的代码,会出现编译器错误:
> Rcpp::sourceCpp('~/scratch/multSample.cpp')
multSample.cpp:33:21: error: no matching function for call to 'rmultinom'
  RObject sampled = rmultinom(1, 1, params);
                    ^~~~~~~~~
/Library/Frameworks/R.framework/Resources/include/Rmath.h:449:6: note: candidate function not viable: requires 4 arguments, but 3 were provided
void    rmultinom(int, double*, int, int*);
        ^
1 error generated.

正如它所示,您没有正确指定参数。请注意,与其他函数相比,rmultinom接口有点棘手:它填充由*rn指向的内存,而不是返回一个新对象(具有自己的新分配的内存)。

如果您查看R源代码,您将看到接口,并且您还可以在此处看到其使用示例(实际上,stats制作了一个包装器函数,执行一些更多的参数检查等)。但请注意这里的用法:

rmultinom(size, REAL(prob), k, &INTEGER(ans)[ik]);

换句话说,通过将该内存的地址传递到rmultinom函数中,它正在填充一个名为ans的INTSXP。
因此,如果您想从Rcpp使用此函数,则必须执行类似的操作--但也许这值得进行类似的糖果向量化处理,以避免该接口的丑陋。您可以尝试做以下事情:
IntegerMatrix sampled(nrow, ncol);
rmultinom(1, 1, params, sampled.begin());

或类似的东西。


谢谢,您提供的链接帮助我找到了一个可行的答案。我会发布一个详细说明那段代码的答案(如果这不符合惯例,请原谅我)。 - sinwav

0

Kevin提供的示例和链接使我能够找到一个可行的答案。需要对类型进行一些处理。我编写了一个函数,允许您从多项分布中抽取一个向量样本。以下是代码。

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
NumericVector oneMultinomC(NumericVector probs) {
    int k = probs.size();
    SEXP ans;
    PROTECT(ans = RF_allocVector(INTSXP, k));
    probs = RF_coerceVector(probs, REALSXP);
    rmultinom(1, REAL(probs), k, &INTEGER(ans)[0]);
    UNPROTECT(1);
    return ans;
}

我不理解这里发生了一半的事情。特别是,我不理解'rmultinom'的第四个参数。我知道它是指向存储输出的内存位置的指针,但我不理解'[0]'部分。尽管如此,它还是有效的。男孩和女孩们,继续进行吉布斯采样。


我认为你也可以用更符合Rcpp惯用风格的方式来编写,而不需要使用SEXPPROTECT等内容... - Dirk Eddelbuettel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接