Eigen:返回具有编译时维度检查的矩阵块的引用

5
我需要的是 这个问题 的拓展。具体来说,我想在一个遗留的 C 和 Fortran 库中创建一个 C++ Eigen 包装器,该库使用 2D 数据结构:
[   x[0,0] ...   x[0,w-1] ]
[   u[0,0] ...   u[0,w-1] ]
[          ...            ]
[ x[c-1,0] ... x[c-1,w-1] ]
[ u[c-1,0] ... u[c-1,w-1] ]

每个条目x[i,j]u[i,j]本身都是大小为(nx1)和(mx1)的列向量。 这导致了一些复杂(且容易出错)的指针算术以及一些非常难以阅读的代码。
因此,我想编写一个Eigen类,其唯一目的是尽可能轻松地提取该矩阵的条目。在C++14中,它看起来像这样data_getter.h:
#ifndef DATA_GETTER_HEADER
#define DATA_GETTER_HEADER

#include "Eigen/Dense"

template<typename T, int n, int m, int c, int w>
class DataGetter {
public:

    /** Return a reference to the data as a matrix */
    static auto asMatrix(T *raw_ptr) {
        auto out = Eigen::Map<Eigen::Matrix<T, (n + m) * c, w>>(raw_ptr);
        static_assert(decltype(out)::RowsAtCompileTime == (n + m) * c);
        static_assert(decltype(out)::ColsAtCompileTime == w);
        return out;
    }

    /** Return a reference to the submatrix
     * [ x[i,0], ..., x[i,w-1]]
     * [ u[i,0], ..., u[i,w-1]] */
    static auto W(T *raw_ptr, int i) {
        auto out = asMatrix(raw_ptr).template middleRows<n + m>((n + m) * i);
        static_assert(decltype(out)::RowsAtCompileTime == (n + m));
        static_assert(decltype(out)::ColsAtCompileTime == w);
        return out;
    }

    /** Return a reference to the submatrix [ x[i,0], ..., x[i,w-1]] */
    static auto X(T *raw_ptr, int i) {
        auto out = W(raw_ptr, i).template topRows<n>();
        static_assert(decltype(out)::RowsAtCompileTime == n);
        static_assert(decltype(out)::ColsAtCompileTime == w);
        return out;
    }

    /** Return a reference to x[i,j] */
    static auto X(T *raw_ptr, int i, int j) {
        auto out = X(raw_ptr, i).col(j);
        static_assert(decltype(out)::RowsAtCompileTime == n);
        static_assert(decltype(out)::ColsAtCompileTime == 1);
        return out;
    }

    /** Return a reference to the submatrix [ u[i,0], ..., u[i,w-1]] */
    static auto U(T *raw_ptr, int i) {
        auto out = W(raw_ptr, i).template bottomRows<m>();
        static_assert(decltype(out)::RowsAtCompileTime == m);
        static_assert(decltype(out)::ColsAtCompileTime == w);
        return out;
    }

    /** Return a reference to u[i,j] */
    static auto U(T *raw_ptr, int i, int j) {
        auto out = U(raw_ptr, i).col(j);
        static_assert(decltype(out)::RowsAtCompileTime == m);
        static_assert(decltype(out)::ColsAtCompileTime == 1);
        return out;
    }

    /** Return a reference to the submatrix
     * [ x[0,i], ..., x[c-1,i]]
     * [ u[0,i], ..., u[c-1,i]] */
    static auto C(T *raw_ptr, int i) {
        auto out = Eigen::Map<Eigen::Matrix<T, n + m, c>>(
                asMatrix(raw_ptr).col(i).template topRows<(n + m) * c>().data());
        static_assert(decltype(out)::RowsAtCompileTime == (n + m));
        static_assert(decltype(out)::ColsAtCompileTime == c);
        return out;
    }

    /** Return a reference to the submatrix [ x[0,i], ..., x[c-1,i]] */
    static auto Xc(T *raw_ptr, int i) {
        auto out = C(raw_ptr, i).template topRows<n>();
        static_assert(decltype(out)::RowsAtCompileTime == n);
        static_assert(decltype(out)::ColsAtCompileTime == c);
        return out;
    }

    /** Return a reference to the submatrix [ u[0,i], ..., u[c-1,i]] */
    static auto Uc(T *raw_ptr, int i) {
        auto out = C(raw_ptr, i).template bottomRows<m>();
        static_assert(decltype(out)::RowsAtCompileTime == m);
        static_assert(decltype(out)::ColsAtCompileTime == c);
        return out;
    }
};

#endif /* DATA_GETTER_HEADER */

这里是一个测试程序,演示了它的工作原理:

#include <iostream>
#include <vector>
#include "Eigen/Dense"
#include "data_getter.h"

using namespace std;
using namespace Eigen;

template<typename T>
void printSize(MatrixBase<T> &mat) {
    cout << T::RowsAtCompileTime << " x " << T::ColsAtCompileTime;
}

