什么是TensorFlow中的局部变量?

19

TensorFlow有以下API:

tf.local_variables()

返回所有使用 collection=[LOCAL_VARIABLES] 创建的变量。

返回:

本地Variable对象的列表。

在TensorFlow中,什么是本地变量?能否给我一个例子?


请查看此问题,其中有一个本地变量在使用前必须进行初始化。 - Rfank2019
3个回答

27
简短回答:在TF中,局部变量是使用collections=[tf.GraphKeys.LOCAL_VARIABLES]创建的任何变量。例如:
e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])

LOCAL_VARIABLES: 指每台计算机本地的Variable对象子集。通常用于临时变量,如计数器等。注意:使用tf.contrib.framework.local_variable将其添加到此集合中。

它们通常不会保存/还原到检查点,并用于临时或中间值。


长答案:这对我来说也是一个困惑的来源。起初,我认为本地变量意味着与几乎任何编程语言中的本地变量相同,但这并不是同一回事:

import tensorflow as tf

def some_func():
    z = tf.Variable(1, name='var_z')

a = tf.Variable(1, name='var_a')
b = tf.get_variable('var_b', 2)
with tf.name_scope('aaa'):
    c = tf.Variable(3, name='var_c')

with tf.variable_scope('bbb'):
    d = tf.Variable(3, name='var_d')

some_func()
some_func()

print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]

无论我尝试了什么,我总是只收到全局变量:
['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0']
[]

tf.local_variables 的文档没有提供很多细节:

本地变量 - 每个进程的变量,通常不保存/恢复到检查点并用于临时或中间值。例如,它们可以用作度量计算的计数器或此机器读取数据的纪元数。local_variable() 自动将新变量添加到 GraphKeys.LOCAL_VARIABLES 中。该便利函数返回该集合的内容。


但是当我阅读tf.Variable类中init方法的文档时,我发现在创建变量时,你可以通过指定一个collections列表来确定所需的变量类型。
可用的集合元素列表在此处。因此,要创建一个本地变量,你需要像这样做。你将在local_variables列表中看到它:
e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])
print [str(i.name) for i in tf.local_variables()]

19

这与常规变量相同,但它位于默认值 (GraphKeys.VARIABLES) 之外的不同集合中。该集合由保存器用于初始化默认的要保存的变量列表,因此指定为 local 将不会默认保存该变量。

根据代码库,我只看到一个地方使用它,即limit_epochs

  with ops.name_scope(name, "limit_epochs", [tensor]) as name:
    zero64 = constant_op.constant(0, dtype=dtypes.int64)
    epochs = variables.Variable(
        zero64, name="epochs", trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])

分布式 TensorFlow 在分布复制模式下使用局部变量。 - Rfank2019

11

我认为,在这里需要理解 TensorFlow 集合。

TensorFlow 提供了集合,它们是张量或其他对象(例如 tf.Variable 实例)的命名列表。

以下是内置的集合:

tf.GraphKeys.GLOBAL_VARIABLES               #=> 'variables'                                                                                                                                                                                 
tf.GraphKeys.LOCAL_VARIABLES                #=> 'local_variables'                                                                                                                                                                           
tf.GraphKeys.MODEL_VARIABLES                #=> 'model_variables'                                                                                                                                                                           
tf.GraphKeys.TRAINABLE_VARIABLES            #=> 'trainable_variables' 

通常,在创建变量时,可以将其添加到给定集合中,方法是将该集合作为传递给 collections 参数的集合之一进行显式传递。

理论上,一个变量可以在任何内置或自定义集合中组合。但是,内置集合用于特定目的:

  • tf.GraphKeys.GLOBAL_VARIABLES
    • Variable() 构造函数或 get_variable() 方法会自动将新变量添加到图集合 GraphKeys.GLOBAL_VARIABLES 中,除非显式传递 collections 参数并且不包括 GLOBAL_VARIABLE
    • 按照惯例,这些变量在分布式环境中是共享的(模型变量是其中的子集)。
    • 更多细节请参阅tf.global_variables()
  • tf.GraphKeys.TRAINABLE_VARIABLES
    • 当传递参数 trainable=True(默认行为)时,Variable() 构造函数和 get_variable() 方法会自动将新变量添加到此图集合中。但是,您可以使用 collections 参数将变量添加到任何所需的集合中。
    • 按照惯例,这些是将由优化器训练的变量。
    • 更多细节请参阅tf.trainable_variables()
  • tf.GraphKeys.LOCAL_VARIABLES
    • 您可以使用 tf.contrib.framework.local_variable() 方法将变量添加到此集合中。但是,您可以使用 collections 参数将变量添加到任何所需的集合中。
    • 按照惯例,这些是每台机器本地的变量。它们是进程级别的变量,通常不保存/恢复到检查点,并用于临时或中间值。例如,它们可以用作度量计算的计数器或此机器读取数据的历元数量。
    • 更多细节请参阅tf.local_variables()
  • tf.GraphKeys.MODEL_VARIABLES
    • 您可以使用 tf.contrib.framework.model_variable() 方法将变量添加到此集合中。但是,您可以使用 collections 参数将变量添加到任何所需的集合中。
    • 按照惯例,这些是用于推理(前向传播)模型中的变量。
    • 更多细节请参阅tf.model_variables()

您也可以使用自己的集合。任何字符串都是有效的集合名称,无需显式创建集合。要在创建变量后将变量(或任何其他对象)添加到集合中,请调用tf.add_to_collection()

例如,

tf.__version__                                                            #=> '1.9.0'                                                                                                                                                       

# initializing using a Tensor                                                                                                                                                                                                               
my_variable01 = tf.get_variable("var01", dtype=tf.int32, initializer=tf.constant([23, 42]))                                                                                                                                                 
# initializing using a convenient initializer                                                                                                                                                                                               
my_variable02 = tf.get_variable("var02", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.zeros_initializer)                                                                                                                                 

my_variable03 = tf.get_variable("var03", dtype=tf.int32, initializer=tf.constant([1, 2]), trainable=None)                                                                                                                                   
my_variable04 = tf.get_variable("var04", dtype=tf.int32, initializer=tf.constant([3, 4]), trainable=False)                                                                                                                                  
my_variable05 = tf.get_variable("var05", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.ones_initializer, trainable=True)                                                                                                                  

my_variable06 = tf.get_variable("var06", dtype=tf.int32, initializer=tf.constant([5, 6]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=None)                                                                                       
my_variable07 = tf.get_variable("var07", dtype=tf.int32, initializer=tf.constant([7, 8]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=True)                                                                                       

my_variable08 = tf.get_variable("var08", dtype=tf.int32, initializer=tf.constant(1), collections=[tf.GraphKeys.MODEL_VARIABLES], trainable=None)                                                                                            
my_variable09 = tf.get_variable("var09", dtype=tf.int32, initializer=tf.constant(2), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, "my_collectio
n"])                                                                                                                                                                                                                                        
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)                                                                                                         

[var.name for var in tf.global_variables()]                               #=> ['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']                                                                                            
[var.name for var in tf.local_variables()]                                #=> ['var06:0', 'var07:0', 'var09:0']                                                                                                                             
[var.name for var in tf.trainable_variables()]                            #=> ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']                                                                                            
[var.name for var in tf.model_variables()]                                #=> ['var08:0', 'var09:0']                                                                                                                                        
[var.name for var in tf.get_collection("trainable_variables")]            #=> ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']                                                                                            
[var.name for var in tf.get_collection("my_collection")]                  #=> ['var09:0', 'var10:0']                                                                                                                                        

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接