我正在使用Keras制作seq2seq模型。我已经构建了单层编码器和解码器,它们工作得很好。但现在我想扩展为多层编码器和解码器。我使用Keras函数API进行构建。
训练:
编码器代码:
encoder_input=Input(shape=(None,vec_dimension))
encoder_lstm=LSTM(vec_dimension,return_state=True,return_sequences=True)(encoder_input)
encoder_lstm=LSTM(vec_dimension,return_state=True)(encoder_lstm)
encoder_output,encoder_h,encoder_c=encoder_lstm
代码解码器:
encoder_state=[encoder_h,encoder_c]
decoder_input=Input(shape=(None,vec_dimension))
decoder_lstm= LSTM(vec_dimension,return_state=True,return_sequences=True (decoder_input,initial_state=encoder_state)
decoder_lstm=LSTM(vec_dimension,return_state=True,return_sequences=True)(decoder_lstm)
decoder_output,_,_=decoder_lstm
用于测试:
encoder_model=Model(inputs=encoder_input,outputs=encoder_state)
decoder_state_input_h=Input(shape=(None,vec_dimension))
decoder_state_input_c=Input(shape=(None,vec_dimension))
decoder_states_input=[decoder_state_input_h,decoder_state_input_c]
decoder_output,decoder_state_h,decoder_state_c =decoder_lstm #(decoder_input,initial_state=decoder_states_input)
decoder_states=[decoder_state_h,decoder_state_c]
decoder_model=Model(inputs=[decoder_input]+decoder_states_input,outputs=[decoder_output]+decoder_states)
现在,当我尝试增加解码器中的层数进行训练时,训练可以正常进行,但在测试时就无法正常工作并出现错误。
实际上,问题出在我将initial_state
从最后一层移到了中间一层,而在测试时调用它会导致出现错误。
RuntimeError: Graph disconnected: cannot obtain value for tensor Tensor("input_64:0", shape=(?, ?, 150), dtype=float32) at layer "input_64".The following previous layers were accessed without issue: []
我应该如何传递initial_state=decoder_states_input
到输入层,以避免出现错误?我应该如何在最后一个层中传递initial_state=decoder_states_input
到第一个输入层?在这段代码中,我尝试创建多层解码器LSTM,但是这导致错误。当使用单层时,正确的代码为: Encoder(Training):encoder_input=Input(shape=(None,vec_dimension))
encoder_lstm =LSTM(vec_dimension,return_state=True)(encoder_input)
encoder_output,encoder_h,encoder_c=encoder_lstm
解码器(训练):-
encoder_state=[encoder_h,encoder_c]
decoder_input=Input(shape=(None,vec_dimension))
decoder_lstm= LSTM(vec_dimension, return_state=True, return_sequences=True)
decoder_output,_,_=decoder_lstm(decoder_input,initial_state=encoder_state)
解码器(测试)
decoder_output,decoder_state_h,decoder_state_c=decoder_lstm( decoder_input, initial_state=decoder_states_input)
decoder_states=[decoder_state_h,decoder_state_c]
decoder_output,decoder_state_h,decoder_state_c=decoder_lstm (decoder_input,initial_state=decoder_states_input)
decoder_model=Model(inputs=[decoder_input]+decoder_states_input,outputs=[decoder_output]+decoder_states)