TensorFlow Estimator无法初始化全局变量。

3

我正在使用tensorflow slim resnet_v2提取图像特征。 resnet_v2_152.ckpt来自:resnet_v2_152.ckpt 这是我的代码。

import tensorflow as tf

import tensorflow.contrib.slim.python.slim.nets.resnet_v2 as resnet_v2


def cnn_model_fn(features, labels, mode):
    net, end_points = resnet_v2.resnet_v2_152(inputs=features, is_training=mode == tf.estimator.ModeKeys.TRAIN)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=net)
    else:
        raise NotImplementedError('only support predict!')


def parse_filename(filename):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image_resized = tf.image.resize_images(image_decoded, [256, 256])
    return image_resized


def dataset_input_fn(dataset, num_epochs=None, batch_size=128, shuffle=False, buffer_size=1000, seed=None):
    def input_fn():
        d = dataset.repeat(num_epochs).batch(batch_size)
        if shuffle:
            d = d.shuffle(buffer_size)
        iterator = d.make_one_shot_iterator()
        next_example = iterator.get_next()
        return next_example

    return input_fn


filenames = sorted(tf.gfile.Glob('/root/data/COCO/download/val2014/*'))
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames).map(parse_filename)

input_fn = dataset_input_fn(dataset, num_epochs=1, batch_size=1, shuffle=False)

estimator = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=None)

es = estimator.predict(input_fn=input_fn,
                       checkpoint_path='/root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt')
print(es.__next__())


print("Done!")

我遇到了如下错误:

2017-09-10 22:06:36.875590: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
     [[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
    status, run_metadata)
  File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
    next(self.gen)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
     [[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
     [[Node: save/RestoreV2_242/_309 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1240_save/RestoreV2_242", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

我认为可以通过将conv1/biases初始化为0来解决这个问题,但是tensorflow Estimator没有提供这样的函数。我该怎么办?

1个回答

1
我认为,你希望加载预训练的权重而不仅仅是初始化ResNet中的变量。你应该考虑使用tf.train.Scaffold对象。
模型例程应该像这样:
def cnn_model_fn(features, labels, mode):
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
        logits, end_points = resnet_v2.resnet_v2_152(features,
                                 is_training=mode == tf.estimator.ModeKeys.TRAIN)

    checkpoint_file = 'resnet_v2_152.ckpt'
    init_fn = slim.assign_from_checkpoint_fn(
        checkpoint_file,
        [var for var in tf.global_variables()])

    saver = tf.train.Saver(max_to_keep=10)
    scaffold = tf.train.Scaffold(init_fn=init_fn, saver=saver)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions={'logits': logits}, 
                                          scaffold=scaffold)
    else:
        raise NotImplementedError('only support predict!')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接