如何在训练后保存/恢复模型?

660

在使用Tensorflow训练模型后:

  1. 如何保存已经训练好的模型?
  2. 如何在之后恢复这个已保存的模型?

你能恢复Inception模型中使用的变量吗?我也在尝试解决完全相同的问题,但我无法编写用于训练Inception模型的变量集(其中我有ckpt文件)。 - exAres
我还没有尝试过Inception模型。你有这个模型的网络结构和名称吗?你需要复制网络,然后按照Ryan的说明加载权重和偏差(ckpt文件)。也许自2015年11月以来有些变化,现在可能有更简单的方法,但我不确定。 - mathetes
哦,好的。我之前加载过其他预训练的TensorFlow模型,但是现在正在寻找Inception模型的变量规格。谢谢。 - exAres
1
如果你想要恢复训练,请使用 Saver 检查点。如果你保存模型以供参考,请使用 TensorFlow SavedModel API。 - HY G
如果你正在使用LSTM,你将会有一个从字符串到字符列表的映射,确保按照相同的顺序保存和加载该列表!这不包括保存模型权重和模型图形网络,并且会导致在更改会话或数据更改时似乎未加载模型。 - devssh
29个回答

272
在(以及之后的)Tensorflow版本0.11中:
保存模型:
import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

恢复模型:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

这些以及更多高级用例在这里被很好地解释了。

一个快速完整的教程,用于保存和恢复Tensorflow模型


5
+1 针对这个 # 直接访问已保存的变量 print(sess.run('bias:0'))

这将打印出我们保存的偏置值2,有助于调试以确认模型是否正确加载。可以使用"All_varaibles = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES"获取变量。另外,在还原之前必须先运行"sess.run(tf.global_variables_initializer())"。

- LGG
1
你确定我们必须再次运行global_variables_initializer吗?我使用global_variable_initialization还原了我的图形,在相同的数据上每次都给出不同的输出。因此,我将初始化注释掉,只还原了图形、输入变量和操作,现在它可以正常工作了。 - Aditya Shinde
@AdityaShinde 我不明白为什么每次都会得到不同的值。而且我没有包括恢复变量初始化步骤。顺便说一下,我正在使用自己的代码。 - Chaine
6
当你恢复张量时,为什么要在名称后面添加“:0”? - Sahar Rabinoviz
你的张量示例对我有效,但操作不行。对于操作,我必须使用get_operation_by_name而没有添加:0。即op_to_restore = graph.get_operation_by_name("op_to_restore") - Carsten
显示剩余2条评论

183

在TensorFlow版本0.11.0RC1及其后的版本中,您可以通过调用tf.train.export_meta_graphtf.train.import_meta_graph直接保存和恢复模型,具体请参见https://www.tensorflow.org/programmers_guide/meta_graph

保存模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

恢复模型

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

4
如何从已保存的模型中加载变量?如何复制其他变量中的值? - neel
10
我无法让这段代码正常运行。虽然模型已经被保存,但我无法恢复它。出现了以下错误:"<built-in function TF_Run> 返回了一个带有错误的结果"。 - Saad Qureshi
2
当我按照上面所示的方式访问变量时,它可以正常工作。但是,如果我使用tf.get_variable_scope().reuse_variables()后跟var = tf.get_variable("varname")更直接地获取变量,就会出现错误:“ValueError: Variable varname does not exist, or was not created with tf.get_variable()。”为什么?这不应该是可能的吗? - jpp1
4
这适用于变量,但是在还原图后如何访问占位符并向其提供值呢? - kbrose
12
这只展示了如何恢复变量。如何在不重新定义网络的情况下恢复整个模型并对新数据进行测试呢? - Chaine
显示剩余7条评论

172

Tensorflow 2文档

保存检查点

改编自官方文档

# -------------------------
# -----  Toy Context  -----
# -------------------------
import tensorflow as tf


class Net(tf.keras.Model):
    """A simple linear model."""

    def __init__(self):
        super(Net, self).__init__()
        self.l1 = tf.keras.layers.Dense(5)

    def call(self, x):
        return self.l1(x)


