flatten_parameters()是做什么的?

16

我看到许多Pytorch示例在RNN的前向函数中使用flatten_parameters

self.rnn.flatten_parameters()

我看到了这个RNNBase,它写到:

重置参数数据指针,以便它们可以使用更快的代码路径

那是什么意思?


1
我认为它只是将所有权重压缩到一块连续的内存中。 - shyam padia
1个回答

9

这可能不是对你问题的完整回答。但是,如果你查看flatten_parameters的源代码,你会注意到它调用了_cudnn_rnn_flatten_weight

...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...

这个函数是执行任务的函数。你会发现它实际上所做的是将模型的权重复制到一个vector<Tensor>中(检查params_arr的声明):

  // Slice off views into weight_buf
  std::vector<Tensor> params_arr;
  size_t params_stride0;
  std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);

  MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
                    params{params_arr, params_stride0};

并且权重复制中

  // Copy weights
  _copyParams(weight, params);

此外,需要注意的是他们在文档中明确称之为“重置”,更新了weights的原始指针以使用params的新指针,通过对orig_param.set_(new_param.view_as(orig_param));进行就地操作.set_(下划线是他们表示就地操作的符号)。
  // Update the storage
  for (size_t i = 0; i < weight.size(0); i++) {
    for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
         orig_param_it != weight[i].end() && new_param_it != params[i].end();
         orig_param_it++, new_param_it++) {
      auto orig_param = *orig_param_it, new_param = *new_param_it;
      orig_param.set_(new_param.view_as(orig_param));
    }
  }

根据n2798 (C++0x的草案)

©ISO/IECN3092

23.3.6 类模板 vector

向量是一种支持随机访问迭代器的序列容器。此外,它支持(平摊)常数时间在末尾插入和删除操作;在中间插入和删除需要线性时间。存储管理由系统自动处理,但可以提供提示以提高效率。 向量的元素是连续存储的,这意味着如果 v 是一个类型为 <T,Allocator> 的向量,其中 T 是除 bool 之外的某种类型,则对于所有 0 <= n < v.size(),它遵循 identity&v[n] == &v[0] + n


在某些情况

用户警告:RNN模块权重不是单个连续的内存块的一部分。这意味着它们需要在每次调用时进行压缩,可能会大大增加内存使用量。要再次压缩权重,请调用flatten_parameters()

他们明确建议在代码警告中拥有一个连续的内存块。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接