在我的模型中(使用Tensorflow Estimator),我希望数据输入更加动态。例如,在训练期间提供不同的数据(在不同的训练步骤中,向模型提供不同的数据)。
以下代码是一个示例。get_input_fn提供了input_fn,_parse函数处理特征,_parse中的_py_process_line_pair进行了精确处理。但我不知道如何将global_step(或相关参数)传递到_py_process_line_pair中。
def _parse(self, features):
def _py_process_line_pair(src_wds, trg_wds, cur_training_steps):
.... (some processing depends on cur_training_steps)
return np.array(src_ids, np.int32), np.array(trg_ids, np.int32)
src_wds, trg_wds = features['src_wds'], features['trg_wds']
src_ids, trg_ids = tf.py_func(
_py_process_line_pair,
[src_wds, trg_wds],
[tf.int32, tf.int32])
src_ids.set_shape(
[self.flags.max_src_len])
trg_ids.set_shape(
[self.flags.max_trg_len])
output = {
'src_ids': src_ids,
'trg_ids': trg_ids,
}
return output
def get_input_fn(self, is_training, input_files, num_cpu_threads):
def input_fn(params):
batch_size = params['batch_size']
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(tf.gfile.Glob(input_files)))
d = d.repeat()
d = d.shuffle(buffer_size=len(input_files))
cycle_length = min(num_cpu_threads, len(input_files))
d = d.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: self._parse(tf.parse_single_example(record, self.feature_set)),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=is_training))
return d
return input_fn