tf.train.MonitoredTrainingSession和来自Dataset的可重新初始化迭代器

12

似乎MonitoredTrainingSession在第一次调用.run(..)之前执行一些操作(日志记录?),这意味着当我执行以下操作时:

train_data = reader.traindata() # returns a tf.contrib.data.Dataset
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
init_train = it.make_initializer(train_data)
ne = it.get_next()
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path)

... no calls to ts.run ...

ts.run(init_train)

这会产生错误:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element

看起来MonitoredTrainingSession在运行我提供的操作之前进行了一些操作,导致无法与数据集中的可重新初始化迭代器一起使用。

我确定我错过了什么,并且很想知道是什么:-)


部分回答自己,我通过使用以下方式成功解决了问题: .ts._coordinated_creator.tf_sess.run(init_train) 但这种方法非常像是一种hack,而不是推荐的做法? - Viktor Ogeman
1个回答

8

看起来Tensorflow中还没有直接的解决方案。是的,他们没有为Dataset API提供完全的支持,这很奇怪。

原因是当从检查点加载时,监视会话跳过运行init_op。因此,迭代器初始化程序应该是一个局部变量。

当前的解决方法建议在这个问题中给出 - https://github.com/tensorflow/tensorflow/issues/12859

scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(),
                                     init_train))
with tf.train.MonitoredTrainingSession(scaffold=scaffold, 
                                       checkpoint_dir=checkpoint_dir) as sess:
    while not sess.should_stop():
        sess.run(train_op)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接