TensorFlow权重初始化

16

关于TensorFlow网站上的MNIST教程,我进行了一项实验(gist),以查看不同权重初始化对学习的影响。 我注意到,与我在流行的[Xavier, Glorot 2010]论文中所读到的相反,无论权重初始化如何,学习都很好。

Learning curves for different weight initializations averaged over 3 runs

不同的曲线代表卷积层和全连接层权重初始化中w的不同取值。请注意,即使0.31.0的性能较低并且某些值训练速度更快(特别是0.030.1最快),所有w的值都能正常工作。然而,图表显示了一个相当大的w范围可行,表明在权重初始化方面具有“健壮性”。
def weight_variable(shape, w=0.1):
  initial = tf.truncated_normal(shape, stddev=w)
  return tf.Variable(initial)

def bias_variable(shape, w=0.1):
  initial = tf.constant(w, shape=shape)
  return tf.Variable(initial)

问题: 为什么这个网络不会受到梯度消失或梯度爆炸的问题?

我建议您阅读要点以了解实现细节,但以下是参考代码。在我的 Nvidia 960m 上大约需要一个小时,尽管我想象它也可以在 CPU 上在合理的时间内运行。

import time
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.client import device_lib

import numpy
import matplotlib.pyplot as pyplot

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Weight initialization

def weight_variable(shape, w=0.1):
  initial = tf.truncated_normal(shape, stddev=w)
  return tf.Variable(initial)

def bias_variable(shape, w=0.1):
  initial = tf.constant(w, shape=shape)
  return tf.Variable(initial)


# Network architecture

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                    strides=[1, 2, 2, 1], padding='SAME')

def build_network_for_weight_initialization(w):
    """ Builds a CNN for the MNIST-problem:
     - 32 5x5 kernels convolutional layer with bias and ReLU activations
     - 2x2 maxpooling
     - 64 5x5 kernels convolutional layer with bias and ReLU activations
     - 2x2 maxpooling
     - Fully connected layer with 1024 nodes + bias and ReLU activations
     - dropout
     - Fully connected softmax layer for classification (of 10 classes)

     Returns the x, and y placeholders for the train data, the output
     of the network and the dropbout placeholder as a tuple of 4 elements.
    """
    x = tf.placeholder(tf.float32, shape=[None, 784])
    y_ = tf.placeholder(tf.float32, shape=[None, 10])

    x_image = tf.reshape(x, [-1,28,28,1])
    W_conv1 = weight_variable([5, 5, 1, 32], w)
    b_conv1 = bias_variable([32], w)

    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)
    W_conv2 = weight_variable([5, 5, 32, 64], w)
    b_conv2 = bias_variable([64], w)

    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    W_fc1 = weight_variable([7 * 7 * 64, 1024], w)
    b_fc1 = bias_variable([1024], w)

    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    W_fc2 = weight_variable([1024, 10], w)
    b_fc2 = bias_variable([10], w)

    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

    return (x, y_, y_conv, keep_prob)


# Experiment

def evaluate_for_weight_init(w):
    """ Returns an accuracy learning curve for a network trained on
    10000 batches of 50 samples. The learning curve has one item
    every 100 batches."""
    with tf.Session() as sess:
        x, y_, y_conv, keep_prob = build_network_for_weight_initialization(w)
        cross_entropy = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
        train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
        correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        sess.run(tf.global_variables_initializer())
        lr = []
        for _ in range(100):
            for i in range(100):
                batch = mnist.train.next_batch(50)
                train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
            assert mnist.test.images.shape[0] == 10000
            # This way the accuracy-evaluation fits in my 2GB laptop GPU.
            a = sum(
                accuracy.eval(feed_dict={
                    x: mnist.test.images[2000*i:2000*(i+1)],
                    y_: mnist.test.labels[2000*i:2000*(i+1)],
                    keep_prob: 1.0})
                for i in range(5)) / 5
            lr.append(a)
        return lr


ws = [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0]
accuracies = [
    [evaluate_for_weight_init(w) for w in ws]
    for _ in range(3)
]


# Plotting results

pyplot.plot(numpy.array(accuracies).mean(0).T)
pyplot.ylim(0.9, 1)
pyplot.xlim(0,140)
pyplot.xlabel('batch (x 100)')
pyplot.ylabel('test accuracy')
pyplot.legend(ws)

2
随着网络深度的增加,梯度问题也随之增加。简单地解释你的结果是,类似LeNet的网络浅层结构足够简单,不会过多遭受初始化问题的困扰。如果使用更深的网络,你的观察结果可能会有所不同。 - P-Gn
这也是我的一个假设,但我想要确定或了解可能存在的其他解释。 - Herbert
1
啊,对于这个例子的另一种解释可能是,逻辑函数比ReLU更容易出现梯度消失的情况。如果有人能对此发表评论,那将是非常有价值的。 - Herbert
我检查了你的代码并发现在def evaluate_for_weight_init(w)for i in range(100):中期望一个缩进块,所以代码无法工作。请您检查一下。 - Mario
不工作是指什么意思?快速浏览我2年前发布的代码,表明缩进是正确的。不工作是指出现错误,还是不按预期执行? - Herbert
2个回答

16

权重初始化策略对于提高模型的性能至关重要,但常常被忽视。由于该问题现在是谷歌的首要搜索结果,我认为这篇文章需要更详细的回答。

一般而言,每个层的激活函数梯度、传入/传出连接数(fan_in/fan_out)以及权重方差的乘积应等于1。这样,当你通过网络进行反向传播时,输入和输出梯度之间的方差将保持一致,你就不会遭受梯度消失或爆炸。尽管ReLU更具抗爆炸/消失梯度的能力,但仍可能存在问题。

OP使用的tf.truncated_normal做随机初始化,这鼓励了权重的“不同”更新,但不考虑上述优化策略。在较小的网络中,这可能不是一个问题,但如果你想要更深的网络或更快的训练时间,那么最好根据最近的研究尝试基于权重初始化的策略。

对于ReLU函数之前的权重,可以使用以下默认设置:

tf.contrib.layers.variance_scaling_initializer

对于tanh/sigmoid激活的层,“xavier”可能更合适:

tf.contrib.layers.xavier_initializer

有关这两个函数及其相关论文的更多详细信息,请访问: https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.layers/initializers

除了权重初始化策略外,进一步的优化可以尝试批量归一化:https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization


逻辑sigmoid怎么样?使用Xavier初始化方法是否合适?在使用批量归一化时,仍然需要正确地初始化权重吗?激活函数是由哪个方法决定的,是在权重之前还是之后(我猜是之后)?只是有点挑剔 :P - Herbert
好问题。Xavier 应该使用逻辑 sigmoid,ReLU 被证明是特别有问题的(参见 https://arxiv.org/abs/1704.08863)。使用批量归一化与正确的权重初始化相结合,应该可以帮助您从 ~10 层到 ~30 层。之后,您需要开始查看跳过连接。激活函数是接收问题权重的函数(因此在之后)。我更新了答案,并添加了一些相关细节。 - Shane
我注意到我从未花时间接受一个答案,对此感到抱歉并感谢您的帮助! - Herbert

6

逻辑函数更容易出现梯度消失的问题,因为它们的梯度都小于1,所以在反向传播时,乘以越多的逻辑函数,梯度就会变得越小(而且非常快),而ReLU在正部分具有梯度1,因此不会出现这个问题。

此外,您的网络并不够深,不会受到这个问题的影响。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接