def toy_dataset():
    inputs = tf.range(10.0)[:, None]
    labels = inputs * 5.0 + tf.range(5.0)[None, :]
    return (
        tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
    )


def train_step(net, example, optimizer):
    """Trains `net` on `example` using `optimizer`."""
    with tf.GradientTape() as tape:
        output = net(example["x"])
        loss = tf.reduce_mean(tf.abs(output - example["y"]))
    variables = net.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    return loss


# ----------------------------
# -----  Create Objects  -----
# ----------------------------

net = Net()
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)

# ----------------------------
# -----  Train and Save  -----
# ----------------------------

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))


# ---------------------
# -----  Restore  -----
# ---------------------

# In another script, re-initialize objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)

# Re-use the manager code above ^

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for _ in range(50):
    example = next(iterator)
    # Continue training or evaluate etc.

更多链接

  • saved_model 的全面而实用的教程。

  • keras 保存模型的详细指南。

检查点(Checkpoints)捕获模型使用的所有参数(tf.Variable对象)的确切值。 检查点不包含模型定义的任何计算描述,因此通常仅在可用源代码使用保存的参数值时才有用。

另一方面,SavedModel格式包括模型定义的计算序列化描述以及参数值(检查点)。该格式的模型与创建该模型的源代码无关。 因此,它们适合通过TensorFlow Serving、TensorFlow Lite、TensorFlow.js或其他编程语言的程序(如C、C ++、Java、Go、Rust、C#等TensorFlow API)进行部署。

(强调为本人所加)


Tensorflow <2


从文档中:

保存

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

还原

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

simple_save

许多好的答案,为了完整性,我会加上我自己的意见:simple_save。同时也提供一个独立的代码示例,使用tf.data.Dataset API。

Python 3;TensorFlow 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

恢复:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

独立的例子

原博客文章

以下代码为演示目的生成随机数据。

  1. 我们首先创建占位符,它们将在运行时保存数据。从中,我们创建一个Dataset,然后创建其Iterator。我们获取迭代器生成的张量,称为input_tensor,这将作为我们模型的输入。
  2. 模型本身是基于GRU的双向RNN,后跟密集分类器而构建的,因为为什么不呢。
  3. 损失函数为softmax_cross_entropy_with_logits,使用Adam优化。经过2个epoch(每个epoch 2个batch),我们使用tf.saved_model.simple_save保存“训练好”的模型。如果您按原样运行代码,则该模型将保存在名为simple/的文件夹中。
  4. 在新图形中,我们使用tf.saved_model.loader.load恢复保存的模型。我们使用graph.get_tensor_by_name获取占位符和logits以及使用graph.get_operation_by_name来初始化Iterator操作。
  5. 最后,我们对数据集中的两个batch都运行推断,并检查保存和恢复的模型是否产生相同的值。它们确实一样!

代码:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

这将打印:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

1
我是初学者,需要更多的解释...:如果我有一个CNN模型,我应该只存储1.输入占位符2.标签占位符和3.CNN输出?还是所有中间的tf.contrib.layers都要存储? - VimNing
3
图表已完全恢复。您可以运行 [n.name for n in graph2.as_graph_def().node] 进行检查。正如文档所述,简单保存旨在简化与 TensorFlow 服务的交互,这就是参数的重点;但其他变量仍然会被恢复,否则将无法进行推断。就像我在示例中所做的那样,只需获取您感兴趣的变量即可。请查看文档 - ted
2
不错吧,但是它也能与Eager模式模型和tfe.Saver一起使用吗? - Geoffrey Anderson
1
如果没有将global_step作为参数,如果您停止然后尝试重新开始训练,它会认为您是从头开始的。这将至少破坏您的TensorBoard可视化。 - Monica Heddneck
2
我正在尝试调用restore,但出现了这个错误“ValueError: No variables to save”。有人可以帮忙吗? - Elaine Chen
显示剩余15条评论

130

对于 TensorFlow 版本 < 0.11.0RC1:

保存的检查点包含了您模型中 Variable 的值,而不是模型/图本身,这意味着当您恢复检查点时,图应该是相同的。

这里有一个线性回归的例子,其中有一个训练循环会保存变量检查点,并且有一个评估部分,将恢复在之前运行中保存的变量并计算预测结果。当然,如果需要,您也可以恢复变量并继续训练。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

这里是用于保存和恢复的 Variable文档。此外,这里还有用于保存和恢复的Saver文档


1
FLAGS是用户定义的。以下是定义它们的示例:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py - Ryan Sepassi
batch_x需要以哪种格式呈现?二进制?Numpy数组? - pepe
@pepe Numpy数组应该是可以的。元素的类型应该对应于占位符的类型。[链接]https://www.tensorflow.org/versions/r0.9/api_docs/python/framework.html#tensor-types - Donny
FLAGS 给出了“未定义”的错误。你能告诉我这段代码中 FLAGS 的定义是什么吗?@RyanSepassi - Muhammad Hannan
为了明确起见:Tensorflow的最新版本确实允许存储模型/图形。[对于0.11限制的哪些方面适用不清楚。鉴于得到了大量赞同,我倾向于相信这个通用语句在最新版本中仍然是正确的。] - bluenote10
显示剩余2条评论

84

我的环境: Python 3.6, Tensorflow 1.3.0

尽管有许多解决方案,但其中大多数都基于tf.train.Saver。当我们加载由Saver保存的.ckpt文件时,我们必须重新定义tensorflow网络或使用一些奇怪且难以记住的名称,例如'placehold_0:0''dense/Adam/Weight:0'。在这里,我建议使用tf.saved_model,下面给出了一个最简单的示例,您可以从Serving a TensorFlow Model中了解更多:

保存模型:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

加载模型:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

4
对于这个很好的SavedModel API示例,我给予点赞。但是,我希望你的“保存模型”部分能像Ryan Sepassi的回答一样展示一个训练循环!我意识到这是一个老问题,但这个回复是我在谷歌上找到的为数不多(也很有价值)的SavedModel示例之一。 - Dylan F
@Tom 这是一个很好的答案 - 只针对新的 SavedModel。你能看一下这个 SavedModel 的问题吗?https://stackoverflow.com/questions/48540744/tensorflow-savedmodel-how-to-iterative-save - bluesummers
现在使用TF Eager模型使其全部正常工作。Google在他们2018年的演示中建议每个人都远离TF图形代码。 - Geoffrey Anderson

55
这个模型有两个部分:模型定义和张量的数值。模型定义由Supervisor保存在模型目录中的graph.pbtxt文件中,而张量的数值则保存在checkpoint文件(例如model.ckpt-1003418)中。
可以使用tf.import_graph_def来恢复模型定义,而使用Saver来恢复张量的数值。
然而,Saver使用特殊的集合来持有变量列表,并将其附加到模型图上。这个集合不是使用import_graph_def初始化的,所以目前无法同时使用这两种方法(我们正在努力解决这个问题)。目前,你必须使用Ryan Sepassi的方法--手动构建一个具有相同节点名称的图形,并使用Saver将权重加载到其中。
(或者,您可以通过使用import_graph_def,手动创建变量,并为每个变量使用tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)的方法进行处理,然后再使用Saver)。

