Tensorflow中的批次迭代是如何工作的?

3

我正在尝试在我的数据上重用PTB语言模型,但缺乏对Tensorflow的了解,不知道它如何处理训练数据的批次迭代。以下是我对训练期间批次迭代的理解:

while epoch <= maxepoch do
  for minibatch in data_iterator() do
    model.forward(minibatch)
    (...)
  end
end

这不能再简单了,是吧?其他框架中也有类似的功能,但Tensorflow没有 :) 以下是PTB语言模型教程中的小批量函数示例:

def ptb_producer(raw_data, batch_size, num_steps, name=None):
    with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
        raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)

        data_len = tf.size(raw_data)
        batch_len = data_len // batch_size
        data = tf.reshape(raw_data[0 : batch_size * batch_len],
                                            [batch_size, batch_len])

        epoch_size = (batch_len - 1) // num_steps
        assertion = tf.assert_positive(
                epoch_size,
                message="epoch_size == 0, decrease batch_size or num_steps")
        with tf.control_dependencies([assertion]):
            epoch_size = tf.identity(epoch_size, name="epoch_size")

        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
        x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps])
        x.set_shape([batch_size, num_steps])
        y = tf.strided_slice(data, [0, i * num_steps + 1], [batch_size, (i + 1) * num_steps + 1])
        y.set_shape([batch_size, num_steps])
        return x, y

一旦调用此函数,它将返回x个输入和y个目标。我在这里没有看到Python迭代器的迹象,但是有一个对tf.strided_slice的调用,该函数使用由tf.train.range_input_producer生成的i索引,因此应该模拟对数据的滑动窗口。然而,在训练之前仅调用一次此函数,那么它如何迭代我的数据呢?这一点不清楚。能否有人解释一下这种"魔法"和完全晦涩的Tensorflow机制?
1个回答

2
"神奇"的地方在于调用tf.train.range_input_producer的那一行代码中:
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

这段代码创建了一个操作,它从保存0..epoch_size-1整数的队列中弹出值。换句话说,它遍历了范围0..epoch_size-1。


是的,这似乎与直觉相反。以下是在tensorflow中使用队列的简单可运行示例:

index = tf.train.range_input_producer(10, shuffle=False).dequeue()

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(15):
    print(sess.run(index))

  coord.request_stop()
  coord.join(threads)

运行后,您应该看到从09的值,然后再看到从04的5个值。请注意sess.run评估相同的张量index,但每次都会得到不同的值。可以添加进一步依赖于index的操作,它们将使用新的index值进行评估。
另请注意,队列在另一个线程中操作,因此要使用tf.train.range_input_producer,需要启动一个Coordinator并生成多个线程(最后停止它们)。如果尝试没有Coordinator运行相同的示例,则 sess.run(index) 将阻止脚本执行。
您可以在此示例中进行调试,例如设置shuffle=True等。
回到PTB生产者片段:
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
x = tf.strided_slice(data, [0, i*num_steps], [batch_size, (i+1)*num_steps])
x.set_shape([batch_size, num_steps])
y = tf.strided_slice(data, [0, i*num_steps+1], [batch_size, (i+1)*num_steps+1])
y.set_shape([batch_size, num_steps])

现在应该很清楚,即使xy被定义为简单张量,它们实际上是data的切片的迭代器。所有线程工作都由tf.train.Supervisor处理。因此,调用一个依赖于xy的优化操作将自动获取新批次。

建议阅读:


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接