Tensorflow:如何在应用程序中使用已训练好的模型?

13
我已经训练好一个Tensorflow模型,现在我想将这个“函数”导出来在我的Python程序中使用。这是否可行,如果可以,如何实现?任何帮助都将不胜感激,因为在文档中没有找到太多相关信息。(我不想保存会话!)
我现在按照您的建议已经将会话存储起来了。我正在这样加载它:
f = open('batches/batch_9.pkl', 'rb')
input = pickle.load(f)
f.close()
sess = tf.Session()

saver = tf.train.Saver()
saver.restore(sess, 'trained_network.ckpt')
y_pred = []

sess.run(y_pred, feed_dict={x: input})

print(y_pred)

然而,当我尝试初始化saver时,会出现“没有变量可保存”的错误。

我的目的是:我正在为一个棋盘游戏编写机器人程序,输入是格式化成张量的棋盘状态。现在我想返回一个张量,它给我推荐下一步最好的位置,即一个张量在每个位置都是0,但在一个位置上是1。


你可以将所有变量都存储起来。 - Natecat
为什么你不想保存一个会话? - aseipel
我想使用网络进行预测,但我不知道如何在会话中实现。 - Teywazz
你能展示一下你的代码吗?你应该能够恢复你的会话,然后使用session.run()来计算所需的张量。 - aseipel
2个回答

9

我不知道是否还有其他的方法可以实现,但你可以通过保存会话,在另一个Python程序中使用你的模型:

你的训练代码:

# build your model

sess = tf.Session()
# train your model
saver = tf.train.Saver()
saver.save(sess, 'model/model.ckpt')

在你的应用程序中:

# build your model (same as training)
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, 'model/model.ckpt')

您可以使用feed_dict来评估模型中的任何张量。这显然取决于您的模型。例如:
#evaluate tensor
sess.run(y_pred, feed_dict={x: input_data})

0
加载模型,如果模型不存在,则训练并保存模型。
if os.path.exists('MNIST.h5'):
    model = tf.keras.models.load_model('MNIST.h5')
else:
    model.fit(train_X,train_y, epochs=30)
    model.save('MNIST.h5')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接