使用正则化后,LSTM 仍存在过拟合问题

3
我正在处理一个时间序列预测问题,并构建了以下类似的LSTM模型:
def create_model():
    model = Sequential()
    model.add(LSTM(50,kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01), input_shape=(train_X.shape[1], train_X.shape[2])))
    model.add(Dropout(0.591))
    model.add(Dense(1))
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model

当我使用以下5个拆分训练模型时:
tss = TimeSeriesSplit(n_splits = 5)
X = data.drop(labels=['target_prediction'], axis=1)
y = data['target_prediction'] 
for train_index, test_index in tss.split(X):
   train_X, test_X = X.iloc[train_index, :].values, X.iloc[test_index,:].values
   train_y, test_y = y.iloc[train_index].values, y.iloc[test_index].values
   model=create_model()
   history = model.fit(train_X, train_y, epochs=10, batch_size=64,validation_data=(test_X, test_y), verbose=0, shuffle=False)

我遇到了过拟合问题。附上损失图表 enter image description here

我不确定为什么在使用Keras模型中的正则化器时会出现过拟合。感谢任何帮助。

编辑: 尝试了这些结构

def create_model():
    model = Sequential()
    model.add(LSTM(20, input_shape=(train_X.shape[1], train_X.shape[2])))
    model.add(Dense(1))
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model


def create_model(x,y):
    # define LSTM
    model = Sequential()
    model.add(Bidirectional(LSTM(20, return_sequences=True), input_shape=(x,y)))
    model.add(TimeDistributed(Dense(1, activation='sigmoid')))
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model 

但仍然存在过拟合问题。


如果你使用多标签分类模型,我建议你使用0.5的dropout;而如果你的数据是二元的,我建议你使用0.2的dropout。 - Balive13
然而,辍学主要用于CNN而不是LSTM。因此最好只保留kernel_regularizer、recurrent_regularizer和bias_regularizer。另外,你可以改变学习率以防止过拟合,例如,在Adam算法中检查learning_rate = 2e-5。 - Balive13
1个回答

10

首先,移除所有的正则化和丢弃。您正在使用所有技巧并使用的0.5丢弃率过高。

减少LSTM中的单元数。从这里开始。达到一个使您的模型停止过度拟合的点。

然后,如果需要,请添加丢弃。

接下来的步骤是添加 tf.keras.Bidirectional。如果仍不满意,则增加层数。请记住对于每个LSTM层,都要保持return_sequences 为True,除了最后一层。

我很少见到使用层正则化的网络,尽管可以使用它,因为丢弃和层正则化有相同的效果,人们通常会选择丢弃(最多只看到使用0.3的情况)。


5
使用较少单元且没有使用dropout的简单架构仍会出现过拟合。 - Ricky

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接