有没有可能创建一个 input_fn
无限生成随机数据,以便在Tensorflow的Estimator API中使用?
基本上这就是我想要的:
def create_input_fn(function_to_generate_one_sample_with_label):
def _input_fn():
### some code ###
return feature_cols, labels
我希望可以像下面这样,使用一个 Estimator
实例调用该函数:
def data_generator():
features = ... generate a (random) feature vector ...
lablel = ... create suitable label ...
return features, labels
input_fn = create_input_fn(data_generator)
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS)
关键是要能够根据需要训练尽可能多的步骤,即时生成所需的训练数据。这是为了模型调优目的,能够尝试不同复杂度的训练数据,以便了解模型适应训练数据的能力。
编辑 正如jkm建议的那样,我尝试使用了一个实际的生成器,例如:
def create_input_fn(function, batch_size=100):
def create_generator():
while True:
features = ... generate <batch_size> feature vectors ...
lablel = ... create <batch_size> labels ...
yield features, label
g = create_generator()
def _input_fn():
return next(g)
return _input_fn
我不得不添加一个batch size才能运行它。现在它可以运行了,但是input_fn
只被调用一次,因此它不会生成任何新的数据。它只对生成的第一个<batch_size>
个样本进行训练。是否有办法告诉估计器使用提供的input_fn
来刷新数据?
num_samples
步骤重新初始化迭代器,使用sess.run(iterator.initializer)
获取新的随机值是有意义的。 - bodokaiser