在Keras中多次调用“fit”的含义是什么?

52

我一直在处理数百GB的图像,为此我创建了一个训练函数,它会将这些图像分块成4GB的大小,并对每个块调用fit。我担心我只是在最后一块上进行训练而不是整个数据集。

实际上,我的伪代码看起来像这样:

DS = lazy_load_400GB_Dataset()
for section in DS:
    X_train = section.images
    Y_train = section.classes

    model.fit(X_train, Y_train, batch_size=16, nb_epoch=30)
我知道API和Keras论坛上说这将在整个数据集上训练,但我无法直观地理解为什么网络不会仅在最后的训练块上重新学习。希望能得到一些帮助来理解这个问题。 最好, 乔

3
创建一个自定义生成器类并在fit_generator中使用它可以避免多次调用fit的问题。 - Dr. Snoopy
2个回答

51

17
该语句"是的,连续调用 fit 会逐步训练模型"似乎是正确的,但是当我使用连续调用 fit 训练我的模型时,我看到的是:第一次调用需要一段时间才能达到我数据集通常的 val_acc: 0.9x,每次后续调用都比这个初始调用更快,但是每次调用 fit 时,我都会看到 val_acc 下降到约为 0.05 的水平,然后再回升到90%左右。如果它是在逐步训练模型,为什么会发生这种情况呢? - alexexchanges
2
我也想听到这个问题的答案。 - Naveen Reddy Marthala

39
对于无法放入内存的数据集,Keras Documentation FAQ section 中有答案。

您可以使用 model.train_on_batch(X, y)model.test_on_batch(X, y) 进行批量训练。请参阅 models documentation

另外,您可以编写一个生成器,产生批量的训练数据,并使用方法 model.fit_generator(data_generator, samples_per_epoch, nb_epoch)

您可以在我们的 CIFAR10 example 中看到批量训练的实际操作。

因此,如果您想要按照您目前的方式迭代数据集,您应该使用 model.train_on_batch 并自己处理批次大小和迭代。

需要注意的是,您应该确保每次训练模型时,使用的样本顺序在每个epoch之后都被打乱。您编写的示例代码似乎没有对数据集进行打乱。您可以在这里这里阅读更多关于打乱数据的内容。


3
我知道我们可以使用train_on_batch,但我仍然不明白为什么OP的原始代码不起作用。fit()是否在每次数据馈送的迭代中更新模型? - ymeng
2
鉴于下面@curlyhairedgenius的回答,目前还不清楚多次调用model.fitmodel.train_on_batch之间的区别是什么。是否有区别? - Olshansky
3
model.fit 管理您提供的数据集的输入和输出,将它们分成批次并逐步训练每个批次,同时报告进度并支持在训练过程中执行自定义回调。而 model.train_on_batch 则仅采用一个输入和输出的批次,并针对单个步骤训练模型。 - Makis Tsantekidis

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接