在Tensorflow的input_fn中生成无限随机的训练数据

3

有没有可能创建一个 input_fn 无限生成随机数据,以便在Tensorflow的Estimator API中使用?

基本上这就是我想要的:

def create_input_fn(function_to_generate_one_sample_with_label):
    def _input_fn():
        ### some code ###
        return feature_cols, labels

我希望可以像下面这样,使用一个 Estimator 实例调用该函数:

def data_generator():
    features = ... generate a (random) feature vector ...
    lablel = ... create suitable label ...
    return features, labels

input_fn = create_input_fn(data_generator)
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS)

关键是要能够根据需要训练尽可能多的步骤,即时生成所需的训练数据。这是为了模型调优目的,能够尝试不同复杂度的训练数据,以便了解模型适应训练数据的能力。


编辑 正如jkm建议的那样,我尝试使用了一个实际的生成器,例如:

def create_input_fn(function, batch_size=100):  
    def create_generator():
        while True:
            features = ... generate <batch_size> feature vectors ...
            lablel = ... create <batch_size> labels ...
            yield features, label
    g = create_generator()
    def _input_fn():
        return next(g)
    return _input_fn

我不得不添加一个batch size才能运行它。现在它可以运行了,但是input_fn只被调用一次,因此它不会生成任何新的数据。它只对生成的第一个<batch_size>个样本进行训练。是否有办法告诉估计器使用提供的input_fn来刷新数据?

2个回答

1
我认为您可以使用最新的Tf Dataset API来获得所需的行为,您需要tensorflow>=1.2.0。
# Define number of samples and input shape for each iteration
# you can set minval or maxval as per you data distribution and label distributon requirements
 num_samples = [20000,]
 input_shape = [32, 32, 3]
dataset = tf.contrib.data.Dataset.from_tensor_slices((tf.random_normal([num_examples+input_shape]),  tf.random_uniform([num_samples], minval=0, maxval=5)))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the estimator fucntion you are using
estimator_loss = some_estimator(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
    sess.run(iterator.initializer)
    while True:
        try:
            _loss = sess.run(estimator_loss)
        except tf.errors.OutOfRangeError:
            break

我认为每隔 num_samples 步骤重新初始化迭代器,使用 sess.run(iterator.initializer) 获取新的随机值是有意义的。 - bodokaiser

0

警告 - 我本人没有使用过Tensorflow,我只是根据API文档进行操作。

话虽如此 - 如果没有什么陷阱,你应该能够做到你需要的。只需将生成器作为一个生成器(yield特征和标签而不是返回它们),并将整个生成过程放入一个无限循环中。例如:

def data_generator():
    while True:
        #do generatey things here
        yield feature, labels

这个函数可以被重复调用,每次调用都会生成新的值。


1
谢谢您的回复,它有所帮助,但仍然没有达到我想要的效果。请查看已编辑的问题 :) - Sindre Tosse

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接