数值错误:顺序图层的输入与图层不兼容:期望的最小维度为4,而实际维度为3。完整形状为[8,28,28]。

22

我一直收到与输入形状相关的错误。希望能得到帮助。谢谢!

import tensorflow as tf

(xtrain, ytrain), (xtest, ytest) = tf.keras.datasets.mnist.load_data()

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, kernel_size=3, activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Conv2D(32, kernel_size=3, activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

model.compile(loss='categorical_crossentropy', 
              optimizer='adam',
              metrics='accuracy')

history = model.fit(xtrain, ytrain,
                    validation_data=(xtest, ytest),
                    epochs=10, batch_size=8)

数值错误: 序列层的输入与该层不兼容:期望的最小维度为4,但发现维度为3。接收到的完整形状为:[8, 28, 28]

3个回答

19

您创建的模型的输入层需要一个4维张量才能正常工作,但您传递给它的x_train张量只有3个维度。

这意味着您需要使用 .reshape(n_images, 286, 384, 1) 来重新整形您的训练集。 现在您已经添加了一个额外的维度而没有改变数据,您的模型已经准备好运行了。

在训练模型之前,您需要将 x_train 张量重塑为四个维度。

例如:

x_train = x_train.reshape(-1, 28, 28, 1)

有关Keras输入的更多信息,请查看此答案


14

你需要添加一个通道维度。Keras期望的数据格式如下:

(n_samples, height, width, channels)
例如,如果您的图像是灰度的,它们只有1个通道,因此需要按照以下格式提供给Keras:

对于这个例子,如果您的图像是灰度的,则只有1个通道,因此需要以此格式提供给Keras:

(60000, 28, 28, 1)

不幸的是,灰度图像通常会在没有通道维度的情况下被给予/下载,例如在tf.keras.datasets.mnist.load_data中,它将是(60000, 28, 28),这是有问题的。

解决方案:

您可以使用tf.expand_dims添加一个维度。

xtrain = tf.expand_dims(xtrain, axis=-1)

现在您的输入形状将为:

(60000, 28, 28, 1)

还有其他同样的选择:

xtrain = xtrain[..., np.newaxis]
xtrain = xtrain[..., None]
xtrain = xtrain.reshape(-1, 28, 28, 1)
xtrain = tf.reshape(xtrain, (-1, 28, 28, 1))
xtrain = np.expand_dims(xtrain, axis=-1)

请问,(60000, 28, 28, 1)中的60000是什么意思? - JunaidAfzal
2
这是样本数量。 - Nicolas Gervais
谢谢,还有一件事。我有一个形状为(170, 64, 64)的张量,但是当我增加它的维度时,它变成了(170, 4096, 1),但我想要的是:(170, 64, 64, 1),我该怎么做??? - JunaidAfzal
2
np.expand_dims(your_array, axis=-1) - Nicolas Gervais

2
您可以使用tf.newaxis添加新的轴:
arr = tf.random.uniform(shape=(8,8,28))
print(arr.shape)
arr = arr[:,:,:,tf.newaxis]
print(arr.shape)

# (8, 8, 28)
# (8, 8, 28, 1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接