在使用inceptionv3的classify_image.py示例中,只有graphdef被加载。这是否意味着现在GraphDef也包含Variable? - jrabary
1
@jrabary 这个模型可能已经被冻结 - Eric Platon
1
嗨,我是新手使用TensorFlow,在保存模型方面遇到了麻烦。如果您能帮助我,我将不胜感激。 https://stackoverflow.com/questions/48083474/finish-tensorflow-training-in-progress - Ruchir Baronia

38

您也可以选择更简单的方式。

步骤1:初始化所有变量

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

步骤2:将会话保存在Saver模型中并进行保存

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

步骤三:恢复模型

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

步骤四:检查你的变量

W1 = session.run(W1)
print(W1)
在不同的 Python 实例中运行时,请使用:

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

你好,我想请问如何在训练模型3000次后保存模型,类似于Caffe。我发现TensorFlow只会保存最后一个模型,尽管我将迭代次数与模型连接起来以区分所有迭代。我的意思是model_3000.ckpt,model_6000.ckpt,--- model_100000.ckpt。您能否解释一下为什么它不会保存所有模型,而只保存最后3个迭代的模型呢?谢谢。 - khan
2
@khan 请参考 https://dev59.com/JZjga4cB1Zd3GeqPOsDE - Himanshu Babal
3
能否获取图中保存的所有变量/操作名称的方法是什么? - Moondra

