TensorFlow数据集API使图形协议缓冲区文件大小翻倍。

4

摘要: 使用新的tf.contrib.data.Dataset功能会使我的图形protobuf文件大小翻倍,且我无法在Tensorboard中可视化该图形。

详细信息:

我正在尝试使用新的TensorFlow tf.contrib.data.Dataset功能与tf.contrib.learn.Experiment框架。我的输入数据被定义为输入函数,返回特征和标签张量。

如果我使用tf.train.slice_input_producer函数创建输入函数,则生成的graph.pbtxt文件大小为620兆字节,.meta文件大小约为165兆字节(完整代码在此处)。

def train_inputs():
    with tf.name_scope('Training_data'):
        x = tf.constant(mnist.train.images.reshape([-1, 28, 28, 1]))
        y = tf.constant(mnist.train.labels)
        sliced_input = tf.train.slice_input_producer(
            tensor_list=[x, y], shuffle=True)
        return tf.train.shuffle_batch(
            sliced_input, batch_size=batch_size,
            capacity=10000, min_after_dequeue=batch_size*10)

现在,如果我使用新的tf.contrib.data.Dataset.from_tensor_slices来创建输入函数,就像下面的代码块中一样(完整代码在此处),那么我的graph.pbtxt文件的大小会增加一倍,达到1.3G,而.meta文件的大小会增加一倍,达到330M。
def train_inputs():
    with tf.name_scope('Training_data'):
        images = mnist.train.images.reshape([-1, 28, 28, 1])
        labels = mnist.train.labels
        dataset = tf.contrib.data.Dataset.from_tensor_slices(
            (images, labels))
        dataset = dataset.repeat(None)  # Infinite
        dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
        next_example, next_label = iterator.get_next()
        return next_example, next_label

现在由于graph.pbtxt文件太大,TensorBoard需要很长时间来解析此文件,我无法通过可视化方式调试我的模型图形。我在数据集文档中发现,大小增加的原因是:"数组的内容将被多次复制"解决方案是使用占位符。然而,在这种情况下,我需要使用活动会话将numpy数组馈送到占位符以初始化迭代器:
sess.run(iterator.initializer, feed_dict={features_placeholder: features, labels_placeholder: labels})

然而,使用tf.contrib.learn.Experiment框架时,这似乎超出了我的控制范围。

我应该如何使用Experiment框架初始化迭代器的初始值?或者找到一个使用Dataset API而不增加图形大小的解决方法?

1个回答

3

我使用 tf.train.SessionRunHook 解决了我的问题。我创建了一个 SessionRunHook 对象,在会话创建后初始化迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initiliser_func = None

    def after_create_session(self, session, coord):
        self.iterator_initiliser_func(session)

创建数据集迭代器时设置了初始化函数:

iterator_initiliser_hook.iterator_initiliser_func = \
    lambda sess: sess.run(
        iterator.initializer,
        feed_dict={images_placeholder: images,
                   labels_placeholder: labels})

我将钩子对象传递给tf.contrib.learn.Experimenttrain_monitorseval_hooks参数。

生成的graph.pbtxt文件现在仅有500K,而.meta文件仅有244K。

完整示例请点击此处。


不错,这也解决了我的问题。但似乎是一个变通方法? - Ohad Meir

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接