如何使用tf.constant或numpy数组初始化tf.Variable?

4
我正在尝试在tf.InteractiveSession()中初始化一个tf.Variable()。我已经有一些预训练的权重,它们是单独的numpy文件。如何使用这些numpy值有效地初始化变量?我已经尝试了以下选项:1. 使用tf.assign() 2. 在tf.Variable()创建过程中直接使用sess.run()。似乎数值没有正确地初始化。以下是我尝试的一些代码,请告诉我哪个是正确的?
def read_numpy(file):
    return np.fromfile(file,dtype='f')

def build_network():
    with tf.get_default_graph().as_default():
        x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
        sess = tf.get_default_session()
        with sess.as_default():
            sess.run(tf.global_variables_initializer())

sess = tf.InteractiveSession()
with sess.as_default():
    build_network()

这是正确的做法吗?我已经打印了session对象,它是整个过程中使用的同一个会话。

编辑:目前似乎使用sess.run(tf.global_variables_initializer())会调用一个随机初始化操作。

1个回答

3

tf.Variable()接受numpy数组作为初始值:

import tensorflow as tf
import numpy as np

init = np.ones((2, 2))
x = tf.Variable(init) # <-- set initial value to assign to a variable

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]

只需使用numpy数组进行初始化,无需首先将其转换为张量。

或者,您还可以使用tf.Variable.load()将numpy数组的值分配给会话上下文中的变量:

import tensorflow as tf
import numpy as np

x = tf.Variable(tf.zeros((2, 2)))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    init = np.ones((2, 2))
    x.load(init)
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]

使用 tf.get_variable() 好吗?我用了它,而且它有效。我的属性是 initializer=tf.constant(read_numpy('foo.npy')) - lamo_738
可以的。tf.Variable()tf.get_variable()之间的区别在于,tf.Variable()创建一个新变量,而get_variable()可能会创建一个新变量,也可能返回一个现有的变量。更多信息请阅读此答案 - Vlad
哦,好的,第一次使用tf.get_variable()时出现了错误,因为命名冲突了...但是后来我将它包含在tf.variable_scope()中,问题就解决了。感谢您的回复! - lamo_738
很高兴能帮到你。我仍然会使用 x = tf.Variable(read_numpy('foo.npy'))。这样更简洁。 - Vlad

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接