恢复Tensorflow检查点文件时出现错误

3

当使用tensorflow中的saver.restore()方法时,我遇到了以下错误。有任何想法是为什么会出现这种情况?

我是这样保存模型的:saver.save(sess, checkpoint_path, global_step=step)

错误信息如下:

tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
     [[Node: Variable_1/initial_value = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [] values: 0.9>]()]]

完整追踪:

can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
('1.1- label batch shape is ', TensorShape([Dimension(128)]))
Inferencing
('in inferemcee ', TensorShape([Dimension(128), Dimension(3072)]), <class 'tensorflow.python.framework.ops.Tensor'>)
Evaluation..
tmp/ckpt/model.ckpt-9100
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789748be0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/string_input_producer/string_input_producer_EnqueueMany = QueueEnqueueMany[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/string_input_producer, input/string_input_producer/limit_epochs)]]
I tensorflow/core/kernels/fifo_queue.cc:154] Skipping cancelled enqueue attempt
Traceback (most recent call last):
  File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
    tf.app.run()
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78b939670 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954f080 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954e5d0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789550370 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78ba28cb0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
    sys.exit(main(sys.argv))
  File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
    my_eval()
  File "/ProjectS/Cifar-Eval/my_eval.py", line 85, in my_eval
    saver.restore(sess, ckpt.model_checkpoint_path)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 864, in restore
    sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 345, in run
    results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 419, in _do_run
    e.code)
tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
     [[Node: Reshape/shape = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [4] values: -1 32 32...>]()]]
Caused by op u'Reshape/shape', defined at:
  File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
    tf.app.run()
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
    sys.exit(main(sys.argv))
  File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
    my_eval()
  File "/ProjectS/Cifar-Eval/my_eval.py", line 78, in my_eval
    logits = my_cifar.inference(images_placeholder)
  File "/ProjectS/Cifar-Eval/my_cifar.py", line 68, in inference
    images = tf.reshape(images, shape=[-1, 32, 32, 3])
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 554, in reshape
    name=name)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 397, in apply_op
    values, name=input_arg.name, dtype=dtype)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 468, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 147, in constant
    attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1710, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 988, in __init__
    self._traceback = _extract_stack()

恢复检查点文件的代码

import tensorflow as tf

import my_cifar
import my_input

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir', 'tmp/log_eval',
                           """Directory where to write event logs.""")

tf.app.flags.DEFINE_string('checkpoint_dir', 'tmp/ckpt',
                           """Directory where to read model checkpoints.""")


IMAGE_PIXELS = 32 * 32 * 3


def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the the input tensors.
  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded ckpt in the .run() loop, below.
  Args:
    batch_size: The batch size will be baked into both placeholders.
  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test ckpt sets.
  # batch_size = -1
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         IMAGE_PIXELS))
  # 32, 32, 3))
  labels_placeholder = tf.placeholder(tf.int32, shape=batch_size)

  return images_placeholder, labels_placeholder


def my_eval():
  with tf.Graph().as_default():

    v1 = tf.Variable(0)

    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

    # Get images and labels for CIFAR-10.
    val_images, val_labels = my_input.inputs(False)

    init_op = tf.initialize_all_variables()

    coord = tf.train.Coordinator()

    with tf.Session() as sess:

      sess.run(init_op)

      saver = tf.train.Saver()
      # Start the queue runners.

      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      summary_op = tf.merge_all_summaries()
      summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
                                              graph_def=sess.graph_def)


      # Build a Graph that computes the logits predictions from the
      # inference model.
      logits = my_cifar.inference(images_placeholder)

      acc = my_cifar.evaluation(logits, labels_placeholder)

      ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
      print ckpt.model_checkpoint_path
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Restored!')

      images_val_r, labels_val_r = sess.run([val_images, val_labels])
      val_feed = {images_placeholder: images_val_r,
                  labels_placeholder: labels_val_r}

      tf.scalar_summary('Acc', acc)

      print('Calculating Acc  :')

      acc_r = sess.run(acc, feed_dict=val_feed)
      print(acc_r)

      # Write results to TensorBoard
      summary_str = sess.run(summary_op)
      summary_writer.add_summary(summary_str)


      coord.join(threads)


def main(argv=None):
  my_eval()


if __name__ == '__main__':
  tf.app.run()


@mrry,请您知道如果在这里发现任何错误,请告诉我。 - Hamed MP
尝试为每个变量定义一个名称,当您加载检查点时,自动生成的变量名称可能会不同。 - fabrizioM
1个回答

0

您正在尝试加载原始网络中不存在的变量,我认为应该省略

    v1 = tf.Variable(0)

会解决问题。

如果您想添加一个新变量,您需要以不同的方式加载它,加载方法应该是这样的:

reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
restore_dict = dict()
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    if reader.has_tensor(tensor_name):
        print('has tensor ', tensor_name)
        restore_dict[tensor_name] = v
    # put the logic of the new/modified variable here and assign to the restore_dict, i.e. 
    # restore_dict['my_var_scope/my_var'] = get_my_variable()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接