在C++中动态创建一个多维数组

6

在C++中动态创建多维数组的好方法(理解惯用法/最佳实践)是什么?

例如,假设我有三个整数 whd,我想创建一个数组 MyEnum my_array[w][h][d]。(当然,w、h 和 d 不是在编译时已知的。)

最好使用嵌套的 std::vector 还是使用 new 或其他什么方式?

额外问题:是否也可以动态设置数组的维度?

1个回答

9
一般来说,嵌套std::vector不是个好主意。通常更好的做法是分配内存以作为您的多维数组的连续块,然后索引它,就好像它是多维的。该内存块可以通过new进行分配,但除非您需要对其进行精确控制(自定义分配器),否则我建议使用单个std::vector
创建一个类来管理这种资源并在其中动态设置维数并不难。组织这样一个类的良好方法是跟踪分配的内存、每个维度的大小和每个维度的步进模式。步幅描述了必须增加多少元素才能沿着给定维度到达下一个元素。
这允许高效的索引(只需指针算术),以及非常有效的重塑:只要元素数量不变,这仅需要更改形状和步幅数组。
示例: 这里是一个非常基本的类,用于存储这样一个动态多维数组的double。它按行优先顺序存储数据,这意味着最后一个索引变化得最快。因此,对于一个2D数组,第一行被连续存储,然后是第二行,依此类推。
如果需要,您可以重新调整数组的形状,更改维数。还展示了一个基本的元素访问operator[]。该类没有其他花哨的功能,但您可以扩展它以提供任何所需的功能,例如迭代器、数据的数学操作、I/O运算符等。
/*! \file dynamic_array.h
 * Basic dynamic multi-dimensional array of doubles.
 */

#ifndef DYNAMIC_ARRAY_H
#define DYNAMIC_ARRAY_H

#include <vector>
#include <numeric>
#include <functional>

class
dynamic_array
{
    public:
        dynamic_array(const std::vector<int>& shape)
            : m_nelem(std::accumulate(shape.begin(), shape.end(),
                        1, std::multiplies<int>()))
            , m_ndim(shape.size())
            , m_shape(shape)
        {
            compute_strides();
            m_data.resize(m_nelem, 0.0);
        }

        ~dynamic_array()
        {
        }

        const double& operator[](int i) const
        {
            return m_data.at(i);
        }

        double& operator[](int i)
        {
            return m_data.at(i);
        }

        const double& operator[](const std::vector<int>& indices) const
        {
            auto flat_index = std::inner_product(
                    indices.begin(), indices.end(),
                    m_strides.begin(), 0);
            return m_data.at(flat_index);
        }

        double& operator[](const std::vector<int>& indices)
        {
            auto flat_index = std::inner_product(
                    indices.begin(), indices.end(),
                    m_strides.begin(), 0);
            return m_data.at(flat_index);
        }

        void reshape(const std::vector<int>& new_shape)
        {
            auto new_nelem = std::accumulate(
                    new_shape.begin(), new_shape.end(),
                    1, std::multiplies<int>());
            if (new_nelem != m_nelem) {
                throw std::invalid_argument("dynamic_array::reshape(): "
                        "number of elements must not change.");
            }
            m_nelem = new_nelem;
            m_ndim = new_shape.size();
            m_shape = new_shape;
            compute_strides();
        }

        const std::vector<int>& shape() const
        {
            return m_shape;
        }

        const std::vector<int>& strides() const
        {
            return m_strides;
        }

        int ndim() const
        {
            return m_ndim;
        }

        int nelem() const
        {
            return m_nelem;
        }

    private:
        int m_ndim;
        int m_nelem;
        std::vector<int> m_shape;
        std::vector<int> m_strides;
        std::vector<double> m_data;

        void compute_strides()
        {
            m_strides.resize(m_ndim);
            m_strides.at(m_ndim - 1) = 1;
            std::partial_sum(m_shape.rbegin(),
                    m_shape.rend() - 1,
                    m_strides.rbegin() + 1,
                    std::multiplies<int>());
        }
};

#endif // include guard

这是一个基本功能的演示。
/*! \file test.cc
 * Basic test of the dynamic_array class.
 */
#include "dynamic_array.h"
#include <iostream>

int main(int /* argc */, const char * /* argv */[])
{
    dynamic_array arr({2, 3});
    std::cout << "Shape: { ";
    for (auto& each : arr.shape())
        std::cout << each << " ";
    std::cout << "}" << std::endl;

    std::cout << "Strides: { ";
    for (auto& each : arr.strides())
        std::cout << each << " ";
    std::cout << "}" << std::endl;

    // Reshape array, changing number of dimensions, but
    // keeping number of elements constant.
    arr.reshape({6});
    std::cout << "Shape: { ";
    for (auto& each : arr.shape())
        std::cout << each << " ";
    std::cout << "}" << std::endl;

    // Verify that the stride pattern has now also changed.
    std::cout << "Strides: { ";
    for (auto& each : arr.strides())
        std::cout << each << " ";
    std::cout << "}" << std::endl;

    return 0;
}

假设定义类的文件与test.cc在同一目录下,您可以使用g++ -std=c++14 -o test test.cc编译测试程序。


谢谢,这就是我想要的答案! - WIP

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接