Keras中的fit_generator函数:批大小(batch_size)在哪里指定?

6

嗨,我不理解keras fit_generator文档。

我希望我的困惑是合理的。有一个批处理大小(batch_size)的概念,也有分批次训练的概念。使用model_fit(),我指定了批处理大小为128。

对我来说,这意味着我的数据集将每次输入128个样本,从而大大减轻内存压力。只要我有时间等待,它就可以训练1亿个样本数据集。毕竟,keras一次只“处理”128个样本。对吧?

但我非常怀疑,仅仅指定批处理大小并不能完全实现我的目标。仍然会使用大量内存。对于我的目标,我需要每次以128个示例进行训练。

所以,我猜这就是fit_generator所做的事情。我真的想问,为什么"batch_size"实际上没有按其名称所建议的那样工作呢?

更重要的是,如果需要fit_generator,我在哪里指定批处理大小?文档说要无限循环。生成器只循环一次。我如何每次循环128个样本,并记住上次停止的位置,下次当keras要求下一个批次的起始行号时,就能够调用它(第一批完成后将是第129行)。

2个回答

2

您需要在生成器内部处理批次大小。这里有一个生成随机批次的示例:

import numpy as np
data = np.arange(100)
data_lab = data%2
wholeData = np.array([data, data_lab])
wholeData = wholeData.T

def data_generator(all_data, batch_size = 20):

    while True:        

        idx = np.random.randint(len(all_data), size=batch_size)

        # Assuming the last column contains labels
        batch_x = all_data[idx, :-1]
        batch_y = all_data[idx, -1]

        # Return a tuple of (Xs,Ys) to feed the model
        yield(batch_x, batch_y)

print([x for x in data_generator(wholeData)])

0
首先,Keras的batch_size确实非常有效。如果你在使用GPU进行工作,你应该知道使用Keras时模型可能会非常庞大,特别是当你使用循环单元时。如果你在使用CPU,整个程序会加载到内存中,batch size对内存的影响不会太大。如果你使用fit()方法,整个数据集可能会被加载到内存中,Keras会在每一步产生批次。很难预测将使用多少内存。
至于fit_generator()方法,你应该构建一个Python生成器函数(使用yield而不是return),在每一步中生成一个批次。yield应该在一个无限循环中(我们经常使用while true:...)。
你有一些代码来说明你的问题吗?

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接