如何在Keras中使用分类one-hot标签进行训练?

8

我有类似以下的输入:

[
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
...]

数据的形状为(1, num_samples, num_features),标签的格式如下:

[
[0, 1]
[1, 0]
[1, 0]
...]

数组的形状为(1, num_samples, 2)

然而,当我尝试运行以下Keras代码时,出现了以下错误:ValueError: Error when checking model target: expected dense_1 to have 2 dimensions, but got array with shape (1, 8038, 2)。从我所读的内容来看,这似乎源于我的标签是二维的,而不仅仅是整数。如果是这样,我该如何在Keras中使用one-hot标签?

以下是代码:

num_features = 463
trX = np.random(8038, num_features)
trY = # one-hot array of shape (8038, 2) as described above

def keras_builder():  #generator to build the inputs
    while(1):
        x = np.reshape(trX, (1,) + np.shape(trX))
        y = np.reshape(trY, (1,) + np.shape(trY))
        print(np.shape(x)) # (1, 8038, 463)
        print(np.shape(y)) # (1, 8038, 2)
        yield x, y

model = Sequential()
model.add(LSTM(100, input_dim = num_features))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit_generator(keras_builder(), samples_per_epoch = 1, nb_epoch=3, verbose = 2, nb_worker = 1)

这将立即抛出上述错误:

Traceback (most recent call last):
  File "file.py", line 35, in <module>
    model.fit_generator(keras_builder(), samples_per_epoch = 1, nb_epoch=3, verbose = 2, nb_worker = 1)
  ...
ValueError: Error when checking model target: expected dense_1 to have 2 dimensions, but got array with shape (1, 8038, 2)

谢谢!

1个回答

6
有很多事情是不符合的。
我假设您正在尝试解决一个连续分类任务,即您的数据的形状为(<batch size>, <sequence length>, <feature length>)
在批处理生成器中,您创建了一个批次,其中包含一个长度为8038且每个序列元素具有463个特征的序列。您创建了一个匹配的Y批次进行比较,其中包含一个具有8038个元素的序列,每个元素大小为2。
您的问题是Y与最后一层的输出不匹配。您的Y是三维的,而模型的输出仅为二维:Y.shape = (1, 8038, 2)dense_1.shape = (1,1)不匹配。这就解释了您收到的错误消息。
解决方法:您需要在LSTM层中启用return_sequences=True以返回一个序列,而不仅仅是最后一个元素(有效地删除时间维)。这将在LSTM层产生输出形状(1, 8038, 100)。由于Dense层无法处理序列数据,因此您需要将其分别应用于每个序列元素,这可以通过将其包装在TimeDistributed包装器中完成。然后,您的模型将具有输出形状(1, 8038, 1)
您的模型应该如下所示:
from keras.layers.wrappers import TimeDistributed

model = Sequential()
model.add(LSTM(100, input_dim=num_features, return_sequences=True))
model.add(TimeDistributed(Dense(1, activation='sigmoid')))

当检查模型摘要时,可以很容易地发现这一点:

print(model.summary()) 

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接