你说得对,与常规的dropout不同,MC Dropout在推理期间也会应用。如果你搜索一下,就可以很容易地找到相关信息。
关于通道化dropout,我的理解是它不是舍弃特定神经元,而是整个通道都被舍弃。
现在在Keras中(我将使用tf.keras
),我们需要实现它。
MC Dropout
像平常一样,在Keras中定义一个自定义层,该层无论是训练还是测试都应用dropout,因此我们可以使用带有固定dropout率的tf.nn.dropout()
:
import tensorflow as tf
class MCDropout(tf.keras.layers.Layer):
def __init__(self, rate):
super(MCDropout, self).__init__()
self.rate = rate
def call(self, inputs):
return tf.nn.dropout(inputs, rate=self.rate)
使用示例:
import tensorflow as tf
import numpy as np
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(filters=6, kernel_size=3))
model.add(MCDropout(rate=0.5))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(2))
model.compile(optimizer=tf.keras.optimizers.SGD(0.001),
loss='binary_crossentropy',
metrics=['accuracy'])
x_train = np.random.normal(size=(10, 4, 4, 3))
x_train = np.vstack([x_train, 2*np.random.normal(size=(10, 4, 4, 3))])
y_train = [[1, 0] for _ in range(10)] + [[0, 1] for _ in range(10)]
y_train = np.array(y_train)
model.fit(x_train,
y_train,
epochs=2,
batch_size=10,
validation_data=(x_train, y_train))
通道丢弃
在这里,您可以使用相同的tf.nn.dropout()
函数,但是您必须指定噪声形状。 tf.nn.dropout()的文档提供了如何实现删除通道的示例:
shape(x) = [k, l, m, n] 和noise_shape = [k, 1, 1, n], 每个批次和通道组件将被独立保留,并且每行和每列将一起被保留或不保留。
这就是我们将在call()
方法中执行的操作:
class ChannelWiseDropout(tf.keras.layers.Layer):
def __init__(self, rate):
super(ChannelWiseDropout, self).__init__()
self.rate = rate
def call(self, inputs):
shape = tf.keras.backend.shape(inputs)
noise_shape = (shape[0], 1, 1, shape[-1])
return tf.nn.dropout(inputs,
rate=self.rate,
noise_shape=noise_shape)
将其应用于一些示例:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(4, 4, 3)))
model.add(tf.keras.layers.Conv2D(filters=3, kernel_size=3))
model.add(ChannelWiseDropout(rate=0.5))
x_train = np.random.normal(size=(1, 4, 4, 3))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
res = sess.run(model.output, feed_dict={model.inputs[0]:x_train})
print(res[:, :, :, 0])
print(res[:, :, :, 1])
print(res[:, :, :, 2])
注意
我使用的是tf.__version__ == '1.13.1'
。旧版本的tf
使用keep_prob = 1 - rate
而不是rate
参数。