Tensorflow:在类中创建图并在外部运行

10

我认为我在理解TensorFlow中的图表以及如何访问它们方面遇到了困难。我的直觉是,在“with graph:”下面的代码将形成一个单独的实体作为图表。 因此,我决定创建一个类,在实例化时构建一个图表,并拥有一个运行图表的函数,如下所示;

class Graph(object):

    #To build the graph when instantiated
    def __init__(self, parameters ):
        self.graph = tf.Graph()
        with self.graph.as_default():
             ...
             prediction = ... 
             cost       = ...
             optimizer  = ...
             ...
    # To launch the graph
    def launchG(self, inputs):
        with tf.Session(graph=self.graph) as sess:
             ...
             sess.run(optimizer, feed_dict)
             loss = sess.run(cost, feed_dict)
             ...
        return variables
下一步是创建一个主文件,用于汇总要传递给类的参数、构建图形,然后运行它。
#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }

#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...

对我来说这太优雅了,但它并没有完全起作用(显然)。事实上,似乎launchG函数无法访问图中定义的节点,这会导致出现错误,例如;

---> 26 sess.run(optimizer, feed_dict)

NameError: name 'optimizer' is not defined
也许是我的Python(和TensorFlow)理解太有限了,但我曾经奇怪地认为,使用创建的图形(G)作为参数运行会话应该可以访问其中的节点,而无需要求我明确访问。有什么启示吗?
1个回答

14

节点predictioncostoptimizer是在方法__init__中创建的局部变量,无法在方法launchG中访问。

最简单的修复方法是将它们声明为您的类Graph的属性:

class Graph(object):

    #To build the graph when instantiated
    def __init__(self, parameters ):
        self.graph = tf.Graph()
        with self.graph.as_default():
             ...
             self.prediction = ... 
             self.cost       = ...
             self.optimizer  = ...
             ...
    # To launch the graph
    def launchG(self, inputs):
        with tf.Session(graph=self.graph) as sess:
             ...
             sess.run(self.optimizer, feed_dict)
             loss = sess.run(self.cost, feed_dict)
             ...
        return variables

您还可以使用其确切名称使用 graph.get_tensor_by_namegraph.get_operation_by_name 检索图的节点。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接