TensorFlow卷积神经网络中的全连接层权重维度

9

我一直在跟着这个用TensorFlow编写的卷积神经网络示例进行编码,但是我对这些权重的分配感到困惑:

weights = {

# 5x5 conv, 1 input, 32 outputs
'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),

# 5x5 conv, 32 inputs, 64 outputs
'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])), 

# fully connected, 7*7*64 inputs, 1024 outputs
'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])), 

# 1024 inputs, 10 outputs (class prediction)
'out': tf.Variable(tf.random_normal([1024, n_classes])) 

}

我们怎么知道'wd1'权重矩阵应该有7 x 7 x 64行?

之后它被用来重新塑造第二个卷积层的输出:

# Fully connected layer
# Reshape conv2 output to fit dense layer input
dense1 = tf.reshape(conv2, [-1, _weights['wd1'].get_shape().as_list()[0]]) 

# Relu activation
dense1 = tf.nn.relu(tf.add(tf.matmul(dense1, _weights['wd1']), _biases['bd1']))

根据我的计算,池化层2(conv2输出)有4 x 4 x 64个神经元。

为什么我们要将其重新整形为[-1,7*7*64]?

2个回答

16

从开始工作:

输入变量_X的大小为[28x28x1](忽略批次维度)。这是一个28x28的灰度图像。

第一个卷积层使用PADDING=same,因此输出一个28x28的图层,然后传递给具有k=2max_pool函数,它将每个维度都减小了一半,结果是14x14的空间布局。conv1具有32个输出 - 因此每个示例张量现在为[14x14x32]

这在conv2中重复,它具有64个输出,结果为[7x7x64]

简而言之:图像最初为28x28,在每个维度上通过maxpool的操作每次减少一半。 28/2/2 = 7。


1
你刚刚解决了我对层维度的可怕误解。我完全忽略了PADDING=same,因此感到困惑。谢谢! - jfbeltran
为什么 _X 的维度是 [28x28x1]?不应该是 [28x28] 吗?为什么有额外的 x1 - daniel451
1
卷积运算符期望一个批次,其中条目具有高度、宽度和深度。最后的x1只是一个美学上的重塑。 - dga
这个例子中,有3个卷积层和2个全连接层。卷积层的步幅为1,池化层的步幅为2。因此按照相同的逻辑,第三个卷积层的大小不应该是代码中给出的4,而应该是7/2。 - Anant Gupta
当涉及到多通道数据时会发生什么?(例如RGB图像) - EdgeRover

1
这个问题需要您对深度学习卷积有很好的理解。
基本上,模型中每个卷积层都会减少卷积金字塔横向面积。这种减少是通过卷积步幅最大池化步幅实现的。并且为了使事情更加复杂,我们根据填充(PADDING)有两个选项。
选项1- PADDING='SAME'
out_height = ceil(float(in_height) / float(strides[1]))
out_width  = ceil(float(in_width) / float(strides[2]))

选项2 - PADDING='VALID'
out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))
out_width  = ceil(float(in_width - filter_width + 1) / float(strides[2]))

对于每个卷积和最大池化调用,您都需要计算新的out_heightout_width。然后,在卷积结束时,将out_heightout_width和您最后一个卷积层的深度相乘。这个乘积的结果是输出特征映射大小,它是第一个全连接层的输入。
因此,在您的示例中,您可能只有PADDING='SAME',卷积步长为1,最大池化步长为2,两次。最后,您只需要将所有内容除以4(1,2,1,2)。
更多信息请参见tensorflow API

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接