使用tensorflow Estimator时出现"TypeError: 'Tensor' object is not iterable"错误。

7
我有一个程序生成的(无限的)数据源,并试图将其用作高级Tensorflow Estimator的输入,以训练基于图像的3D物体检测器。我设置了数据集,就像Tensorflor Estimator Quickstart中所示,并且我的dataset_input_fn返回一个特征和标签张量的元组,就像Estimator.train函数指定的那样,并且如本教程所示,但是当尝试调用train函数时出现错误:TypeError:'Tensor'对象不可迭代。我做错了什么?
    def data_generator():
        """
        Generator for image (features) and ground truth object positions (labels)

        Sample an image and object positions from a procedurally generated data source
        """
        while True:
            source.step()  # generate next data point

            object_ground_truth = source.get_ground_truth() # list of 9 floats
            cam_img = source.get_cam_frame()  # image (224, 224, 3) 
            yield (cam_img, object_ground_truth)

    def dataset_input_fn():
        """
        Tensorflow `Dataset` object from generator
        """

        dataset = tf.data.Dataset.from_generator(data_generator, (tf.uint8, tf.float32), \
            (tf.TensorShape([224, 224, 3]), tf.TensorShape([9])))
        dataset = dataset.batch(16)

        iterator = dataset.make_one_shot_iterator()

        features, labels = iterator.get_next()
        return features, labels

    def main():
        """
        Estimator [from Keras model](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models) 

        Try to call `est_vgg.train()` leads to the error
        """
        ....
        est_vgg16 = tf.keras.estimator.model_to_estimator(keras_model=keras_vgg16)
        est_vgg16.train(input_fn=dataset_input_fn, steps=10)
        ....

这里是完整代码

(注意:与此问题不同,事物命名可能不同)

以下是堆栈跟踪:

Traceback (most recent call last):
  File "./rock_detector.py", line 155, in <module>
    main()
  File "./rock_detector.py", line 117, in main
    est_vgg16.train(input_fn=dataset_input_fn, steps=10)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 145, in model_fn
    labels)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 92, in _clone_and_build_model
    keras_model, features)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 58, in _create_ordered_io
    for key in estimator_io_dict:
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

我认为你只需要:get_next = iterator.get_next(); est_vgg16.train(input_fn=get_next, steps=10,但是我不使用keras,所以对于在那里使用的.train函数并不完全熟悉。 - David Parks
请问您能否提供错误的完整堆栈跟踪信息? - mrry
带有堆栈跟踪的更新帖子。使用高级API时,很难理解内部发生了什么。我通过切换到TensorFlow的低级接口,并手动“馈送”生成器,以相同的努力使其工作。虽然高级API的好处是它处理所有训练和详细信息,并可以优化处理过程。 - matwilso
1个回答

5
让您的输入函数返回一个像这样的特征字典:
def dataset_input_fn():
  ...
  features, labels = iterator.get_next()
  return {'image': features}, labels

3
谢谢,问题已解决。作为对未来访问者的提示,我还需要在dataset_input_fn中将tf.uint8更改为tf.float32 - matwilso

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接