21

在大多数情况下,使用 tf.train.Saver 进行磁盘保存和恢复是最佳选择:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

您还可以保存/恢复图结构本身(有关详细信息,请参阅MetaGraph文档)。默认情况下,Saver将图结构保存到.meta文件中。您可以调用import_meta_graph()进行恢复。它会恢复图结构并返回一个Saver,您可以使用它来恢复模型的状态:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

然而,有些情况下您需要更快的方法。例如,如果您实现了早停机制,您希望在训练过程中每次模型(根据验证集的度量结果)改进时保存检查点,然后如果一段时间内没有进展,您希望回滚到最佳模型。如果每次改进都将模型保存到磁盘上,将会大大减慢训练速度。诀窍是将变量状态保存到 内存,然后稍后恢复它们:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

简要说明:当你创建一个变量 X 时,TensorFlow会自动创建一个赋值操作 X/Assign 来设置变量的初始值。我们不需要创建占位符和额外的赋值操作(这只会使图表混乱),而是直接使用这些现有的赋值操作。每个赋值操作的第一个输入是它所初始化的变量的引用,第二个输入(assign_op.inputs[1])是初始值。因此,为了设置任何我们想要的值(而不仅仅是初始值),我们需要使用一个 feed_dict 并替换初始值。是的,TensorFlow允许您为任何操作提供值,而不仅仅是占位符,所以这很好用。


感谢您的回答。我有一个类似的问题,即如何将单个.ckpt文件转换为两个.index和.data文件(例如tf.slim中提供的预训练inception模型)。我的问题在这里:https://stackoverflow.com/questions/47762114/converting-a-pb-file-to-meta-in-tf-1-3 - Amir
嗨,我是tensorflow的新手,正在保存模型时遇到了麻烦。如果您能帮助我,我将不胜感激。https://stackoverflow.com/questions/48083474/finish-tensorflow-training-in-progress - Ruchir Baronia

17

正如 Yaroslav 所说,您可以通过导入图形、手动创建变量,然后使用 Saver 来黑客式地从 graph_def 和 checkpoint 进行恢复。

我为了个人使用而实现了这一点,所以我想在这里分享代码。

链接: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(当然,这是一种黑客方式,不能保证以这种方式保存的模型将来版本的 TensorFlow 中仍然可读。)


14

如果是内部保存的模型,您只需为所有变量指定一个还原器,如下所示:

restorer = tf.train.Saver(tf.all_variables())

并将其用于恢复当前会话中的变量:

restorer.restore(self._sess, model_file)

为了外部模型,您需要指定其变量名称与您的变量名称之间的映射关系。 您可以使用以下命令查看模型变量名称

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

inspect_checkpoint.py脚本可以在Tensorflow源代码的'./tensorflow/python/tools'文件夹中找到。

要指定映射,您可以使用我的Tensorflow-Worklab,其中包含一组类和脚本来训练和重新训练不同的模型。它包括重新训练ResNet模型的示例,位于这里


all_variables()现已过时。 - MiniQuark
嗨,我是tensorflow的新手,遇到了保存模型的困难。如果您能帮助我,我会非常感激。https://stackoverflow.com/questions/48083474/finish-tensorflow-training-in-progress - Ruchir Baronia

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接