如何在一个HDF5数据文件中读取批次进行训练?

6

我有一个大小为(21760, 1, 33, 33)的hdf5训练数据集。21760是训练样本的总数。我想使用大小为128的小批量训练数据来训练网络。

我想问:

如何使用tensorflow每次从整个数据集中提取128个小批量训练数据?

3个回答

12

如果你的数据集太大而无法像 keveman 建议的那样导入内存,你可以直接使用 h5py 对象:

import h5py
import tensorflow as tf

data = h5py.File('myfile.h5py', 'r')
data_size = data['data_set'].shape[0]
batch_size = 128
sess = tf.Session()
train_op = # tf.something_useful()
input = # tf.placeholder or something
for i in range(0, data_size, batch_size):
    current_data = data['data_set'][position:position+batch_size]
    sess.run(train_op, feed_dict={input: current_data})

如果您想的话,还可以运行大量迭代并随机选择一批:

import random
for i in range(iterations):
    pos = random.randint(0, int(data_size/batch_size)-1) * batch_size
    current_data = data['data_set'][pos:pos+batch_size]
    sess.run(train_op, feed_dict={inputs=current_data})

或者按顺序:

for i in range(iterations):
    pos = (i % int(data_size / batch_size)) * batch_size
    current_data = data['data_set'][pos:pos+batch_size]
    sess.run(train_op, feed_dict={inputs=current_data})

你可能希望编写一些更复杂的代码,随机地遍历所有数据,但要跟踪已使用的批次,以便不会比其他批次使用更频繁。在对整个训练集进行完整运行后,重新启用所有批次并重复执行。


1
这种方法在逻辑上似乎是正确的,但我使用它并没有得到任何积极的结果。我的最佳猜测是:使用上面的代码示例1,在每次迭代中,网络都会重新训练,忘记了之前学习的所有内容。因此,如果我们每次迭代获取30个样本或批次,那么在每个循环/迭代中,只有30个数据样本被使用,然后在下一个循环中,所有内容都被覆盖。 - rocksyne

7
你可以将hdf5数据集读入numpy数组,并将numpy数组的切片提供给TensorFlow模型。类似以下的伪代码可以工作:
import numpy, h5py
f = h5py.File('somefile.h5','r')
data = f.get('path/to/my/dataset')
data_as_array = numpy.array(data)
for i in range(0, 21760, 128):
  sess.run(train_op, feed_dict={input:data_as_array[i:i+128, :, :, :]})

1
谢谢。但是当训练迭代次数i很大,例如100000时,如何进行喂养? - karl_TUM
如果你只有21760个训练样本,那么你只有21760/128个不同的小批次。你需要在i循环外编写一个外部循环,并在训练数据集上运行多个时期。 - keveman
1
我有一个疑惑点。当原始数据被洗牌并提取小批量时,这是否意味着小批量的数量超过了“21760/128”? - karl_TUM

2

alkamen's 的方法在逻辑上似乎是正确的,但我使用它没有得到任何积极的结果。我最好的猜测是:在上面的代码示例1中,在每次迭代中,网络都会重新训练,忘记了之前学到的所有内容。因此,如果我们每次迭代获取30个样本或批次,那么在每次循环/迭代中,只有30个数据样本被使用,然后在下一次循环中,所有内容都被覆盖。

以下是这种方法的屏幕截图

Training always starting afresh

如图所示,损失和准确性总是重新开始。如果有人能分享可能的解决方法,我将不胜感激。


2
你在标记其他用户时,我的名字拼写为'n'而不是'm' =) - alkanen
@rocksyne,我遇到了类似的问题,每个批次后网络都没有学习。你解决了吗? - CAta.RAy
1
很遗憾,我在TensorFlow方面没有什么好运气。我已经创建了一个Github Gist代码来帮助您理解。所以我转而使用keras。我构建了一个自定义生成器来批量获取数据。请在此处查找(https://gist.github.com/rocksyne/a4022afd7a5aaacdfb873218dba21d0c)。这个函数被称为Kera的fit_generator函数(https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/)如果您能分享更多关于您正在做的事情的信息,我可以更好地理解您并提供更具针对性的答案。 - rocksyne
@rocksyne。感谢您的回复。我目前遇到了类似的情况。我在keras中使用我的网络很好,但是当我尝试在tensorflow中实现相同的网络时,它就无法工作。我甚至尝试创建了最简单的网络(我的问题),但没有成功。我确定我的代码有问题,但我无法找出原因。 - CAta.RAy
@CAta.RAy 我已经查看了你的代码,从表面上看,一切都很好。我说从表面上看,因为我没有你的数据集来尝试这些代码。我将查看此当前线程中的代码,并尝试找到使用tensorflow的解决方法。如果我找到了,我一定会告诉你。 - rocksyne
显示剩余5条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接