TensorFlow:数据集.train.next_batch如何定义?

19

我正在尝试学习TensorFlow,研究以下示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后我对下面的代码有一些问题:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", "{:.9f}".format(c))

因为mnist只是一个数据集,mnist.train.next_batch到底是什么意思?dataset.train.next_batch是如何定义的呢?

谢谢!

1个回答

29

mnist对象是从tf.contrib.learn模块中定义的read_data_sets()函数返回的。这里实现了mnist.train.next_batch(batch_size)方法,它返回一个由两个数组组成的元组,第一个表示batch_size个MNIST图像的批次,第二个表示对应这些图像的batch-size标签的批次。

图像以2-D NumPy数组的形式返回,大小为[batch_size, 784](因为MNIST图像中有784个像素),标签以1-D NumPy数组的形式返回,大小为[batch_size](如果使用one_hot=False调用了read_data_sets()),或者以2-D NumPy数组的形式返回,大小为[batch_size, 10](如果使用one_hot=True调用了read_data_sets())。

10
值得一提的是,next_batch 会在每个 epoch 结束后重新打乱样本。你可以通过 DataSet._index_in_epoch 来跟踪当前处于哪个 epoch 中,例如 mnist.train._index_in_epoch - Yibo Yang
@YiboYang 那么,使用next_batch()函数进行训练时,是否意味着不能将所有的训练数据都输入进去?我是Tensorflow的新手,请见谅如果这个问题看起来很傻。谢谢。 - Loochie

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接