我已经通过以下代码,让我的 Keras 模型能够使用 tf.Dataset 进行工作:
# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)
# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()
# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)
# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)
但是当我尝试将validation_data
参数传递给模型时,fit
告诉我不能与生成器一起使用。在使用tf.Dataset
时,是否有办法同时使用验证数据?
例如在TensorFlow中,我可以这样做:
# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)
# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
batch_train.output_shapes)
# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)
那么只需使用 sess.run(init_op_train)
和 sess.run(init_op_valid)
来在数据集之间切换。
我尝试了实现一个回调函数来完成这个任务(切换到验证集、预测和返回),但它告诉我不能在回调函数中使用model.predict。
有人能帮助我使用Keras+Tf.Dataset使验证集正常工作吗?
编辑:将答案合并到代码中:
因此,最终对我有用的是,感谢所选的答案:
# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset
# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)
# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)
# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)
def make_iterator(dataset):
iterator = dataset.make_one_shot_iterator()
next_val = iterator.get_next()
with K.get_session().as_default() as sess:
while True:
*inputs, labels = sess.run(next_val)
yield inputs, labels
这不会引入任何额外开销