TensorFlow 类型错误:'BatchDataset' 对象不可迭代 / 类型错误: 'CacheDataset' 对象不支持下标操作

5
我正在遵循TensorFlow入门指南。它特别提到要在鸢尾花分类的示例项目上启用急切执行。

导入所需的Python模块,包括TensorFlow,并为此程序启用急切执行。急切执行使TensorFlow立即评估操作,返回具体值,而不是创建稍后执行的计算图。如果您习惯于REPL或python交互式控制台,则会感到如家。

因此,我按照说明启用了急切执行,并继续按照说明进行。然而,当我到达将数据集准备成张量流数据集的部分时,我遇到了一个错误。

代码单元格

train_dataset = tf.data.TextLineDataset(train_dataset_fp)
train_dataset = train_dataset.skip(1)             # skip the first header row
train_dataset = train_dataset.map(parse_csv)      # parse each row
train_dataset = train_dataset.shuffle(buffer_size=1000)  # randomize
train_dataset = train_dataset.batch(32)

# View a single example entry from a batch
features, label = iter(train_dataset).next()
print("example features:", features[0])
print("example label:", label[0])

错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-61bfe99af85b> in <module>()

      7 
      8 # View a single example entry from a batch
----> 9 features, label = iter(train_dataset).next()
     10 print("example features:", features[0])
     11 print("example label:", label[0])

TypeError: 'BatchDataset' object is not iterable

我只想继续跟随这些示例。 我该如何将 BatchDataset 对象转换为可迭代的对象?
2个回答

4

事实证明,我在项目中确实没有完成某些步骤,导致了这个问题。

将TensorFlow从1.7升级到1.8:

!pip install --upgrade tensorflow

检查您的TensorFlow是否已更新

此代码单元格:

from __future__ import absolute_import, division, print_function

import os
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

print("TensorFlow version: {}".format(tf.VERSION))
print("Eager execution: {}".format(tf.executing_eagerly()))

应该返回以下输出:
TensorFlow version: 1.8.0
Eager execution: True

2

请参考这里的备选方案。我们也可以使用as_numpy_iterator()从tensorflow数据集中获取值。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接