在Keras的fit_generator中,“shuffle”参数是什么意思?

11

我手动构建了一个数据生成器,每次调用会产生一个[input, target]元组。 我将生成器设置为每个周期随机打乱训练样本的顺序。然后我使用fit_generator调用我的generator,但是对这个函数中的“shuffle”参数感到困惑:

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

来自Keras API

shuffle:是否在每个 epoch 开始的时候,对 batch 的顺序进行洗牌。只能使用 keras.utils.Sequence 实例。

我原本认为“shuffle”应该由generator来完成。如果我的自定义generator决定在每次迭代中输出哪个批次,它怎么能洗牌批次的顺序呢?

1个回答

12
正如您引用的文档所述,shuffle参数仅适用于实现keras.utils.Sequence的生成器。
如果您使用的是“简单”生成器(例如keras.preprocessing.image.ImageDataGenerator或自定义的非Sequence生成器),则该生成器实现了一个方法来返回单个批次(使用yield - 您可以在此问题中了解更多信息)。因此,只有生成器本身控制返回哪个批次。 keras.utils.Sequence被引入以支持多处理:

Sequence是更安全的进行多处理的方法。这种结构保证网络每个时期仅对每个样本进行一次训练,而这不是生成器的情况。

为此,您需要实现一个方法,通过批索引返回批:__getitem__(self, idx)。如果启用shuffle参数,则将随机顺序的索引传递给__getitem__方法。
但是,您也可以将其设置为false,并通过实现on_epoch_end方法来进行自己的洗牌。

1
如果我在_fit-generator_中使用自己的非序列生成器并设置_shuffle=True_,会发生什么? - Tu Bui
1
什么也没有发生。如果你查看源代码:https://github.com/keras-team/keras/blob/3b444513b52cf05e7d40f2ffdb7ab7283bb2ce06/keras/engine/training.py#L2168,则当你的生成器是一个Sequence时,该参数才会被使用。 - Mark Loyman
在方法__getitem_(...)中,有没有办法知道哪个工作线程正在获取特定批次(由“idx”标识)?提出这个问题的动机是我想将工作负载分散到构建单独数据集(例如负样本)的2个工作线程上。理想情况下,这应该在on_epoch_end中完成,但可能不会被多进程执行? - kawingkelvin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接