Keras LSTM 在线学习

4

我有一个非常可预测的序列。下面是其中一部分:

deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4])

基本上,它是一系列但仍可预测的4的数量,后跟一个8,每三个8后为28。

我希望构建一个非常简单的LSTM模型进行在线预测:每次出现新数字时,它都会附加在deque的右侧。因此,LSTM是在由deque的元素[0:seq_length]组成的旧序列上进行训练的,训练目标是[seq_length]元素。然后,窗口移​​动并且在[1:seq_length+1]元素上执行预测。最后,deque的最左边的元素被丢弃。我的直觉告诉我,这应该使网络记忆序列。

然而,我的网络只回答4。经过(很长)一段时间后,令人惊讶的是,它开始几乎全部错过,只回答8。然后,(很长)一段时间后,它又回到了只回答4的状态。

我的模型结构如下所示。自然地,我已经尝试使用不同的seq_lengthlstm_cells值进行实验,但没有成功。以下是最新运行的值:

seq_length = 64  #Length of the sequence to be inserted into the LSTM
vocab_size = 4  #Size of the final dense layer of the model
lstm_cells = 16  #Size of the LSTM layer

model = Sequential()
model.add(LSTM(lstm_cells, input_shape=(seq_length, 1)))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

以下是如何在模型上准备、训练和进行预测的数据。变量sequence是在本帖开头显示的deque。我维护一个列表vocab=[4,8,28],该列表在执行时构建新数字时被添加,因此vocab[i]将类i转换为其相应的序列号。然后我创建一个字典legend来完成相反的操作。这更或多或少是持续的在线循环:
while True:

    # Receives new number and puts it into the deque:
    sequence.append(generateNextNumber())

    # At this point, please note that the length of the deque is seq_length + 1.

    # Dictionary to convert numbers to classes:
    legend = dict([(v, k) for k, v in enumerate(vocab)])
    # Converts the deque into a list:
    seq_list = list(sequence)
    # Each iteration is comprised of 1 training and 1 prediction. These are the training sequence and target:
    train_seq = [ [legend[i]] for i in seq_list[:seq_length] ]
    train_target = legend[ seq_list[seq_length] ]
    # And the prediction sequence just shifts the window by 1:
    pred_seq = [ [legend[i]] for i in seq_list[1:] ]

    # Batches data into a batch of size 1:
    x = np.zeros((1, seq_length, 1))
    y = np.zeros((1, vocab_size))
    x[0,:] = train_seq
    y[0,:] = to_categorical( train_target, num_classes=vocab_size )
    # Online training:
    model.fit(x=x, y=y, batch_size=1, epochs=1, verbose=0)

    # Now that one training step is done, make a prediction:
    x_pred = np.zeros((1, seq_length, 1))
    x[0,:] = pred_seq
    predicted_onehot = model.predict(x_pred)
    # Avoids "index out of range" erros when the LSTM vocab is still being built:
    predicted_index = min(np.argmax(predicted_onehot), len(vocab)-1)
    predicted_number = vocab[ predicted_index ]
    # Reverts deque length to seq_length:
    sequence.popleft()

最后,这是一个示例输出:

HIT! Current hit rate: 34.753665869071725 (predicted: 4, sequence was: deque([4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))

HIT! Current hit rate: 34.75566735175926 (predicted: 4, sequence was: deque([4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))

Predicted 4 but it was 8

HIT! Current hit rate: 34.75660255820374 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4]))

HIT! Current hit rate: 34.758603766640086 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4]))

HIT! Current hit rate: 34.7606048523142 (predicted: 4, sequence was: deque([4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4]))

HIT! Current hit rate: 34.76260581523739 (predicted: 4, sequence was: deque([4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4]))

HIT! Current hit rate: 34.76460665542095 (predicted: 4, sequence was: deque([4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))

HIT! Current hit rate: 34.76660737287616 (predicted: 4, sequence was: deque([4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))

Predicted 4 but it was 28

HIT! Current hit rate: 34.767541707556425 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4]))

发生了什么问题?

非常感谢您的帮助。


我的猜测是,LSTM 最有效的行为是始终预测 4s,因为约 90% 的数字都相同。尝试让它预测正弦波或类似的东西,你应该会看到它在不改变架构的情况下很快学会了。 - HitLuca
感谢您的评论。然而,我刚刚尝试了循环序列4、8、4、8、4、8、4、28。一开始,LSTM 在预测 4 和 8 之间有些分裂,但过了一会儿它就开始一直预测 4。 - Denolth
1个回答

1
我在代码的这一行遇到了问题:


x[0,:] = pred_seq

应该是什么:
x_pred[0,:] = pred_seq

现在一切都运行得差不多正确。我仍然会留下这个问题,因为它提供了一些有关LSTM在线学习的好见解。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接