通用的 C++ 多维迭代器

12
在我的当前项目中,我正在处理一个多维数据结构。底层文件是顺序存储的(即一个巨大的数组,没有向量嵌套)。使用这些数据结构的算法需要知道各个维度的大小。
我想知道是否已经以通用的方式定义了一个多维迭代器类,并且是否有任何标准或首选方法来解决这个问题。
目前,我只是使用一个线性迭代器,其中包含一些额外的方法,返回每个维度的大小和第一部分中有多少个维度。我不喜欢它的原因是因为例如无法合理地使用std::distance(即仅返回整个结构的距离,而不是每个维度的距离)。
在大多数情况下,我将以线性方式访问数据结构(从第一维开始到结束->下一维+...等等),但是知道何时一个维度“结束”会很好。我不知道如何在这种方法中仅使用operator*(),operator+()和operator==()来做到这一点。
向量嵌套的方法不受欢迎,因为我不想拆分文件。此外,算法必须在具有不同维度的结构上操作,因此很难进行泛化(或者可能有一种方法?)。
Boost multi_array也有相同的问题(多个“级别”的迭代器)。
我希望这不会太模糊或抽象。任何正确方向的提示将不胜感激。
我再次自己寻找解决方案,重新访问了boost::multi_array。事实证明,可以使用它们在数据上生成子视图,但同时也在顶层使用直接迭代器并隐含地“展平”数据结构。然而,已实现的版本的multi_array不适合我的需求,因此我可能会自己实现一个(可以处理后台文件缓存),它与其他multi_array兼容。
一旦实现完成,我将再次更新。

1
看起来你需要自己实现多维迭代器。不要仅限于运算符:你可以使用命名方法查询每个维度中当前位置的信息。 - Serge Rogatch
2
这个问题很有趣。我发现提供这样的信息的唯一方法是拥有可以推断多索引(例如 {x,y,z})的方法来自扁平化索引(反之亦然)。我不认为你可以以“标准”的方式做到这一点,除非提供自己的自定义类来实现这个目的。 - coincoin
1
如果数据结构是只读的(即您从文件中加载它,然后永远不会更改它),则可以将数据结构读入您的多维数据结构和一个平面向量。让需要平面迭代的例程在平面向量上运行,需要多维访问的例程使用多维数据结构。 - Ben Braun
@coincoin,你能否提供更多关于如何处理N维数组的详细信息? - gnzlbg
@gnzlbg,你可以在code review上找到我实现的一个例子。希望它能满足你的需求。如果不行,你可以在SO上提问,我很乐意帮忙。 - coincoin
显示剩余7条评论
2个回答

2
我刚刚决定在Github上开设一个公共代码库:MultiDim Grid,这可能有助于您的需求。这是一个正在进行中的项目,所以如果您能尝试并告诉我您需要什么,我会很高兴的。
我通过codereview上的这个主题开始了这个工作。
简单来说:

MultiDim Grid提供了一个平面的一维数组,可以在多维坐标和扁平索引之间提供通用快速访问。

您可以获得容器行为,因此可以访问迭代器。

感谢你们的回答,很抱歉我回复晚了...太忙了。现在我已经自己实现了一个类似于你们项目的东西,并且目前我对它很满意。我已经很久没有碰那段代码了,因为它只是做我想要的事情,但如果我再次需要,我会查看你们的代码,也许合并力量。 - Lazarus535

1
这并不难实现。只需明确说明您的项目需要哪些功能即可。以下是一个简单的示例。
#include <iostream>
#include <array>
#include <vector>
#include <cassert>

template<typename T, int dim>
class DimVector : public std::vector<T> {
public:
    DimVector() {
        clear();
    }

    void clear() {
        for (auto& i : _sizes)
            i = 0;
        std::vector<T>::clear();
    }

    template<class ... Types>
    void resize(Types ... args) {
        std::array<int, dim> new_sizes = { args ... };
        resize(new_sizes);
    }

    void resize(std::array<int, dim> new_sizes) {
        clear();
        for (int i = 0; i < dim; ++i)
            if (new_sizes[i] == 0)
                return;
        _sizes = new_sizes;
        int realsize = _sizes[0];
        for (int i = 1; i < dim; ++i)
            realsize *= _sizes[i];
        std::vector<T>::resize(static_cast<size_t>(realsize));
    }

    decltype(auto) operator()(std::array<int, dim> pos) {
        // check indexes and compute original index
        size_t index;
        for (int i = 0; i < dim; ++i) {
            assert(0 <= pos[i] && pos[i] < _sizes[i]);
            index = (i == 0) ? pos[i] : (index * _sizes[i] + pos[i]);
        }
        return std::vector<T>::at(index);
    }

    template<class ... Types>
    decltype(auto) at(Types ... args) {
        std::array<int, dim> pos = { args ... };
        return (*this)(pos);
    }

    int size(int d) const {
        return _sizes[d];
    }


    class Iterator {
    public:
        T& operator*() const;
        T* operator->() const;
        bool operator!=(const Iterator& other) const {
            if (&_vec != &other._vec)
                return true;
            for (int i = 0; i < dim; ++i)
                if (_pos[i] != other._pos[i])
                    return true;
            return false;
        }
        int get_dim(int d) const {
            assert(0 <= d && d < dim);
            return _pos[d];
        }
        void add_dim(int d, int value = 1) {
            assert(0 <= d && d < dim);
            _pos[d] += value;
            assert(0 <= _pos[i] && _pos[i] < _vec._sizes[i]);
        }
    private:
        DimVector &_vec;
        std::array<int, dim> _pos;
        Iterator(DimVector& vec, std::array<int, dim> pos) : _vec(vec), _pos(pos) { }
    };

    Iterator getIterator(int pos[dim]) {
        return Iterator(*this, pos);
    }

private:
    std::array<int, dim> _sizes;
};

template<typename T, int dim>
inline T& DimVector<T, dim>::Iterator::operator*() const {
    return _vec(_pos);
}

template<typename T, int dim>
inline T* DimVector<T, dim>::Iterator::operator->() const {
    return &_vec(_pos);
}

using namespace std;

int main() {

    DimVector<int, 4> v;
    v.resize(1, 2, 3, 4);
    v.at(0, 0, 0, 1) = 1;
    v.at(0, 1, 0, 0) = 1;

    for (int w = 0; w < v.size(0); ++w) {
        for (int z = 0; z < v.size(1); ++z) {
            for (int y = 0; y < v.size(2); ++y) {
                for (int x = 0; x < v.size(3); ++x) {
                    cout << v.at(w, z, y, x) << ' ';
                }
                cout << endl;
            }
            cout << "----------------------------------" << endl;
        }
        cout << "==================================" << endl;
    }
    return 0;
}

待办事项清单:

  • 优化:尽可能使用 T const&
  • 优化迭代器:预计算实际索引,然后仅更改该实际索引
  • 实现 const 存取器
  • 实现 ConstIterator
  • 实现将 DimVector 序列化到/从文件的 operator>>operator<<

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接