类型错误:Fetch参数的类型无效,必须是字符串或张量。

18

我正在训练一个卷积神经网络,其结构与此例相似,用于图像分割。这些图像的尺寸为1500x1500x1,标签的大小相同。

在定义CNN结构后,以及像这个代码示例(conv_net_test.py)中启动会话时:

with tf.Session() as sess:
sess.run(init)
summ = tf.train.SummaryWriter('/tmp/logdir/', sess.graph_def)
step = 1
print ("import data, read from read_data_sets()...")

#Data defined by me, returns a DataSet object with testing and training images and labels for segmentation problem.
data = import_data_test.read_data_sets('Dataset')

# Keep training until reach max iterations
while step * batch_size < training_iters:
    batch_x, batch_y = data.train.next_batch(batch_size)
    print ("running backprop for step %d" % step)
    batch_x = batch_x.reshape(batch_size, n_input, n_input, n_channels)
    batch_y = batch_y.reshape(batch_size, n_input, n_input, n_channels)
    batch_y = np.int64(batch_y)
    sess.run(optimizer, feed_dict={x: batch_x, y: batch_y, keep_prob: dropout})
    if step % display_step == 0:
        # Calculate batch loss and accuracy
        #pdb.set_trace()
        loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
    step += 1
print "Optimization Finished"

我遇到了以下TypeError错误(详细信息见下文):

    conv_net_test.py in <module>()
    178             #pdb.set_trace()
--> 179             loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
    180         step += 1
    181     print "Optimization Finished!"

tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    370     try:
    371       result = self._run(None, fetches, feed_dict, options_ptr,
--> 372                          run_metadata_ptr)
    373       if run_metadata:
    374         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    582 
    583     # Validate and process fetches.
--> 584     processed_fetches = self._process_fetches(fetches)
    585     unique_fetches = processed_fetches[0]
    586     target_list = processed_fetches[1]

tensorflow/python/client/session.pyc in _process_fetches(self, fetches)
    538           raise TypeError('Fetch argument %r of %r has invalid type %r, '
    539                           'must be a string or Tensor. (%s)'
--> 540                           % (subfetch, fetch, type(subfetch), str(e)))

TypeError: Fetch argument 1.4415792e+2 of 1.4415792e+2 has invalid type <type 'numpy.float32'>, must be a string or Tensor. (Can not convert a float32 into a Tensor or Operation.)

我在这一点上感到困惑。也许这只是一个简单的类型转换问题,但我不确定如何/在哪里进行转换。另外,为什么损失必须是一个字符串?(假设一旦修复了这个问题,准确性也会出现同样的错误)。

非常感谢任何帮助!

1个回答

65

在使用loss = sess.run(loss)时,你会重新定义Python中的变量loss

第一次运行时它会正常工作。 第二次,你会尝试执行:

sess.run(1.4415792e+2)

因为loss现在是一个浮点数。


你应该使用不同的名称,例如:

loss_val, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接