用Keras实现LSTM

3

我有一些训练数据x_train,以及与此x_train相对应的标签y_train。下面是x_trainy_train的构建方式:

train_x = np.array([np.random.rand(1, 1000)[0] for i in range(10000)])
train_y = (np.random.randint(1,150,10000))

train_x有一万行每行均有1000列。 train_y为每个train_x样本标签赋值1到150之间的代码。

我还有一个叫做sample的样本,它只有1行1000列,我想要用它在这个LSTM模型中进行预测。该变量定义为:

sample = np.random.rand(1,1000)[0]

我正在尝试使用Keras训练和预测这些数据上的LSTM。我希望输入这个特征向量,并使用这个LSTM来预测1到150范围内的代码之一。我知道这些是随机数组,但我不能发布我拥有的数据。我已经尝试了以下方法,我相信应该可以工作,但是遇到了一些问题。

    model = Sequential()
    model.add(LSTM(output_dim = 32, input_length = 10000, input_dim = 1000,return_sequences=True))
    model.add(Dense(150, activation='relu'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 
    history = model.fit(train_x, train_y,
              batch_size=128, nb_epoch=1,
              verbose = 1)
    model.predict(sample)

任何对这个管道的帮助或调整都将是很好的。我不确定output_dim是否正确。我想在1000维数据的每个样本上训练LSTM,然后生成1到150范围内的特定代码。谢谢。
1个回答

2

我看到至少有三件事情需要改变:

  1. Change this line:

    model.add(Dense(150, activation='relu'))
    

    to:

    model.add(Dense(150, activation='softmax'))
    

    as leaving 'relu' as activation makes your output unbounded whereas it needs to have a probabilistic interpretation (as you use categorical_crossentropy).

  2. Change loss or target:

    As you are using categorical_crossentropy you need to change your target to be a one-hot encoded vector of length 150. Another way is to leave your target but to change loss to sparse_categorical_crossentropy.

  3. Change your target range:

    Keras has a 0-based array indexing (as in Python, C and C++ so your values should be in range [0, 150) instead [1, 150].


当我将其更改为长度为150的独热编码向量时,预测结果中每个位置都会出现一堆小数。这些小数代表什么?它们是表示此类别为1、2、3、...、150的概率吗? - Mike El Jackson
是的 - 但请记住这个基于0的数组索引。类1的概率位于索引0处,类2的概率位于索引1处,以此类推。 - Marcin Możejko
好的,非常感谢!您能用一个k热编码向量来完成这个任务吗?比如说,对于样本1,我们有[5, 8, 9],对于样本2,我们有[130, 11, 12, 5, 9],其中标签数量不同。 - Mike El Jackson
是的 - 但这更棘手。这是另一个问题的情况。 - Marcin Możejko
好的,谢谢您,我接受了您的答案。我可能会为那个问题发一篇单独的提问。 - Mike El Jackson
显示剩余4条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接