不使用iostream保存C++11随机生成器的状态

7
什么是在不使用iostream接口的情况下存储C++11随机生成器状态的最佳方法?我想像这里列出的第一个替代方法[1]那样做。然而,这种方法要求对象包含PRNG状态并且仅包含PRNG状态。特别地,如果实现使用pimpl模式(至少可能会在重新加载状态时由于加载了错误数据而导致应用程序崩溃),或者与生成序列无关的PRNG对象存在更多的状态变量,则该方法将失败。
对象的大小是由实现定义的:

我缺少像这样的成员函数

  1. size_t state_size();
  2. const size_t* get_state() const;
  3. void set_state(size_t n_elems,const size_t* state_new);

(1)应返回随机生成器状态数组的大小。

(2)应返回指向状态数组的指针。该指针由PRNG管理。

(3)应将缓冲区std::min(n_elems,state_size())从state_new指向的缓冲区复制。

这种接口允许更灵活的状态操作。或者有没有任何PRNG的状态无法表示为无符号整数数组?

[1]使用流来保存boost随机生成器状态的更快替代方法


你可能会发现这个问题有所帮助。除此之外,我认为在没有关于底层实现的一些知识的情况下,序列化RNG(或任何对象)是不可能的。如果真的有可能,那么它可能涉及到一些...奇怪的黑客技巧。 - More Axes
1
g++ 和 Ideone 的 sizeof std::mt19937 返回值分别比存储 Mersenne Twister 状态数组所需的空间多 8 和 4 个字节。如果这是单个值,那么我敢打赌它是一个指针(我假设您使用的是 64 位系统而 Ideone 不是),在序列化期间需要相应地处理它。如果它是非指针值(或两个值在 g++ 的情况下),则可以安全地按原样序列化。 - More Axes
3
你也可以尝试另一种方法:创建一个包装器类,存储种子和调用次数,只存储这些信息。在反序列化时,使用存储的种子为RNG生成随机数,并调用相应的次数。由于我不确定C++11的分布函数是否总是从它们被调用的RNG请求相同数量的伪随机字节,所以这种方法可能有些奇怪,但值得一试。虽然反序列化过程会比较慢,但序列化速度将非常快。 - More Axes
@MoreAxes 这是一个大小值:_UIntType _M_x[state_size];size_t _M_p; _M_p 是状态大小,也会被写入序列化。生成1e6个随机数需要多长时间? - user877329
根据Boost的RNG性能分析,生成一百万个随机数应该需要大约5毫秒的时间,这是由Boost实现的。我希望标准实现不会比这更差,但有时候奇怪的事情也会发生。你在哪里找到这个“_UIntType _M_x[state_size]; size_t _M_p;”的代码? - More Axes
显示剩余5条评论
1个回答

0

我已经为我在OP评论中提到的方法编写了一个简单(-ish)的测试。显然它没有经过实战考验,但是这个想法已经被代表了 - 你应该能够从这里开始。

由于读取的字节数比序列化整个引擎要少得多,因此这两种方法的性能实际上可能是可比较的。测试这个假设以及进一步的优化留给读者作为练习。

#include <iostream>
#include <random>
#include <chrono>
#include <cstdint>
#include <fstream>

using namespace std;

struct rng_wrap
{
    // it would also be advisable to somehow
    // store what kind of RNG this is,
    // so we don't deserialize an mt19937
    // as a linear congruential or something,
    // but this example only covers mt19937

    uint64_t seed;
    uint64_t invoke_count;
    mt19937 rng;

    typedef mt19937::result_type result_type;

    rng_wrap(uint64_t _seed) :
        seed(_seed),
        invoke_count(0),
        rng(_seed)
    {}

    rng_wrap(istream& in) {
        in.read(reinterpret_cast<char*>(&seed), sizeof(seed));
        in.read(reinterpret_cast<char*>(&invoke_count), sizeof(invoke_count));
        rng = mt19937(seed);
        rng.discard(invoke_count);
    }

    void discard(unsigned long long z) {
        rng.discard(z);
        invoke_count += z;
    }

    result_type operator()() {
        ++invoke_count;
        return rng();
    }

    static constexpr result_type min() {
        return mt19937::min();
    }

    static constexpr result_type max() {
        return mt19937::max();
    }
};

ostream& operator<<(ostream& out, rng_wrap& wrap)
{
    out.write(reinterpret_cast<char*>(&(wrap.seed)), sizeof(wrap.seed));
    out.write(reinterpret_cast<char*>(&(wrap.invoke_count)), sizeof(wrap.invoke_count));
    return out;
}

istream& operator>>(istream& in, rng_wrap& wrap)
{
    wrap = rng_wrap(in);
    return in;
}

void test(rng_wrap& rngw, int count, bool quiet=false)
{
    uniform_int_distribution<int> integers(0, 9);
    uniform_real_distribution<double> doubles(0, 1);
    normal_distribution<double> stdnorm(0, 1);

    if (quiet) {
        for (int i = 0; i < count; ++i)
            integers(rngw);

        for (int i = 0; i < count; ++i)
            doubles(rngw);

        for (int i = 0; i < count; ++i)
            stdnorm(rngw);
    } else {
        cout << "Integers:\n";
        for (int i = 0; i < count; ++i)
            cout << integers(rngw) << " ";

        cout << "\n\nDoubles:\n";
        for (int i = 0; i < count; ++i)
            cout << doubles(rngw) << " ";

        cout << "\n\nNormal variates:\n";
        for (int i = 0; i < count; ++i)
            cout << stdnorm(rngw) << " ";
        cout << "\n\n\n";
    }
}


int main(int argc, char** argv)
{
    rng_wrap rngw(123456790ull);

    test(rngw, 10, true);  // this is just so we don't start with a "fresh" rng
    uint64_t seed1 = rngw.seed;
    uint64_t invoke_count1 = rngw.invoke_count;

    ofstream outfile("rng", ios::binary);
    outfile << rngw;
    outfile.close();

    cout << "Test 1:\n";
    test(rngw, 10);  // test 1

    ifstream infile("rng", ios::binary);
    infile >> rngw;
    infile.close();

    cout << "Test 2:\n";
    test(rngw, 10);  // test 2 - should be identical to 1

    return 0;
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接