我已经在Tensorflow中训练了我的神经网络,并按照以下方式保存了模型:
def neural_net(x):
layer_1 = tf.layers.dense(inputs=x, units=195, activation=tf.nn.sigmoid)
out_layer = tf.layers.dense(inputs=layer_1, units=6)
return out_layer
train_x = pd.read_csv("data_x.csv", sep=" ")
train_y = pd.read_csv("data_y.csv", sep=" ")
train_x = train_x / 6 - 0.5
train_size = 0.9
train_cnt = int(floor(train_x.shape[0] * train_size))
x_train = train_x.iloc[0:train_cnt].values
y_train = train_y.iloc[0:train_cnt].values
x_test = train_x.iloc[train_cnt:].values
y_test = train_y.iloc[train_cnt:].values
x = tf.placeholder("float", [None, 386])
y = tf.placeholder("float", [None, 6])
nn_output = neural_net(x)
cost = tf.reduce_mean(tf.losses.mean_squared_error(labels=y, predictions=nn_output))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
training_epochs = 5000
display_step = 1000
batch_size = 30
keep_prob = tf.placeholder("float")
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
total_batch = int(len(x_train) / batch_size)
x_batches = np.array_split(x_train, total_batch)
y_batches = np.array_split(y_train, total_batch)
for i in range(total_batch):
batch_x, batch_y = x_batches[i], y_batches[i]
_, c = sess.run([optimizer, cost],
feed_dict={
x: batch_x,
y: batch_y,
keep_prob: 0.8
})
saver.save(sess, 'trained_model', global_step=1000)
现在我想在另一个文件中使用已经训练好的模型。当然有很多恢复和保存模型的例子,我浏览了很多但是还是没能够使其工作,总是会出现某种错误。所以这是我的恢复文件,请你帮我使其恢复保存的模型吧。
saver = tf.train.import_meta_graph('trained_model-1000.meta')
y_pred = []
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('./'))
sess.run([y_pred], feed_dict={x: input_values})
例如这次尝试给了我一个错误信息:“The session graph is empty. Add operations to the graph before calling run()。”那么我应该在图表中添加什么操作以及如何添加呢?我不知道在我的模型中应该添加什么操作...我不理解Tensorflow中保存/恢复的整个概念。或者我应该完全以不同的方式进行恢复吗?感谢您提前的帮助!
saver = tf.train.import_meta_graph('trained_model-1000.meta')
放在with tf.Session() as sess:
里面吗?也许之前加上tf.reset_default_graph()
以确保一切都清空了... - gdelabneural_net(x)
,然后按照我的问题或CAta.RAy下面的答案恢复其数据。最后你可以像Alli Abbasi的回答中那样进行预测。 - T.Poe