Tensorflow CIFAR10教程失败

4
我从教程这里下载了CIFAR10代码,并尝试运行教程。我使用以下命令运行它:
python cifar10_train.py

程序开始运行,并按预期下载数据文件。但是,在尝试打开输入文件时,它会失败并出现以下跟踪信息:

Traceback (most recent call last):
  File "cifar10_train.py", line 120, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "cifar10_train.py", line 116, in main
    train()
  File "cifar10_train.py", line 63, in train
    images, labels = cifar10.distorted_inputs()
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10.py", line 157, in distorted_inputs
    batch_size=FLAGS.batch_size)
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 161, in distorted_inputs
    read_input = read_cifar10(filename_queue)
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 87, in read_cifar10
    tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
TypeError: strided_slice() takes at least 4 arguments (3 given)

果然,当我调查代码时,在cifar10_input.py中有一个只有3个参数的strided_slice()调用:

tf.strided_slice(record_bytes, [0], [label_bytes])

虽然tensorflow文档确实指出必须至少有4个参数,但是为什么会出错呢?我已经下载了最新版的tensorflow(0.12)并且运行cifar代码的主分支。


1
这可能值得在他们的GitHub页面上提出一个问题。我往前看了几个版本,它们都需要4个参数。 - chris
谢谢。我已经在GitHub上添加了一些讨论,并得到了一个解决方案(我认为),我已经在下面添加了。我仍然有点不确定代码为什么处于这种不工作的状态,但它似乎目前正在运行。 - BobbyG
1个回答

2

GitHub上的讨论后,我进行了以下更改,似乎使它能够正常工作:

在cifar10_input.py文件中:

-  result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
+  result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)



-  depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),      [result.depth, result.height, result.width])
+  depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width])

然后在cifar10_input.py和cifar10.py两个文件中,我需要搜索"deprecated",并将其替换为根据api指南中所读内容的有效函数(希望是正确的)。以下是一些示例:

-  tf.contrib.deprecated.image_summary('images', images)
+  tf.summary.image('images', images)

并且

 - tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
 - tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
 + tf.summary.histogram(tensor_name + '/activations', x)
 + tf.summary.scalar(tensor_name + '/sparsity',

现在似乎已经顺畅运行了。我会看看它是否能够正常完成,并且我所做的修改是否产生了期望的诊断输出。

不过,我仍然希望听到一个更接近代码的权威答案。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接