恢复保存的 TensorFlow 模型以在测试集上进行评估。

3

我看过一些有关恢复TF模型的帖子,以及Google文档页面上有关导出图形的内容,但我认为我还是缺少了些什么。

我使用这个Gist中的代码来保存模型,以及这个utils文件来定义模型。

现在,我想将其恢复并在先前未见过的测试数据上运行,如下所示:

def evaluate(X_data, y_data):
    num_examples = len(X_data)
    total_accuracy = 0
    total_loss = 0
    sess = tf.get_default_session()
    acc_steps = len(X_data) // BATCH_SIZE
    for i in range(acc_steps):
        batch_x, batch_y = next_batch(X_val, Y_val, BATCH_SIZE)

        loss, accuracy = sess.run([loss_value, acc], feed_dict={
                images_placeholder: batch_x,
                labels_placeholder: batch_y,
                keep_prob: 0.5
                })
        total_accuracy += (accuracy * len(batch_x))
        total_loss += (loss * len(batch_x))
    return (total_accuracy / num_examples, total_loss / num_examples)

## re-execute the code that defines the model

# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')

gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')

gray /= 255.

# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')

# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')

# construct model
logits = inference(gray, keep_prob)

# calculate loss
loss_value = loss(logits, labels_placeholder)

# training
train_op = training(loss_value, 0.001)

# accuracy
acc = accuracy(logits, labels_placeholder)

with tf.Session() as sess:
    loader = tf.train.import_meta_graph('gtsd.meta')
    loader.restore(sess, tf.train.latest_checkpoint('./'))
    sess.run(tf.initialize_all_variables())   
    test_accuracy = evaluate(X_test, y_test)
    print("Test Accuracy = {:.3f}".format(test_accuracy[0]))

我的测试准确度只有 3%。然而,如果我在训练模型后不关闭笔记本,立即运行测试代码,我可以得到 95% 的准确度。

这使我认为我没有正确地加载模型?

3个回答

5
这个问题源于以下两行代码:
loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())   

第一行代码从检查点中加载保存的模型。第二行代码重新初始化模型中的所有变量(如权重矩阵、卷积过滤器和偏置向量),通常使用随机数,并覆盖已加载的值。
解决方案很简单:删除第二行代码(sess.run(tf.initialize_all_variables())),评估将继续使用从检查点加载的训练值。
PS. 这种更改可能会导致出现“未初始化变量”的错误。在这种情况下,您应该在执行loader.restore(sess, tf.train.latest_checkpoint('./'))之前执行sess.run(tf.initialize_all_variables())来初始化未保存在检查点中的任何变量。

谢谢 @mrry,我现在会尝试一下。 - Sam Hammamy
正如您所预期的那样,TF报告了一个未初始化变量的错误。当我按照您建议的将该行向上移动时,它仍然只给出2%的准确度,因此它从头开始。 - Sam Hammamy
哦,我注意到另一个问题了!tf.train.import_meta_graph()将在当前图中加载模型结构的第二个副本。如果在创建tf.Session之前的代码中构建了图的副本(包括所有权重),那么那些权重将保持未初始化状态,只有第二个副本中的权重将被恢复。有两种方法可以解决这个问题:(1)不使用tf.train.import_meta_graph(),直接创建tf.train.Saver并使用它来将检查点恢复到图的初始副本中;或者... - mrry
(2) 在使用 tf.train.import_meta_graph() 之前避免构建评估图,而是使用内省方法,如 tf.get_default_graph().get_operation_by_name() 查找原始图中的损失、精确度和占位符张量。这两种方法都可能需要一些重构(基本上您必须确保变量名称在图形和检查点中相同),但我期望选项(1)将涉及更少的工作。 - mrry
我尝试了选项(1),但它没有解决问题。我在##重新执行定义模型的代码时注释掉了sess = tf.Session(),但仍然没有运气。现在我将尝试选项2。 - Sam Hammamy
1
@mrry Dropout怎么样?在评估时间应该如何将其重置为1.0?声明一个新的 tf.placeholder() 就可以吗,还是应该从训练中恢复placeholder? - Nicolai Anton Lynnerup

2

我有类似的问题,对我来说这个方法有效:

with tf.Session() as sess:
    saver=tf.train.Saver(tf.all_variables())
    saver=tf.train.import_meta_graph('model.meta')
    saver.restore(sess,"model")

    test_accuracy = evaluate(X_test, y_test)

1

找到的答案在这里,最终得以工作的方法如下:

save_path = saver.save(sess, '/home/ubuntu/gtsd-12-23-16.chkpt')
print("Model saved in file: %s" % save_path)
## later re-run code that creates the model
# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')

gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')

gray /= 255.

# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')

# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')

# construct model
logits = inference(gray, keep_prob)

# calculate loss
loss_value = loss(logits, labels_placeholder)

# training
train_op = training(loss_value, 0.001)

# accuracy
acc = accuracy(logits, labels_placeholder)

saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, '/home/ubuntu/gtsd-12-23-16.chkpt')
        print("Model restored.")
        test_accuracy = evaluate(X_test, y_test)
        print("Test Accuracy = {:.3f}".format(test_accuracy[0]*100))

你不是忘了设置 keep_prop = 1 吗? - Nicolai Anton Lynnerup

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接