我需要使用TF2 Keras模型将形状为32x32的输入分为3个类别。我的训练集有7000个示例。
>>> X_train.shape # (7000, 32, 32)
>>> Y_train.shape # (7000, 3)
每个类别的示例数量不同(例如,class_0有约2500个示例,而class_1有约800个示例等)。
我想使用tf.data API创建一个数据集对象,该对象返回训练数据的批次,并且每个类别的示例数由[n_0,n_1,n_2]指定。
我希望这些n_i来自每个类别的样本均从X_train、Y_train中随机抽取(且放回)。
例如,如果我调用get_batch([100, 150, 125]),它应该返回100个来自class_0的X_batch随机样本,150个来自class_1,以及125个来自class_2。
如何使用TF2.0 Data API实现这一点,以便可以将其用于训练Keras模型?