从Eigen::Tensor创建tensorflow::Tensor

3
我该如何从Eigen :: Tensor创建一个tensorflow :: Tensor?我可以逐个复制元素,但我希望有更好的方法。
2个回答

2
没有公共API可以直接从Eigen :: Tensor创建tensorflow :: Tensor而不复制数据。但是,您可以使用以下API创建tensorflow :: Tensor并将其解释为Eigen :: TensorMap: tensorflow :: Tensor tf_tensor(tensor_constructor_args); // 对于一般情况: Eigen :: TensorMap <type_params> eigen_tensor = tf_tensor.tensor<Type,NumDims>(); // 如果您知道张量是矩阵/向量/标量,则可以使用快捷方式 Eigen :: TensorMap <type_params> eigen_matrix = tf_tensor.matrix<Type>(); Eigen :: TensorMap <type_params> eigen_vector = tf_tensor.vector<Type>(); Eigen :: TensorMap <type_params> eigen_scalar = tf_tensor.scalar<Type>();

这样可以避免复制。此外,Eigen张量和TensorMaps共享相同的API,因此您可以互换使用它们。

非常感谢,@Benoit Steiner,但这对我帮助不大。但也许我正在尝试解决一个错误的问题。我使用Eigen::Tensor的原因是可以轻松地对它们进行切片。因此,我从文件中读取所有样本,并在每次迭代中传递给模型相应的切片。我能用纯tensorflow::Tensor做到这一点吗? - Moshe Kravchik
没有简单的方法来切分一个tensorflow::Tensor。然而,TensorFlow提供了一个tf.slice操作,您可以使用它来切分您的输入数据,并在每次迭代中单独地馈送每个样本。 - Benoit Steiner
你的回答似乎展示了如何从Tensorflow::Tensor创建一个Eigen::TensorMap,但问题是相反的...如何从现有的Eigen::Tensor创建一个Tensorflow::Tensor(这也是我遇到的同样问题,导致我来到这个问题)。 - FJC

0

这里有一个可能会有用的例子:

Eigen::Tensor<float, 3> TensorflowToEigen(const tensorflow::Tensor& tensor) {
  const tensorflow::TensorShape dims = tensor.shape();
  Eigen::Tensor<float, 3, Eigen::RowMajor> rm_tensor =
      tensor.tensor<float, 3>();
  // Change to ColMajor. swap_layout changes the ordering of dimensions, so we
  // shuffle them back.
  Eigen::Tensor<float, 3> cm_tensor =
      rm_tensor.swap_layout().shuffle(Eigen::make_index_list(2, 1, 0));
  return cm_tensor;
}

tensorflow::Tensor EigenToTensorflow(const Eigen::Tensor<float, 3>& tensor) {
  const Eigen::DSizes<int64_t, 3> dims = tensor.dimensions();
  // Change to RowMajor. swap_layout  changes the ordering of dimensions, so we
  // shuffle them back.
  Eigen::Tensor<float, 3, Eigen::RowMajor> rm_tensor =
      tensor.swap_layout().shuffle(Eigen::make_index_list(2, 1, 0));
  tensorflow::Tensor tf_tensor(
      tensorflow::DT_FLOAT,
      tensorflow::TensorShape({dims[0], dims[1], dims[2]}));
  tf_tensor.tensor<float, 3>() = rm_tensor;
  return tf_tensor;
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接