在TensorFlow中,optimizer.compute_gradient()和tf.gradients()有什么区别?

10

我写的以下代码,在self.optimizer.compute_gradients(self.output,all_variables)这一行失败。

import tensorflow as tf
import tensorlayer as tl
from tensorflow.python.framework import ops
import numpy as np

class Network1():


def __init__(self):
    ops.reset_default_graph()
    tl.layers.clear_layers_name()

    self.sess = tf.Session()
    self.optimizer = tf.train.AdamOptimizer(learning_rate=0.1)

    self.input_x = tf.placeholder(tf.float32, shape=[None, 784],name="input")  

    input_layer = tl.layers.InputLayer(self.input_x)        

    relu1 = tl.layers.DenseLayer(input_layer, n_units=800, act = tf.nn.relu, name="relu1")
    relu2 = tl.layers.DenseLayer(relu1, n_units=500, act = tf.nn.relu, name="relu2")

    self.output = relu2.all_layers[-1]
    all_variables = relu2.all_layers

    self.gradient = self.optimizer.compute_gradients(self.output,all_variables)

    init_op = tf.initialize_all_variables()
    self.sess.run(init_op)

带有警告,

TypeError: 参数不是tf.Variable类型:Tensor("relu1/Relu:0", shape=(?, 800), dtype=float32)

然而,当我将该行代码更改为tf.gradients(self.output,all_variables)时,代码可以正常工作,至少没有报出警告。我错在哪里了?因为我认为这两种方法实际上执行的是相同的操作,即返回(梯度、变量)对列表。


什么是 tensorlayers?我们有 tf.contrib.layers - drpng
2个回答

5

optimizer.compute_gradients是对tf.gradients()的封装。你可以在这里看到它执行了额外的断言(这也解释了你的错误)。


谢谢您的回答,但是我们为什么需要额外的断言? - ytutow

5
我想通过一个简单的例子来补充上面的答案。optimizer.compute_gradients 返回一个元组列表,包括 (grads, vars) 对。变量始终存在,但梯度可能为 None。这是有道理的,因为计算特定 var_list 中某些变量相对于某些变量的 loss 的梯度可能会是 None。这表示没有依赖关系。
另一方面,tf.gradients 仅返回每个变量的 sum(dy/dx) 列表。它必须伴随着变量列表才能应用梯度更新。
因此,以下两种方法可以互换使用:
        ### Approach 1 ###
        variable_list = desired_list_of_variables
        gradients = optimizer.compute_gradients(loss,var_list=variable_list)
        optimizer.apply_gradients(gradients)

        # ### Approach 2 ###
        variable_list = desired_list_of_variables
        gradients = tf.gradients(loss, var_list=variable_list)
        optimizer.apply_gradients(zip(gradients, variable_list))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接