在Keras中定义自定义LSTM单元?

7

我使用TensorFlow作为Keras的后端。如果我想修改LSTM单元,例如"删除"输出门,我该怎么做?输出门是一个乘法门,所以我需要将其设置为固定值,以便无论乘以它的是什么,都不会对其产生影响。

1个回答

11

首先,您应该定义自己的自定义层。如果您需要一些直觉来实现自己的单元,请参见Keras存储库中的LSTMCell。例如,您的自定义单元将是:

class MinimalRNNCell(keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = K.dot(inputs, self.kernel)
        output = h + K.dot(prev_output, self.recurrent_kernel)
        return output, [output]

接着使用tf.keras.layers.RNN来使用您的单元格:

cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)

# Here's how to use the cell to build a stacked RNN:

cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)

谢谢。这种手动方法与使用“stateful”和其他参数管理LSTM状态有何不同? - user8038245
最好能够显示至少方法参数:build(self, batch_input_shape)和call(self, inputs, states),以及可选的trainingmask,以及这些方法返回的内容:对于build()方法没有返回值,而对于call()方法则返回inputs和states。 - MiniQuark
1
嗨,我正在尝试复制多个输入的代码。所以我改变了y=layer([input_1,input_2]),并且还改变了input_shape的形状,但它会抛出错误,如https://stackoverflow.com/questions/58408106/define-custom-lstm-with-multiple-inputs中所述。如何克服这个错误...有什么想法吗? - Saikat

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接