Keras中的CNN和RNN模型集成

5
尝试在keras中实现论文“Ensemble Application of Convolutional and Recurrent Neural Networks for Multi-label Text Categorization”中的模型。模型如下所示(取自该论文):enter image description here 我已经有了以下代码:
document_input = Input(shape=(None,), dtype='int32')
embedding_layer = Embedding(vocab_size, WORD_EMB_SIZE, weights=[initial_embeddings], 
                                input_length=DOC_SEQ_LEN, trainable=True)
convs = []
filter_sizes = [2,3,4,5]

doc_embedding = embedding_layer(document_input)
for filter_size in filter_sizes:
    l_conv = Conv1D(filters=256, kernel_size=filter_size, padding='same', activation='relu')(doc_embedding)
    l_pool = MaxPooling1D(filter_size)(l_conv)
    convs.append(l_pool)

l_merge = Concatenate(axis=1)(convs)
l_flat = Flatten()(l_merge)
l_dense = Dense(100, activation='relu')(l_flat)
l_dense_3d = Reshape((1,int(l_dense.shape[1])))(l_dense)

gene_variation_input = Input(shape=(None,), dtype='int32')
gene_variation_embedding = embedding_layer(gene_variation_input)
rnn_layer = LSTM(100, return_sequences=False, stateful=True)(gene_variation_embedding,initial_state=[l_dense_3d])

l_flat = Flatten()(rnn_layer)
output_layer = Dense(9, activation='softmax')(l_flat)
model = Model(inputs=[document_input,gene_variation_input], outputs=[output_layer])

我不知道我是否在上图中正确设置了文本特征向量!我尝试过,但是出现了错误。
ValueError: Layer lstm_9 expects 3 inputs, but it received 2 input tensors. Input received: [<tf.Tensor 'embedding_10_1/Gather:0' shape=(?, ?, 200) dtype=float32>, <tf.Tensor 'reshape_9/Reshape:0' shape=(?, 1, 100) dtype=float32>]

我遵循了Keras文档中关于RNN初始状态的说明部分和代码 (链接)(链接)。希望这能有所帮助。
更新: 根据建议和进一步阅读代码,模型的形式如下。
embedding_layer = Embedding(vocab_size, WORD_EMB_SIZE, weights=[initial_embeddings], trainable=True)

document_input = Input(shape=(DOC_SEQ_LEN,), batch_shape=(BATCH_SIZE, DOC_SEQ_LEN),dtype='int32')
doc_embedding = embedding_layer(document_input)

convs = []
filter_sizes = [2,3,4,5]

for filter_size in filter_sizes:
    l_conv = Conv1D(filters=256, kernel_size=filter_size, padding='same', activation='relu')(doc_embedding)
    l_pool = MaxPooling1D(filter_size)(l_conv)
    convs.append(l_pool)

l_merge = Concatenate(axis=1)(convs)
l_flat = Flatten()(l_merge)
l_dense = Dense(100, activation='relu')(l_flat)

gene_variation_input = Input(shape=(GENE_VARIATION_SEQ_LEN,), batch_shape=(BATCH_SIZE, GENE_VARIATION_SEQ_LEN),dtype='int32')
gene_variation_embedding = embedding_layer(gene_variation_input)

rnn_layer = LSTM(100, return_sequences=False, 
                 batch_input_shape=(BATCH_SIZE, GENE_VARIATION_SEQ_LEN, WORD_EMB_SIZE),
                 stateful=False)(gene_variation_embedding, initial_state=[l_dense, l_dense])

output_layer = Dense(9, activation='softmax')(rnn_layer)

model = Model(inputs=[document_input,gene_variation_input], outputs=[output_layer])

模型概述

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_8 (InputLayer)             (32, 9)               0                                            
____________________________________________________________________________________________________
input_7 (InputLayer)             (32, 4000)            0                                            
____________________________________________________________________________________________________
embedding_6 (Embedding)          multiple              73764400    input_7[0][0]                    
                                                                   input_8[0][0]                    