int main() {

    using T = double;
    const int n = 2;
    const int m = 3;
    const int c = 2;
    const int w = 5;
    const int size = w * (c * (n + m));
    std::vector<T> vec;
    for (int i = 0; i < size; ++i)
        vec.push_back(i);

    /* Define the interface that we will use a lot */
    using Data = DataGetter<T, n, m, c, w>;

    /* Now let's map that pointer to some submatrices */
    Ref<Matrix<T, (n + m) * c, w>> allData = Data::asMatrix(vec.data());
    Ref<Matrix<T, n, w>> x1 = Data::X(vec.data(), 1);
    Ref<Matrix<T, n, c>> xc2 = Data::Xc(vec.data(), 2);
    Ref<Matrix<T, n + m, c>> xuc2 = Data::C(vec.data(), 2);
    Ref<Matrix<T, n, 1>> x12 = Data::X(vec.data(), 1, 2);

    cout << "Data::asMatrix( T* ): ";
    printSize(allData);
    cout << endl << endl << allData << endl << endl;
    cout << "Data::X( T*, 1 )    : ";
    printSize(x1);
    cout << endl << endl << x1 << endl << endl;
    cout << "Data::Xc( T*, 2 )   : ";
    printSize(xc2);
    cout << endl << endl << xc2 << endl << endl;
    cout << "Data::C( T*, 2 )    : ";
    printSize(xuc2);
    cout << endl << endl << xuc2 << endl << endl;
    cout << "Data::X( T*, 1, 2 ) : ";
    printSize(x12);
    cout << endl << endl << x12 << endl << endl;

    /* Now changes to x12 should be reflected in the other variables */
    x12.setZero();

    cout << "-----" << endl << endl << "x12.setZero() " << endl << endl << "-----" << endl;

    cout << "allData" << endl << endl << allData << endl << endl;
    cout << "x1" << endl << endl << x1 << endl << endl;
    cout << "xc2" << endl << endl << xc2 << endl << endl;
    cout << "xuc2" << endl << endl << xuc2 << endl << endl;
    cout << "x12" << endl << endl << x12 << endl << endl;
    return 0;
}

具体来说,它会生成以下输出(如预期):
Data::asMatrix( T* ): 10 x 5

 0 10 20 30 40
 1 11 21 31 41
 2 12 22 32 42
 3 13 23 33 43
 4 14 24 34 44
 5 15 25 35 45
 6 16 26 36 46
 7 17 27 37 47
 8 18 28 38 48
 9 19 29 39 49

Data::X( T*, 1 )    : 2 x 5

 5 15 25 35 45
 6 16 26 36 46

Data::Xc( T*, 2 )   : 2 x 2

20 25
21 26

Data::C( T*, 2 )    : 5 x 2

20 25
21 26
22 27
23 28
24 29

Data::X( T*, 1, 2 ) : 2 x 1

25
26

-----

x12.setZero() 

-----
allData

 0 10 20 30 40
 1 11 21 31 41
 2 12 22 32 42
 3 13 23 33 43
 4 14 24 34 44
 5 15  0 35 45
 6 16  0 36 46
 7 17 27 37 47
 8 18 28 38 48
 9 19 29 39 49

x1

 5 15  0 35 45
 6 16  0 36 46

xc2

20  0
21  0

xuc2

20  0
21  0
22 27
23 28
24 29

x12

0
0

问题在于维度的编译时检查似乎没有起作用。在 data_getter.h 中,您可能会注意到我在维度上放了一堆 static_assert。这可能看起来有点过头了,但我想确保表达式确实执行了编译时操作,以便我们可以对维度进行检查。如果它们是动态表达式,那么所有大小都将为-1。
然而,尽管所有 static_assert 都通过了,似乎没有对引用进行任何编译时检查。例如,如果我们在测试程序中更改以下行:
Ref<Matrix<T, (n + m) * c, w>> allData = Data::asMatrix(vec.data());

转换为

Ref<Matrix<T, (n + m) * c + 1, w>> allData = Data::asMatrix(vec.data());

代码编译成功,但运行时崩溃。这似乎表明Ref正在丢弃维度。那么我应该如何定义这些变量呢?
可能会想到的一个想法是将这些返回值也定义为auto。然而,这是Eigen文档明确反对的,因为如果我们最终在循环中使用输出,它可能会导致表达式被反复评估。这就是我使用Ref的原因。此外,既然我们在编译时知道大小,明确声明大小似乎是一个好主意...
那么这是Ref的一个bug吗?对于所有访问器方法产生的变量,我应该使用什么类型?

@chtz 更新了更多信息... - bremen_matt
@ggael 这是一个 bug 吗,还是预期行为? - bremen_matt
这是一个缺点。 - ggael
@ggael,你认为这个功能将来会被添加吗?还是需要进行重大重新设计才行? - bremen_matt
还有没有其他方法可以实现上述行为而不使用 Ref - bremen_matt
显示剩余16条评论
1个回答

3

如果您在评论中错过了... @ggael说Eigen Ref在编译时不会检查维度。


1
这是固定的链接:https://bitbucket.org/eigen/eigen/commits/e1203d5ceb8e669f76fd643f4056b7e28d987077 - ggael
@ggael,请将此作为答案发布,以便您可以获得悬赏。 - bremen_matt

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接