如何将Python数据生成器转换为Tensorflow张量?

4
我有一个数据生成器,可以产生训练图像。我想通过使用Python数据生成器将数据馈入Tensorflow模型,但我不知道如何将该生成器转换为Tensorflow张量。我正在寻找类似于Keras的fit_generator()函数的解决方案。
谢谢!

1
具体是什么错误?你尝试了什么?fit_generator需要一个生成器作为参数,你还需要什么? - parsethis
1个回答

10

tf.data.Dataset.from_generator() 方法提供了一种将 Python 生成器转换为 tf.Tensor 对象的方式,这些对象评估自生成器中的每个连续元素。

假设您有一个简单的生成器,它生成元组(但也可以生成列表或 NumPy 数组):

def g():
  yield 1, 10.0, "foo"
  yield 2, 20.0, "bar"
  yield 3, 30.0, "baz"
您可以使用tf.data API将生成器首先转换为tf.data.Dataset,然后转换为tf.data.Iterator,最后转换为tf.Tensor对象的元组。
dataset = tf.data.Dataset.from_generator(g, (tf.int32, tf.float32, tf.string))

iterator = dataset.make_one_shot_iterator()

int_tensor, float_tensor, str_tensor = iterator.get_next()

你可以使用int_tensorfloat_tensorstr_tensor作为您的TensorFlow模型的输入。有关更多想法,请参阅tf.data程序员指南


生成器的yield(next)中组件的图像数据(imageio.core.util.Image)对应的tf类型是什么?是否有tf.image这样的类型? - Yu Shen
嗨,@mrry,我想知道是否可以使用tf.data.Dataset.from_generator与Python生成器一起使用,该生成器将会产生(单个图像),(Python列表)? - luvwinnie
这是TF1代码,需要在TF2中运行tf.compat.v1.data.Dataset.from_generator。我相信TF2有更好的解决方案,但目前我正在使用它。 - MrMartin
实际上,TensorFlow 2的解决方案似乎只是将 iterator = dataset.make_one_shot_iterator() 替换为 iterator = iter(dataset) - MrMartin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接