在TensorFlow Estimator中,input_fn能否知道当前的训练步骤?

4

在我的模型中(使用Tensorflow Estimator),我希望数据输入更加动态。例如,在训练期间提供不同的数据(在不同的训练步骤中,向模型提供不同的数据)。

以下代码是一个示例。get_input_fn提供了input_fn,_parse函数处理特征,_parse中的_py_process_line_pair进行了精确处理。但我不知道如何将global_step(或相关参数)传递到_py_process_line_pair中。

    def _parse(self, features):
      def _py_process_line_pair(src_wds, trg_wds, cur_training_steps):
        .... (some processing depends on cur_training_steps)
        return np.array(src_ids, np.int32), np.array(trg_ids, np.int32)

    src_wds, trg_wds = features['src_wds'], features['trg_wds']
    src_ids, trg_ids = tf.py_func(
        _py_process_line_pair,
        [src_wds, trg_wds],
        [tf.int32, tf.int32])
    src_ids.set_shape(
        [self.flags.max_src_len])
    trg_ids.set_shape(
        [self.flags.max_trg_len])
    output = {
        'src_ids': src_ids,
        'trg_ids': trg_ids,
    }
    return output

  def get_input_fn(self, is_training, input_files, num_cpu_threads):

    def input_fn(params):
        batch_size = params['batch_size']
        if is_training:
            d = tf.data.Dataset.from_tensor_slices(tf.constant(tf.gfile.Glob(input_files)))
            d = d.repeat()
            d = d.shuffle(buffer_size=len(input_files))
            cycle_length = min(num_cpu_threads, len(input_files))
            d = d.apply(
                tf.data.experimental.parallel_interleave(
                    tf.data.TFRecordDataset,
                    sloppy=is_training,
                    cycle_length=cycle_length))
            d = d.shuffle(buffer_size=100)
        else:
            d = tf.data.TFRecordDataset(input_files)

        d = d.apply(
            tf.data.experimental.map_and_batch(
                lambda record:  self._parse(tf.parse_single_example(record, self.feature_set)),
                batch_size=batch_size,
                num_parallel_batches=num_cpu_threads,
                drop_remainder=is_training))
        return d
    return input_fn
1个回答

0

这非常简单:您只需要在_parse函数内部使用tf.train.get_or_create_global_step()从图形中获取global_step张量。

以下是一个可行的示例:

import tensorflow as tf
import numpy as np

# Synth dataset with 10 values
x = np.arange(10)

# This function replaces 'x' by the current step
def step_dependant_preprocessing(x):
    global_step = tf.train.get_or_create_global_step()
    return global_step

# Maps step_dependant_preprocessing
def input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((x))
    dataset = dataset.map(step_dependant_preprocessing)
    return dataset

def model_fn(features, labels, mode, params=None):
    # Get the global step
    global_step = tf.train.get_or_create_global_step()

    # Since this example doesn't use an optimizer, we need to increment
    # the global step manually.
    increment_global_step = tf.assign_add(global_step, 1)

    # Logging hook to verify that the global step inside the input fn has 
    # the same value as the one here.
    logging_hook = tf.train.LoggingTensorHook({"true_global_step": global_step, 
                                               "input_fn_global_step": features}, 
                                              every_n_iter=1)

    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=tf.constant(0.0), # Needed to use estimator.train()
        training_hooks=[logging_hook],
        train_op=increment_global_step)

estimator = tf.estimator.Estimator(model_fn=model_fn)

estimator.train(input_fn)

...

# Output

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmppuwe9hxh/model.ckpt.
INFO:tensorflow:loss = 0.0, step = 1
INFO:tensorflow:input_fn_global_step = 1, true_global_step = 1
INFO:tensorflow:input_fn_global_step = 2, true_global_step = 2 (0.007 sec)
INFO:tensorflow:input_fn_global_step = 3, true_global_step = 3 (0.002 sec)
INFO:tensorflow:input_fn_global_step = 4, true_global_step = 4 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 5, true_global_step = 5 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 6, true_global_step = 6 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 7, true_global_step = 7 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 8, true_global_step = 8 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 9, true_global_step = 9 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 10, true_global_step = 10 (0.001 sec)
INFO:tensorflow:Saving checkpoints for 11 into /tmp/tmppuwe9hxh/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0.

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接