Keras LSTM多分类分类

6

我有一份适用于二分类的代码,已经在Keras IMDB数据集上进行了测试。

    model = Sequential()
    model.add(Embedding(5000, 32, input_length=500))
    model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])        
    print(model.summary())
    model.fit(X_train, y_train, epochs=3, batch_size=64)
    # Final evaluation of the model
    scores = model.evaluate(X_test, y_test, verbose=0)

我需要将上述代码转换为多类分类,其中总共有7个类别。通过阅读一些文章,我理解要转换上述代码,我必须更改:

model.add(Dense(7, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])  

很明显只更改上述两行是不起作用的。我需要更改什么才能使代码适用于多类分类?此外,我认为我还需要将类别更改为热编码,但不知道如何在keras中实现。

1个回答

10

是的,你需要一个独热编码的目标变量,你可以使用to_categorical函数对目标变量进行编码或者使用如下方式:

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

以下是完整代码:

from keras.models import Sequential
from keras.layers import *

model = Sequential()
model.add(Embedding(5000, 32, input_length=500))
model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(7, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

总结

Using TensorFlow backend.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_1 (Embedding)      (None, 500, 32)           160000    
_________________________________________________________________
lstm_1 (LSTM)                (None, 100)               53200     
_________________________________________________________________
dense_1 (Dense)              (None, 7)                 707       
=================================================================
Total params: 213,907
Trainable params: 213,907
Non-trainable params: 0
_________________________________________________________________

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接