如何获取变量的当前值?

40

假设我们有一个变量:

x = tf.Variable(...)

使用assign()方法可以在训练过程中更新此变量的值。

获取变量当前值的最佳方法是什么?

我知道我们可以使用以下方法:

session.run(x)

但我担心这会触发一整个操作链。

在Theano中,您只需执行以下操作:

y = theano.shared(...)
y_vals = y.get_value()

我正在寻找在TensorFlow中等效的东西。

4个回答

39
变量的值只能通过在“会话”中运行它来获得。在常见问题解答中写道:

Tensor对象是对操作结果的符号句柄,但实际上不保存操作输出的值。

因此TF的等效方式将是:

import tensorflow as tf

x = tf.Variable([1.0, 2.0])

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    v = sess.run(x)
    print(v)  # will show you your variable.

使用 init = global_variables_initializer() 这一部分非常重要,应该在初始化变量时完成。

如果你在 IPython 中工作,请查看 InteractiveSession


并且需要非常清楚的是:运行变量只会生成变量当前的值;它不会运行任何与其相关的赋值操作。这很便宜。 - dga
@dga 是的,如果变量依赖于其他n个变量,则它们也需要被评估。如果您想获取多个值的值,则可以执行以下操作:a,b = sess.run([v1,v2]) - Salvador Dali

24

一般而言,session.run(x) 只会评估计算 x 所必需的节点,其他节点不会被评估,因此如果您想检查变量的值,它应该是相对便宜的。

要了解更多上下文,请参阅这个很棒的答案:https://dev59.com/9VwX5IYBdhLWcg3wnwcy#33610914


17

tf.Print可以让你的生活更简单!

tf.Print会在调用它的代码中评估时打印您告诉它打印的张量的值。

例如:

import tensorflow as tf
x = tf.Variable([1.0, 2.0])
x = tf.Print(x,[x])
x = 2* x

tf.initialize_all_variables()

sess = tf.Session()
sess.run()

[1.0 2.0]

这是因为它会在tf.Print行被执行的时候打印出x的值。如果您改为:

v = x.eval()
print(v)

你将会得到:

[2.0 4.0 ]

因为它将给出x的最终值。


1
由于在tensorflow 2.0.0中取消了tf.Variable(),如果你想从一个张量(比如 "net")中提取值,可以使用以下方法:
net.[tf.newaxis,:,:].numpy().

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接