____________________________________________________________________________________________________
conv1d_13 (Conv1D)               (32, 4000, 256)       102656      embedding_6[0][0]                
____________________________________________________________________________________________________
conv1d_14 (Conv1D)               (32, 4000, 256)       153856      embedding_6[0][0]                
____________________________________________________________________________________________________
conv1d_15 (Conv1D)               (32, 4000, 256)       205056      embedding_6[0][0]                
____________________________________________________________________________________________________
conv1d_16 (Conv1D)               (32, 4000, 256)       256256      embedding_6[0][0]                
____________________________________________________________________________________________________
max_pooling1d_13 (MaxPooling1D)  (32, 2000, 256)       0           conv1d_13[0][0]                  
____________________________________________________________________________________________________
max_pooling1d_14 (MaxPooling1D)  (32, 1333, 256)       0           conv1d_14[0][0]                  
____________________________________________________________________________________________________
max_pooling1d_15 (MaxPooling1D)  (32, 1000, 256)       0           conv1d_15[0][0]                  
____________________________________________________________________________________________________
max_pooling1d_16 (MaxPooling1D)  (32, 800, 256)        0           conv1d_16[0][0]                  
____________________________________________________________________________________________________
concatenate_4 (Concatenate)      (32, 5133, 256)       0           max_pooling1d_13[0][0]           
                                                                   max_pooling1d_14[0][0]           
                                                                   max_pooling1d_15[0][0]           
                                                                   max_pooling1d_16[0][0]           
____________________________________________________________________________________________________
flatten_4 (Flatten)              (32, 1314048)         0           concatenate_4[0][0]              
____________________________________________________________________________________________________
dense_6 (Dense)                  (32, 100)             131404900   flatten_4[0][0]                  
____________________________________________________________________________________________________
lstm_4 (LSTM)                    (32, 100)             120400      embedding_6[1][0]                
                                                                   dense_6[0][0]                    
                                                                   dense_6[0][0]                    
____________________________________________________________________________________________________
dense_7 (Dense)                  (32, 9)               909         lstm_4[0][0]                     
====================================================================================================
Total params: 206,008,433
Trainable params: 206,008,433
Non-trainable params: 0
____________________________________________________________________________________________________

initial_states 不应该在 LSTM 调用中吗? - Marcin Możejko
基于Github和代码中的一些问题,必须在传递的参数中进行设置。我正在尝试使用recurrentShop。 - bicepjai
是的 - 但你把它传递给了“嵌入”,而不是“LSTM”。 - Marcin Możejko
你是对的。那是一个打字错误。我会修复它。 - bicepjai
1个回答

2

LSTM有两个隐藏状态,但你只提供了一个初始状态。你可以选择以下其中之一:

用只有1个隐藏状态的RNN(如GRU)替换LSTM:

rnn_layer = GRU(100, return_sequences=False, stateful=True)
(gene_variation_embedding,initial_state=[l_dense_3d])

或者将LSTM的第二个隐藏状态的初始状态设为零:

zeros = Lambda(lambda x: K.zeros_like(x), output_shape=lambda s: s)(l_dense_3d)
rnn_layer = LSTM(100, return_sequences=False, stateful=True)
(gene_variation_embedding,initial_state=[l_dense_3d, zeros])

我理解初始状态为h_0和c_0。阅读了https://philipperemy.github.io/keras-stateful-lstm/之后,Keras有关stateful的定义变得清晰了。但是我只想设置h_0和c_0状态,而stateful=False似乎也支持这一点。 - bicepjai
stateful 是在你希望网络跨批次记住状态时使用的,这不是同一件事情。 - farizrahman4u
@farizrahman4u 当我使用hidden_states = K.variable(value=np.zeros((1, 10)))和cell_states = K.variable(value=np.zeros((1, 10)))时,运行lstm = LSTM(10)(input,initial_state=[hidden_states,cell_states])时出现TypeError: 'list' object is not callable。 - user2614596

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接