在TensorFlow中如何展开(Flatten)数据集

4

我正在尝试将TensorFlow中的数据集转换为具有多个单值张量的数据集。目前数据集的结构如下:

[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...

转换后,它应该是这样的:
[12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ...

我的最初想法是在数据集上使用flat_map,然后使用reshapeunstack将每个张量转换为张量列表:

output_labels = self.dataset.flat_map(convert_labels)

...

def convert_labels(tensor):
    id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
    return tf.data.Dataset.from_tensors(id_list)

然而每个张量的形状只有部分已知(即 (?, 1)),这就是为什么无法执行 unstack 操作的原因。是否有任何方法可以在不显式迭代它们的情况下仍然“连接”不同的张量?

1个回答

4
你的解决方案非常接近,但是Dataset.flat_map()需要一个返回tf.data.Dataset对象而不是张量列表的函数。幸运的是,Dataset.from_tensor_slices()方法完全符合你的用例,因为它可以将张量分成可变数量的元素:
output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)

请注意,tf.contrib.data.unbatch() 转换实现了相同的功能,并在当前TensorFlow主分支中具有稍微更高效的实现(将包含在1.9版本中):
output_labels = self.dataset.apply(tf.contrib.data.unbatch())

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接