Rcpp如何使用NumericVector选择/子集NumericMatrix列

5
我可以通过以下方式选择矩阵的所有行和矩阵范围内的列:
library(Rcpp)
cppFunction('
NumericMatrix subset(NumericMatrix x){
  return x(_, Range(0, 1));
}
')

然而,我希望根据 NumericVector y 来选择列,例如,y 可能是类似于 c(0, 1, 0, 0, 1) 的内容。我尝试了以下代码:

library(Rcpp)
cppFunction('
NumericMatrix subset(NumericMatrix x, NumericVector y){
  return x(_, y);
}
')

但它无法编译。我该怎么做?

1个回答

9

哎呀,Rcpp 对于非连续视图或仅在一个语句中选择列1和4的支持不是很好。正如您所看到的,使用 Rcpp::Range() 可以选择连续视图或选择所有列。您可能想升级到 RcppArmadillo 以更好地控制矩阵子集

RcppArmadillo 子集示例

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::mat matrix_subset_idx(const arma::mat& x,
                            const arma::uvec& y) { 

    // y must be an integer between 0 and columns - 1
    // Allows for repeated draws from same columns.
    return x.cols( y );
}


// [[Rcpp::export]]
arma::mat matrix_subset_logical(const arma::mat& x,
                                const arma::vec& y) { 
    // Assumes that y is 0/1 coded.
    // find() retrieves the integer index when y is equivalent 1. 
    return x.cols( arma::find(y == 1) );
}

测试

# Sample data
x = matrix(1:15, ncol = 5)
x
#      [,1] [,2] [,3] [,4] [,5]
# [1,]    1    4    7   10   13
# [2,]    2    5    8   11   14
# [3,]    3    6    9   12   15

# Subset only when 1 (TRUE) is found:
matrix_subset_logical(x, c(0, 1, 0, 0, 1))
#      [,1] [,2]
# [1,]    4   13
# [2,]    5   14
# [3,]    6   15

# Subset with an index representing the location
# Note: C++ indices start at 0 not 1!
matrix_subset_idx(x, c(1, 3))
#      [,1] [,2]
# [1,]    4   13
# [2,]    5   14
# [3,]    6   15

纯Rcpp逻辑

如果您不想承担使用armadillo的依赖关系,那么在Rcpp中矩阵子集的等效方法是:

#include <Rcpp.h>

// [[Rcpp::export]]
Rcpp::NumericMatrix matrix_subset_idx_rcpp(
        Rcpp::NumericMatrix x, Rcpp::IntegerVector y) { 

    // Determine the number of observations
    int n_cols_out = y.size();

    // Create an output matrix
    Rcpp::NumericMatrix out = Rcpp::no_init(x.nrow(), n_cols_out);

    // Loop through each column and copy the data. 
    for(unsigned int z = 0; z < n_cols_out; ++z) {
        out(Rcpp::_, z) = x(Rcpp::_, y[z]);
    }

    return out;
}

非常感谢!顺便说一下,y并不是以01编码的。我只是多次选择列0和列1。本质上,我想要用替换方式获取这些列。 - Euler_Salter
@Euler_Salter添加了一个仅使用Rcpp的示例,并使用RcppArmadillo更新了子集操作以处理重复位置(在这种情况下,只需删除find()并直接使用y)。 - coatless
很好的答案,我之前点了赞,但是你在子集索引的例子中错过了一个复制和粘贴。调用类似于 matrix_subset_idx(x, c(1,3)) 这样的内容。 - Dirk Eddelbuettel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接