TensorFlow Dataset API中make_initializable_iterator和make_one_shot_iterator的区别

15

我想知道make_initializable_iteratormake_one_shot_iterator之间的区别。
1. TensorFlow文档中提到:"one-shot" 迭代器目前不支持重新初始化。这是什么意思?
2. 以下两个代码片段等价吗? 使用 make_initializable_iterator

iterator = data_ds.make_initializable_iterator()
data_iter = iterator.get_next()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for e in range(1, epoch+1):
    sess.run(iterator.initializer)
    while True:
        try:
            x_train, y_train = sess.run([data_iter])
            _, cost = sess.run([train_op, loss_op], feed_dict={X: x_train,
                                                               Y: y_train})
        except tf.errors.OutOfRangeError:   
            break
sess.close()

使用make_one_shot_iterator

iterator = data_ds.make_one_shot_iterator()
data_iter = iterator.get_next()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for e in range(1, epoch+1):
    while True:
        try:
            x_train, y_train = sess.run([data_iter])
            _, cost = sess.run([train_op, loss_op], feed_dict={X: x_train,
                                                               Y: y_train})
        except tf.errors.OutOfRangeError:   
            break
sess.close()
1个回答

12

如果你想使用相同的代码进行培训和验证,你可能想要使用相同的迭代器,但初始化为指向不同的数据集; 像以下这样:

def _make_batch_iterator(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    ...
    return dataset.make_initializable_iterator()


filenames = tf.placeholder(tf.string, shape=[None])
iterator = _make_batch_iterator(filenames)

with tf.Session() as sess:
    for epoch in range(num_epochs):

        # Initialize iterator with training data
        sess.run(iterator.initializer,
                 feed_dict={filenames: ['training.tfrecord']})

        _train_model(...)

        # Re-initialize iterator with validation data
        sess.run(iterator.initializer,
                 feed_dict={filenames: ['validation.tfrecord']})

        _validate_model(...)

使用单次迭代器,您无法像这样重新初始化它。


你能解释一下initializable和reinitializable迭代器之间的区别吗? - Nima
我们能否“重新加载”数据集?例如:features,labels = trainDataset.make_one_shot_iterator().get_next(),graph = fn(features,labels)。然后在训练之后,重新加载 features,labels = TestDataset.xxx().get_next 吗?因为我认为它是一个不同的数据集而不是重新初始化。 - Leighton
@Leighton 这可能意味着您需要创建额外的数据集图表,这通常不是您想要的。 - Alex Kreimer

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接