TensorFlow的数据集API返回的大小不是常量。

3

我正在使用TensorFlow的数据集 API。 使用简单的案例测试我的代码。下面展示了我使用的简单代码。问题是,当数据集大小较小时,似乎从数据集API返回的大小不一致。我确定有正确的方法来处理它。但即使我阅读了该页面和教程中的所有函数,我也找不到解决方法。

import numpy as np
import tensorflow as tf

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(16)
dataset = dataset.repeat()

iterator = tf.contrib.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(dataset)

with tf.Session() as sess:
    sess.run(training_init_op)
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))

这个数据集是灰度视频。总共有24个视频序列,步长都是200。每帧大小为64x64像素,只有一个通道。我将批处理大小设置为16,缓冲区大小为100。但是代码的结果是:

(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

返回结果如下:

视频的返回大小为16或8。我猜是因为原始数据大小只有24,当到达数据结尾时,API只返回剩余的部分。

但我不明白。我也将缓冲区大小设置为100。那意味着缓冲区应该提前填充小数据集。从缓冲区中,API应选择批量大小为16的next_element。

当我在tensorflow中使用队列类型的API时,我没有这个问题。无论原始数据的大小是多少,迭代器到达数据集末尾总会有一个时刻。我想知道其他人如何使用此API解决这个问题。

2个回答

6
尝试在调用batch()之前调用repeat():
data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.repeat()
dataset = dataset.batch(16)

我得到的结果:
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

0
您可以使用以下代码来解决问题:
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))

1
回答时,请提供更多解释。 - Bram Vanroy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接