Keras模型.fit()与tf.dataset API + 验证数据的使用方法

21

我已经通过以下代码,让我的 Keras 模型能够使用 tf.Dataset 进行工作:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)

# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()

# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)

但是当我尝试将validation_data参数传递给模型时,fit告诉我不能与生成器一起使用。在使用tf.Dataset时,是否有办法同时使用验证数据?

例如在TensorFlow中,我可以这样做:

# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)

# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
                                           batch_train.output_shapes)

# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)

那么只需使用 sess.run(init_op_train)sess.run(init_op_valid) 来在数据集之间切换。

我尝试了实现一个回调函数来完成这个任务(切换到验证集、预测和返回),但它告诉我不能在回调函数中使用model.predict。

有人能帮助我使用Keras+Tf.Dataset使验证集正常工作吗?

编辑:将答案合并到代码中:

因此,最终对我有用的是,感谢所选的答案:

# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset

# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)

# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
    generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
    epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)

def make_iterator(dataset):
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        while True:
            *inputs, labels = sess.run(next_val)
            yield inputs, labels

这不会引入任何额外开销


2
在你的更改之后,你如何将dataset_inputs输入到模型中?我不明白keras.models.Model(inputs=dataset_inputs, outputs=logits)这一行的含义,我假设这是“model”变量的内容,请问你能否提供完整的代码?我遇到了完全相同的问题,但似乎无法应用你的代码。谢谢! - josesuero
@mark rofail,我认为这行是错误的,应该改成batch_valid:itr_valid = make_iterator(batch_train)。 - Robert Lugg
2个回答

3
我使用fit_genertor解决了这个问题。我在这里找到了解决方案(链接),并应用了@Dat-Nguyen的解决方案。
您只需要创建两个迭代器,一个用于训练,一个用于验证,然后创建自己的生成器,从数据集中提取批次并以(batch_data,batch_labels)的形式提供数据。最后,在model.fit_generator中传递train_generator和validation_generator。

所以我必须将TensorFlow迭代器包装在Python生成器中,如下所示:iterator = ds.make_one_shot_iterator() while True: next_val = iterator.get_next() yield sess.run(next_val) - Mark Rofail
据我所知,您无法访问小批量指标,但是定义自定义损失函数并在编译模型时将其包含在指标中应该可以做到这一点。Keras应该会给出每个时期的平均AUC。这是我想出的AUC损失函数: `from sklearn.metrics import roc_auc_scoredef roc_auc(y_true, y_pred): return roc_auc_score(y_true, y_pred)` - Mark Rofail
抱歉,我的意思是“我无法访问验证”,而不是“我可以”。可能我会为此开一个新帖子。 - W. Sam
如果在 epoch 结束时提供了验证数据,Keras 将自动计算所有指标。 - Mark Rofail
1
@Dat-Nguyen的解决方案已更改为直接将迭代器传递给model.fit而不是fit_generator。它应该支持TensorFlow 1.9,但在我的情况下没有起作用,会出现“AttributeError:'Iterator'对象没有属性'ndim'”的错误。 - W. Sam
显示剩余3条评论

2
连接可重置迭代器到Keras模型的方法是插入一个同时返回x和y值的迭代器:
sess = tf.Session()
keras.backend.set_session(sess) 

x = np.random.random((5, 2))
y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
input_dataset = tf.data.Dataset.from_tensor_slices((x, y))

# Create your reinitializable_iterator and initializer
reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
init_op = reinitializable_iterator.make_initializer(input_dataset)

#run the initializer
sess.run(init_op) # feed_dict if you're using placeholders as input

# build keras model and plug in the iterator
model = keras.Model.model(...)
model.compile(...)
model.fit(reinitializable_iterator,...)

如果您也有一个验证数据集,最简单的做法是创建一个单独的迭代器并将其插入到validation_data参数中。请确保定义您的steps_per_epoch和validation_steps,因为它们无法被推断出来。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接