tf.train.Saver()和GPU存在问题 - TensorFlow

3

我的代码结构如下:

with tf.device('/gpu:1'):
...
model = get_model(input_pl)
...
    with tf.Session() as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
           ...
           for n in range(num_batches):
              ...
              sess.run(...)
           # eval epoch
        saver.save(sess, ...)

我想在训练阶段后保存模型。 运行时会出现以下错误:
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'save/SaveV2': Could not satisfy explicit device specification '/device:GPU:1' because no supported kernel for GPU devices is available.

阅读这个问题后,我以这种方式更改了代码:

saver = tf.train.Saver()
with tf.device('/gpu:1'):
...
model = get_model(pointcloud_pl)
...
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
           ...
           for n in range(num_batches):
              ...
              sess.run(...)
           # eval epoch
        saver.save(sess, ...)

但现在我遇到了这个错误:

ValueError: No variables to save

我也尝试过这种方式:

with tf.Session() as sess:
    saver = tf.train.Saver()
    ...
    with tf.device('/gpu:1'):
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
        ...
            for n in range(num_batches):
               ...
               sess.run()
            # eval epoch
        saver.save(sess, ...)

我仍然遇到相同的错误。错误总是出现在saver = tf.train.Saver()这一行。

我该如何解决这个问题?


1
你在哪里建立图表?你能否在代码中添加一个注释,标明图表是如何构建的? - Chan Kha Vu
编辑了第一个代码块,在 with tf.device():tf.Session() 之间。 - User
1个回答

1

通过以下步骤解决:

  1. tf.Session()
  2. 模型
  3. saver = tf.train.Saver()
  4. with tf.device():

这里是一个示例代码

with tf.Session() as sess:
    ...
    model = get_model(input_pl)
    saver = tf.train.Saver()
    ...
    with tf.device('/gpu:1'):
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
        ...
            for n in range(num_batches):
               ...
               sess.run()
            # eval epoch
        saver.save(sess, ...)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接